refactor(config): ensure `upstreams.timeout` is always valid

This commit is contained in:
ThinkChaos 2023-11-22 12:29:57 -05:00
parent 0f69630563
commit ef29cdc45e
5 changed files with 34 additions and 15 deletions

View File

@ -578,6 +578,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
func (cfg *Config) validate(logger *logrus.Entry) {
cfg.MinTLSServeVer.validate(logger)
cfg.Upstreams.validate(logger)
}
// ConvertPort converts string representation into a valid port (0 - 65535)

View File

@ -10,7 +10,7 @@ const UpstreamDefaultCfgName = "default"
// Upstreams upstream servers configuration
type Upstreams struct {
Init Init `yaml:"init"`
Timeout Duration `yaml:"timeout" default:"2s"`
Timeout Duration `yaml:"timeout" default:"2s"` // always > 0
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
UserAgent string `yaml:"userAgent"`
@ -18,6 +18,15 @@ type Upstreams struct {
type UpstreamGroups map[string][]Upstream
func (c *Upstreams) validate(logger *logrus.Entry) {
defaults := mustDefault[Upstreams]()
if !c.Timeout.IsAboveZero() {
logger.Warnf("upstreams.timeout <= 0, setting to %s", defaults.Timeout)
c.Timeout = defaults.Timeout
}
}
// IsEnabled implements `config.Configurable`.
func (c *Upstreams) IsEnabled() bool {
return len(c.Groups) != 0

View File

@ -61,6 +61,25 @@ var _ = Describe("ParallelBestConfig", func() {
))
})
})
Describe("validate", func() {
It("should compute defaults", func() {
cfg.Timeout = -1
cfg.validate(logger)
Expect(cfg.Timeout).Should(BeNumerically(">", 0))
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout")))
})
It("should not override valid user values", func() {
cfg.validate(logger)
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("timeout")))
})
})
})
Context("UpstreamGroupConfig", func() {

View File

@ -143,12 +143,8 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string
return ips, nil
}
if b.cfg.timeout.IsAboveZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
defer cancel()
// Use system resolver if no bootstrap is configured
if b.resolver == nil {

View File

@ -274,14 +274,8 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
ip = ips.Current()
upstreamURL := r.upstreamClient.fmtURL(ip, r.cfg.Port, r.cfg.Path)
ctx := ctx // make sure we don't overwrite the outer function's context
if r.cfg.Timeout.IsAboveZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
defer cancel()
}
ctx, cancel := context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
defer cancel()
response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
if err != nil {