diff --git a/httpclient/client.go b/httpclient/client.go index 73805a5..05df790 100644 --- a/httpclient/client.go +++ b/httpclient/client.go @@ -15,7 +15,7 @@ type Client struct { client heimdall.Doer retrier heimdall.Retriable plugins []heimdall.Plugin - timeout time.Duration + timeout *time.Duration retryCount int } @@ -29,7 +29,7 @@ var _ heimdall.Client = (*Client)(nil) // NewClient returns a new instance of http Client func NewClient(opts ...Option) *Client { client := Client{ - timeout: defaultHTTPTimeout, + client: &http.Client{Timeout: defaultHTTPTimeout}, retryCount: defaultRetryCount, retrier: heimdall.NewNoRetrier(), } @@ -38,11 +38,7 @@ func NewClient(opts ...Option) *Client { opt(&client) } - if client.client == nil { - client.client = &http.Client{ - Timeout: client.timeout, - } - } + client.updateHTTPTimeout() return &client } @@ -207,3 +203,13 @@ func (c *Client) reportRequestEnd(request *http.Request, response *http.Response plugin.OnRequestEnd(request, response) } } + +func (c *Client) updateHTTPTimeout() { + if c.timeout == nil { + return + } + + if client, ok := c.client.(*http.Client); ok && client != nil { + client.Timeout = *c.timeout + } +} diff --git a/httpclient/options.go b/httpclient/options.go index a43cf02..fd525a2 100644 --- a/httpclient/options.go +++ b/httpclient/options.go @@ -9,10 +9,12 @@ import ( // Option represents the client options type Option func(*Client) -// WithHTTPTimeout sets hystrix timeout +// WithHTTPTimeout sets timeout for http.Client func WithHTTPTimeout(timeout time.Duration) Option { return func(c *Client) { - c.timeout = timeout + c.timeout = &timeout + + c.updateHTTPTimeout() // hystrix.WithHTTPTimeout relies on this } } @@ -34,5 +36,7 @@ func WithRetrier(retrier heimdall.Retriable) Option { func WithHTTPClient(client heimdall.Doer) Option { return func(c *Client) { c.client = client + + c.updateHTTPTimeout() // hystrix.WithHTTPTimeout relies on this } } diff --git a/httpclient/options_test.go b/httpclient/options_test.go index b71d97d..547eb3c 100644 --- a/httpclient/options_test.go +++ b/httpclient/options_test.go @@ -29,11 +29,60 @@ func TestOptionsAreSet(t *testing.T) { ) assert.Equal(t, client, c.client) - assert.Equal(t, httpTimeout, c.timeout) + assert.NotEqual(t, httpTimeout, client.client.Timeout) // can't override custom implementation + assert.Equal(t, httpTimeout, *c.timeout) assert.Equal(t, retrier, c.retrier) assert.Equal(t, noOfRetries, c.retryCount) } +func TestWithClientWihhoutHTTPTimeoutShouldNotOverrideUserHTTPClientTimeout(t *testing.T) { + t.Parallel() + + client := &http.Client{Timeout: 25 * time.Millisecond} + + c := NewClient( + WithHTTPClient(client), + ) + + assert.Equal(t, client, c.client) + assert.Equal(t, 25*time.Millisecond, client.Timeout) // overrides user provided *http.Client + assert.Nil(t, c.timeout) +} + +func TestWithHTTPTimeoutOverridesUserHTTPClientTimeout(t *testing.T) { + t.Parallel() + + httpTimeout := 10 * time.Second + + client := &http.Client{Timeout: 25 * time.Millisecond} + + c := NewClient( + WithHTTPClient(client), + WithHTTPTimeout(httpTimeout), + ) + + assert.Equal(t, client, c.client) + assert.Equal(t, httpTimeout, client.Timeout) // overrides user provided *http.Client + assert.Equal(t, httpTimeout, *c.timeout) +} + +func TestWithHTTPTimeoutOverridesUserHTTPClientTimeout_InverseSeq(t *testing.T) { + t.Parallel() + + httpTimeout := 10 * time.Second + + client := &http.Client{Timeout: 25 * time.Millisecond} + + c := NewClient( + WithHTTPTimeout(httpTimeout), + WithHTTPClient(client), + ) + + assert.Equal(t, client, c.client) + assert.Equal(t, httpTimeout, client.Timeout) // overrides user provided *http.Client + assert.Equal(t, httpTimeout, *c.timeout) +} + func TestOptionsHaveDefaults(t *testing.T) { t.Parallel() @@ -45,7 +94,10 @@ func TestOptionsHaveDefaults(t *testing.T) { c := NewClient() assert.Equal(t, http.DefaultClient, c.client) - assert.Equal(t, httpTimeout, c.timeout) + assert.Nil(t, c.timeout) + httpClient, ok := c.client.(*http.Client) + assert.True(t, ok) + assert.Equal(t, httpTimeout, httpClient.Timeout) assert.Equal(t, retrier, c.retrier) assert.Equal(t, noOfRetries, c.retryCount) } diff --git a/hystrix/helper_test.go b/hystrix/helper_test.go index 1b4bd8b..a07a930 100644 --- a/hystrix/helper_test.go +++ b/hystrix/helper_test.go @@ -1,7 +1,9 @@ package hystrix import ( + "net/http" "sync" + "time" metricCollector "github.com/afex/hystrix-go/hystrix/metric_collector" ) @@ -82,3 +84,13 @@ func (r *simpleMetricRegistry) Register(name string) metricCollector.MetricColle r.collectors[name] = collector return collector } + +type delayedCancelDoer struct { + Delay time.Duration +} + +func (d delayedCancelDoer) Do(r *http.Request) (*http.Response, error) { + <-r.Context().Done() + time.Sleep(d.Delay) + return nil, r.Context().Err() +} diff --git a/hystrix/hystrix_client.go b/hystrix/hystrix_client.go index 9eb35ce..a0fcfc2 100644 --- a/hystrix/hystrix_client.go +++ b/hystrix/hystrix_client.go @@ -23,7 +23,6 @@ type fallbackCtxFunc func(context.Context, error) error type Client struct { client *httpclient.Client - timeout time.Duration hystrixTimeout time.Duration hystrixCommandName string maxConcurrentRequests int @@ -38,7 +37,6 @@ type Client struct { const ( defaultHystrixRetryCount = 0 - defaultHTTPTimeout = 30 * time.Second defaultHystrixTimeout = 30 * time.Second defaultMaxConcurrentRequests = 100 defaultErrorPercentThreshold = 25 @@ -56,7 +54,6 @@ var err5xx = goerrors.New("server returned 5xx status code") func NewClient(opts ...Option) *Client { client := Client{ client: httpclient.NewClient(), - timeout: defaultHTTPTimeout, hystrixTimeout: defaultHystrixTimeout, maxConcurrentRequests: defaultMaxConcurrentRequests, errorPercentThreshold: defaultErrorPercentThreshold, diff --git a/hystrix/hystrix_client_test.go b/hystrix/hystrix_client_test.go index a72d16e..0a33fd3 100644 --- a/hystrix/hystrix_client_test.go +++ b/hystrix/hystrix_client_test.go @@ -647,6 +647,7 @@ func TestHystrixHTTPClientDoContextCancelled(t *testing.T) { r := newSimpleMetricRegistry() client := NewClient( + WithHTTPClient(delayedCancelDoer{Delay: 100 * time.Millisecond}), // making sure hystrix pickups context cancel before run failure WithCommandName(cmdName), WithRetryCount(3), WithRetrier(heimdall.NewRetrierFunc(func(retry int) time.Duration { diff --git a/hystrix/options.go b/hystrix/options.go index cccfb9f..e36fc32 100644 --- a/hystrix/options.go +++ b/hystrix/options.go @@ -19,10 +19,10 @@ func WithCommandName(name string) Option { } } -// WithHTTPTimeout sets hystrix timeout +// WithHTTPTimeout sets timeout for http.Client func WithHTTPTimeout(timeout time.Duration) Option { return func(c *Client) { - c.timeout = timeout + httpclient.WithHTTPTimeout(timeout)(c.client) } } diff --git a/hystrix/options_test.go b/hystrix/options_test.go index f2beed6..82d5d1f 100644 --- a/hystrix/options_test.go +++ b/hystrix/options_test.go @@ -23,7 +23,6 @@ func TestOptionsAreSet(t *testing.T) { WithStatsDCollector("localhost:8125", "myapp.hystrix"), ) - assert.Equal(t, 10*time.Second, c.timeout) assert.Equal(t, "test", c.hystrixCommandName) assert.Equal(t, time.Duration(1100), c.hystrixTimeout) assert.Equal(t, 10, c.maxConcurrentRequests) @@ -39,7 +38,6 @@ func TestOptionsHaveDefaults(t *testing.T) { c := NewClient(WithCommandName("test-defaults")) - assert.Equal(t, 30*time.Second, c.timeout) assert.Equal(t, "test-defaults", c.hystrixCommandName) assert.Equal(t, 30*time.Second, c.hystrixTimeout) assert.Equal(t, 100, c.maxConcurrentRequests) @@ -147,3 +145,35 @@ func ExampleWithStatsDCollector() { fmt.Println("Response status : ", res.StatusCode) // Output: Response status : 200 } + +func TestWithHTTPTimeoutOverridesUserHTTPClientTimeout(t *testing.T) { + t.Parallel() + + httpTimeout := 10 * time.Second + + client := &http.Client{Timeout: 25 * time.Millisecond} + + c := NewClient( + WithHTTPClient(client), + WithHTTPTimeout(httpTimeout), + ) + + assert.NotNil(t, c) + assert.Equal(t, httpTimeout, client.Timeout) // overrides user provided *http.Client +} + +func TestWithHTTPTimeoutOverridesUserHTTPClientTimeout_InverseSeq(t *testing.T) { + t.Parallel() + + httpTimeout := 10 * time.Second + + client := &http.Client{Timeout: 25 * time.Millisecond} + + c := NewClient( + WithHTTPTimeout(httpTimeout), + WithHTTPClient(client), + ) + + assert.NotNil(t, c) + assert.Equal(t, httpTimeout, client.Timeout) // overrides user provided *http.Client +}