feat: add timeout to bootstrap (#1158)

This commit is contained in:
Dimitri Herzog 2023-09-20 22:41:55 +02:00 committed by GitHub
parent 65ff6847ad
commit 6f60bea5c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 22 deletions

View File

@ -19,6 +19,10 @@ import (
"github.com/sirupsen/logrus"
)
const (
defaultTimeout = 5 * time.Second
)
// Bootstrap allows resolving hostnames using the configured bootstrap DNS.
type Bootstrap struct {
log *logrus.Entry
@ -27,6 +31,7 @@ type Bootstrap struct {
bootstraped bootstrapedResolvers
connectIPVersion config.IPVersion
timeout time.Duration
// To allow replacing during tests
systemResolver *net.Resolver
@ -38,17 +43,25 @@ type Bootstrap struct {
// NewBootstrap creates and returns a new Bootstrap.
// Internally, it uses a CachingResolver and an UpstreamResolver.
func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
log := log.PrefixedLog("bootstrap")
logger := log.PrefixedLog("bootstrap")
timeout := defaultTimeout
if cfg.Upstreams.Timeout.IsAboveZero() {
timeout = cfg.Upstreams.Timeout.ToDuration()
}
// Create b in multiple steps: Bootstrap and UpstreamResolver have a cyclic dependency
// This also prevents the GC to clean up these two structs, but is not currently an
// issue since they stay allocated until the process terminates
b = &Bootstrap{
log: log,
log: logger,
connectIPVersion: cfg.ConnectIPVersion,
systemResolver: net.DefaultResolver,
dialer: &net.Dialer{},
timeout: timeout,
dialer: &net.Dialer{
Timeout: timeout,
},
}
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS)
@ -57,7 +70,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
}
if len(bootstraped) == 0 {
log.Infof("bootstrapDns is not configured, will use system resolver")
logger.Infof("bootstrapDns is not configured, will use system resolver")
return b, nil
}
@ -109,18 +122,10 @@ func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
// Use system resolver if no bootstrap is configured
if b.resolver == nil {
cfg := config.GetConfig()
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), b.timeout)
defer cancel()
timeout := cfg.Upstreams.Timeout
if timeout.IsAboveZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration())
defer cancel()
}
return b.systemResolver.LookupIP(ctx, cfg.ConnectIPVersion.Net(), host)
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
}
if ips, ok := b.bootstraped[r]; ok {
@ -134,7 +139,9 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
func (b *Bootstrap) NewHTTPTransport() *http.Transport {
if b.resolver == nil {
return &http.Transport{}
return &http.Transport{
DialContext: b.dialer.DialContext,
}
}
return &http.Transport{
@ -143,11 +150,11 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
}
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log := b.log.WithField("network", network).WithField("addr", addr)
logger := b.log.WithField("network", network).WithField("addr", addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("dial error: %s", err)
logger.Errorf("dial error: %s", err)
return nil, err
}
@ -168,14 +175,14 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
// Resolve the host with the bootstrap DNS
ips, err := b.resolve(host, qTypes)
if err != nil {
log.Errorf("resolve error: %s", err)
logger.Errorf("resolve error: %s", err)
return nil, err
}
ip := ips[rand.Intn(len(ips))] //nolint:gosec
log.WithField("ip", ip).Tracef("dialing %s", host)
logger.WithField("ip", ip).Tracef("dialing %s", host)
// Use the standard dialer to actually connect
addrWithIP := net.JoinHostPort(ip.String(), port)

View File

@ -76,7 +76,6 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
transport := sut.NewHTTPTransport()
Expect(transport).ShouldNot(BeNil())
Expect(*transport).Should(BeZero()) //nolint:govet
})
})

View File

@ -11,7 +11,7 @@ import (
. "github.com/onsi/gomega"
)
var systemResolverBootstrap = &Bootstrap{}
var systemResolverBootstrap = &Bootstrap{dialer: newMockDialer()}
var _ = Describe("Resolver", func() {
Describe("Chains", func() {