mirror of https://github.com/0xERR0R/blocky.git
refactor(config): ensure `upstreams.timeout` is always valid
This commit is contained in:
parent
0f69630563
commit
ef29cdc45e
|
@ -578,6 +578,7 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
|
|||
|
||||
func (cfg *Config) validate(logger *logrus.Entry) {
|
||||
cfg.MinTLSServeVer.validate(logger)
|
||||
cfg.Upstreams.validate(logger)
|
||||
}
|
||||
|
||||
// ConvertPort converts string representation into a valid port (0 - 65535)
|
||||
|
|
|
@ -10,7 +10,7 @@ const UpstreamDefaultCfgName = "default"
|
|||
// Upstreams upstream servers configuration
|
||||
type Upstreams struct {
|
||||
Init Init `yaml:"init"`
|
||||
Timeout Duration `yaml:"timeout" default:"2s"`
|
||||
Timeout Duration `yaml:"timeout" default:"2s"` // always > 0
|
||||
Groups UpstreamGroups `yaml:"groups"`
|
||||
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
|
||||
UserAgent string `yaml:"userAgent"`
|
||||
|
@ -18,6 +18,15 @@ type Upstreams struct {
|
|||
|
||||
type UpstreamGroups map[string][]Upstream
|
||||
|
||||
func (c *Upstreams) validate(logger *logrus.Entry) {
|
||||
defaults := mustDefault[Upstreams]()
|
||||
|
||||
if !c.Timeout.IsAboveZero() {
|
||||
logger.Warnf("upstreams.timeout <= 0, setting to %s", defaults.Timeout)
|
||||
c.Timeout = defaults.Timeout
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled implements `config.Configurable`.
|
||||
func (c *Upstreams) IsEnabled() bool {
|
||||
return len(c.Groups) != 0
|
||||
|
|
|
@ -61,6 +61,25 @@ var _ = Describe("ParallelBestConfig", func() {
|
|||
))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("validate", func() {
|
||||
It("should compute defaults", func() {
|
||||
cfg.Timeout = -1
|
||||
|
||||
cfg.validate(logger)
|
||||
|
||||
Expect(cfg.Timeout).Should(BeNumerically(">", 0))
|
||||
|
||||
Expect(hook.Calls).ShouldNot(BeEmpty())
|
||||
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout")))
|
||||
})
|
||||
|
||||
It("should not override valid user values", func() {
|
||||
cfg.validate(logger)
|
||||
|
||||
Expect(hook.Messages).ShouldNot(ContainElement(ContainSubstring("timeout")))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("UpstreamGroupConfig", func() {
|
||||
|
|
|
@ -143,12 +143,8 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
if b.cfg.timeout.IsAboveZero() {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
ctx, cancel = context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
|
||||
defer cancel()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
|
||||
defer cancel()
|
||||
|
||||
// Use system resolver if no bootstrap is configured
|
||||
if b.resolver == nil {
|
||||
|
|
|
@ -274,14 +274,8 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request)
|
|||
ip = ips.Current()
|
||||
upstreamURL := r.upstreamClient.fmtURL(ip, r.cfg.Port, r.cfg.Path)
|
||||
|
||||
ctx := ctx // make sure we don't overwrite the outer function's context
|
||||
|
||||
if r.cfg.Timeout.IsAboveZero() {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
ctx, cancel = context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
|
||||
defer cancel()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, r.cfg.Timeout.ToDuration())
|
||||
defer cancel()
|
||||
|
||||
response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue