refactor(bootstrap): replace `Dialer.Timeout` with a `context` deadline

This commit is contained in:
ThinkChaos 2023-11-29 18:13:54 -05:00
parent 603d374405
commit 0f69630563
2 changed files with 37 additions and 34 deletions

View File

@ -16,12 +16,12 @@ func (c Duration) ToDuration() time.Duration {
return time.Duration(c)
}
// IsAboveZero returns true if duration is above zero
// IsAboveZero returns true if duration is strictly greater than zero.
func (c Duration) IsAboveZero() bool {
return c.ToDuration() > 0
}
// IsAtLeastZero returns true if duration is at least zero
// IsAtLeastZero returns true if duration is greater or equal to zero.
func (c Duration) IsAtLeastZero() bool {
return c.ToDuration() >= 0
}

View File

@ -21,25 +21,34 @@ import (
"golang.org/x/exp/maps"
)
const (
defaultTimeout = 5 * time.Second
)
var errArbitrarySystemResolverRequest = errors.New(
"cannot resolve arbitrary requests using the system resolver",
)
type bootstrapConfig struct {
config.BootstrapDNSConfig
connectIPVersion config.IPVersion
timeout config.Duration
}
func newBootstrapConfig(cfg *config.Config) *bootstrapConfig {
return &bootstrapConfig{
BootstrapDNSConfig: cfg.BootstrapDNS,
connectIPVersion: cfg.ConnectIPVersion,
timeout: cfg.Upstreams.Timeout,
}
}
// Bootstrap allows resolving hostnames using the configured bootstrap DNS.
type Bootstrap struct {
configurable[*config.BootstrapDNSConfig]
configurable[*bootstrapConfig]
typed
resolver Resolver
bootstraped bootstrapedResolvers
connectIPVersion config.IPVersion
timeout time.Duration
// To allow replacing during tests
systemResolver *net.Resolver
dialer interface {
@ -50,25 +59,15 @@ type Bootstrap struct {
// NewBootstrap creates and returns a new Bootstrap.
// Internally, it uses a CachingResolver and an UpstreamResolver.
func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err error) {
timeout := defaultTimeout
if cfg.Upstreams.Timeout.IsAboveZero() {
timeout = cfg.Upstreams.Timeout.ToDuration()
}
// Create b in multiple steps: Bootstrap and UpstreamResolver have a cyclic dependency
// This also prevents the GC to clean up these two structs, but is not currently an
// issue since they stay allocated until the process terminates
b = &Bootstrap{
configurable: withConfig(&cfg.BootstrapDNS),
configurable: withConfig(newBootstrapConfig(cfg)),
typed: withType("bootstrap"),
connectIPVersion: cfg.ConnectIPVersion,
systemResolver: net.DefaultResolver,
timeout: timeout,
dialer: &net.Dialer{
Timeout: timeout,
},
dialer: new(net.Dialer),
}
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams)
@ -126,7 +125,7 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
hostname := r.cfg.Host
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing tests easier
return newIPSet([]net.IP{ip}), nil
}
@ -139,20 +138,24 @@ func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSe
}
func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) {
// Use system resolver if no bootstrap is configured
if b.resolver == nil {
ctx, cancel := context.WithTimeout(ctx, b.timeout)
defer cancel()
return b.systemResolver.LookupIP(ctx, b.connectIPVersion.Net(), host)
}
if ips, ok := b.bootstraped[r]; ok {
// Special path for bootstraped upstreams to avoid infinite recursion
return ips, nil
}
return b.resolve(ctx, host, b.connectIPVersion.QTypes())
if b.cfg.timeout.IsAboveZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, b.cfg.timeout.ToDuration())
defer cancel()
}
// Use system resolver if no bootstrap is configured
if b.resolver == nil {
return b.systemResolver.LookupIP(ctx, b.cfg.connectIPVersion.Net(), host)
}
return b.resolve(ctx, host, b.cfg.connectIPVersion.QTypes())
}
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
@ -181,8 +184,8 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
var qTypes []dns.Type
switch {
case b.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured
qTypes = b.connectIPVersion.QTypes()
case b.cfg.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured
qTypes = b.cfg.connectIPVersion.QTypes()
case strings.HasSuffix(network, "4"):
qTypes = config.IPVersionV4.QTypes()
case strings.HasSuffix(network, "6"):