diff --git a/config/config.go b/config/config.go index 9f41e59e..9bd72c8f 100644 --- a/config/config.go +++ b/config/config.go @@ -578,6 +578,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool { func (cfg *Config) validate(logger *logrus.Entry) { cfg.MinTLSServeVer.validate(logger) + cfg.Upstreams.validate(logger) } // ConvertPort converts string representation into a valid port (0 - 65535) diff --git a/config/upstreams.go b/config/upstreams.go index 698301ce..cf1090e7 100644 --- a/config/upstreams.go +++ b/config/upstreams.go @@ -10,7 +10,7 @@ const UpstreamDefaultCfgName = "default" // Upstreams upstream servers configuration type Upstreams struct { Init Init `yaml:"init"` - Timeout Duration `yaml:"timeout" default:"2s"` + Timeout Duration `yaml:"timeout" default:"2s"` // always > 0 Groups UpstreamGroups `yaml:"groups"` Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` UserAgent string `yaml:"userAgent"` @@ -18,6 +18,15 @@ type Upstreams struct { type UpstreamGroups map[string][]Upstream +func (c *Upstreams) validate(logger *logrus.Entry) { + defaults := mustDefault[Upstreams]() + + if !c.Timeout.IsAboveZero() { + logger.Warnf("upstreams.timeout <= 0, setting to %s", defaults.Timeout) + c.Timeout = defaults.Timeout + } +} + // IsEnabled implements `config.Configurable`. func (c *Upstreams) IsEnabled() bool { return len(c.Groups) != 0 diff --git a/config/upstreams_test.go b/config/upstreams_test.go index 6a2000d9..95d93b3c 100644 --- a/config/upstreams_test.go +++ b/config/upstreams_test.go @@ -61,6 +61,25 @@ var _ = Describe("ParallelBestConfig", func() { )) }) }) + + Describe("validate", func() { + It("should compute defaults", func() { + cfg.Timeout = -1 + + cfg.validate(logger) + + Expect(cfg.Timeout).Should(BeNumerically(">", 0)) + + Expect(hook.Calls).ShouldNot(BeEmpty()) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout"))) + }) + + It("should not override valid user values", func() { + cfg.validate(logger) + + Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("timeout"))) + }) + }) }) Context("UpstreamGroupConfig", func() { diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index e5cab7a1..6e75c341 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -143,12 +143,8 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string return ips, nil } - if b.cfg.timeout.IsAboveZero() { - var cancel context.CancelFunc - - ctx, cancel = context.WithTimeout(ctx, b.cfg.timeout.ToDuration()) - defer cancel() - } + ctx, cancel := context.WithTimeout(ctx, b.cfg.timeout.ToDuration()) + defer cancel() // Use system resolver if no bootstrap is configured if b.resolver == nil { diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index f3fcc3ae..a7449028 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -274,14 +274,8 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) ip = ips.Current() upstreamURL := r.upstreamClient.fmtURL(ip, r.cfg.Port, r.cfg.Path) - ctx := ctx // make sure we don't overwrite the outer function's context - - if r.cfg.Timeout.IsAboveZero() { - var cancel context.CancelFunc - - ctx, cancel = context.WithTimeout(ctx, r.cfg.Timeout.ToDuration()) - defer cancel() - } + ctx, cancel := context.WithTimeout(ctx, r.cfg.Timeout.ToDuration()) + defer cancel() response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol) if err != nil {