From 0f69630563389176951dfc7223b1be6400ad07a5 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Wed, 29 Nov 2023 18:13:54 -0500 Subject: [PATCH] refactor(bootstrap): replace `Dialer.Timeout` with a `context` deadline --- config/duration.go | 4 +-- resolver/bootstrap.go | 67 ++++++++++++++++++++++--------------------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/config/duration.go b/config/duration.go index 1a2c7a9b..6c99b773 100644 --- a/config/duration.go +++ b/config/duration.go @@ -16,12 +16,12 @@ func (c Duration) ToDuration() time.Duration { return time.Duration(c) } -// IsAboveZero returns true if duration is above zero +// IsAboveZero returns true if duration is strictly greater than zero. func (c Duration) IsAboveZero() bool { return c.ToDuration() > 0 } -// IsAtLeastZero returns true if duration is at least zero +// IsAtLeastZero returns true if duration is greater or equal to zero. func (c Duration) IsAtLeastZero() bool { return c.ToDuration() >= 0 } diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index ef5537ab..e5cab7a1 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -21,25 +21,34 @@ import ( "golang.org/x/exp/maps" ) -const ( - defaultTimeout = 5 * time.Second -) - var errArbitrarySystemResolverRequest = errors.New( "cannot resolve arbitrary requests using the system resolver", ) +type bootstrapConfig struct { + config.BootstrapDNSConfig + + connectIPVersion config.IPVersion + timeout config.Duration +} + +func newBootstrapConfig(cfg *config.Config) *bootstrapConfig { + return &bootstrapConfig{ + BootstrapDNSConfig: cfg.BootstrapDNS, + + connectIPVersion: cfg.ConnectIPVersion, + timeout: cfg.Upstreams.Timeout, + } +} + // Bootstrap allows resolving hostnames using the configured bootstrap DNS. type Bootstrap struct { - configurable[*config.BootstrapDNSConfig] + configurable[*bootstrapConfig] typed resolver Resolver bootstraped bootstrapedResolvers - connectIPVersion config.IPVersion - timeout time.Duration - // To allow replacing during tests systemResolver *net.Resolver dialer interface { @@ -50,25 +59,15 @@ type Bootstrap struct { // NewBootstrap creates and returns a new Bootstrap. // Internally, it uses a CachingResolver and an UpstreamResolver. func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err error) { - timeout := defaultTimeout - if cfg.Upstreams.Timeout.IsAboveZero() { - timeout = cfg.Upstreams.Timeout.ToDuration() - } - // Create b in multiple steps: Bootstrap and UpstreamResolver have a cyclic dependency // This also prevents the GC to clean up these two structs, but is not currently an // issue since they stay allocated until the process terminates b = &Bootstrap{ - configurable: withConfig(&cfg.BootstrapDNS), + configurable: withConfig(newBootstrapConfig(cfg)), typed: withType("bootstrap"), - connectIPVersion: cfg.ConnectIPVersion, - systemResolver: net.DefaultResolver, - timeout: timeout, - dialer: &net.Dialer{ - Timeout: timeout, - }, + dialer: new(net.Dialer), } bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams) @@ -126,7 +125,7 @@ func (b *Bootstrap) Resolve(ctx context.Context, request *model.Request) (*model func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) { hostname := r.cfg.Host - if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier + if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing tests easier return newIPSet([]net.IP{ip}), nil } @@ -139,20 +138,24 @@ func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSe } func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) { - // Use system resolver if no bootstrap is configured - if b.resolver == nil { - ctx, cancel := context.WithTimeout(ctx, b.timeout) - defer cancel() - - return b.systemResolver.LookupIP(ctx, b.connectIPVersion.Net(), host) - } - if ips, ok := b.bootstraped[r]; ok { // Special path for bootstraped upstreams to avoid infinite recursion return ips, nil } - return b.resolve(ctx, host, b.connectIPVersion.QTypes()) + if b.cfg.timeout.IsAboveZero() { + var cancel context.CancelFunc + + ctx, cancel = context.WithTimeout(ctx, b.cfg.timeout.ToDuration()) + defer cancel() + } + + // Use system resolver if no bootstrap is configured + if b.resolver == nil { + return b.systemResolver.LookupIP(ctx, b.cfg.connectIPVersion.Net(), host) + } + + return b.resolve(ctx, host, b.cfg.connectIPVersion.QTypes()) } // NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames @@ -181,8 +184,8 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net. var qTypes []dns.Type switch { - case b.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured - qTypes = b.connectIPVersion.QTypes() + case b.cfg.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured + qTypes = b.cfg.connectIPVersion.QTypes() case strings.HasSuffix(network, "4"): qTypes = config.IPVersionV4.QTypes() case strings.HasSuffix(network, "6"):