mirror of https://github.com/0xERR0R/blocky.git
feat: add timeout to bootstrap (#1158)
This commit is contained in:
parent
65ff6847ad
commit
6f60bea5c2
|
@ -19,6 +19,10 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Bootstrap allows resolving hostnames using the configured bootstrap DNS.
|
||||
type Bootstrap struct {
|
||||
log *logrus.Entry
|
||||
|
@ -27,6 +31,7 @@ type Bootstrap struct {
|
|||
bootstraped bootstrapedResolvers
|
||||
|
||||
connectIPVersion config.IPVersion
|
||||
timeout time.Duration
|
||||
|
||||
// To allow replacing during tests
|
||||
systemResolver *net.Resolver
|
||||
|
@ -38,17 +43,25 @@ type Bootstrap struct {
|
|||
// NewBootstrap creates and returns a new Bootstrap.
|
||||
// Internally, it uses a CachingResolver and an UpstreamResolver.
|
||||
func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
||||
log := log.PrefixedLog("bootstrap")
|
||||
logger := log.PrefixedLog("bootstrap")
|
||||
|
||||
timeout := defaultTimeout
|
||||
if cfg.Upstreams.Timeout.IsAboveZero() {
|
||||
timeout = cfg.Upstreams.Timeout.ToDuration()
|
||||
}
|
||||
|
||||
// Create b in multiple steps: Bootstrap and UpstreamResolver have a cyclic dependency
|
||||
// This also prevents the GC to clean up these two structs, but is not currently an
|
||||
// issue since they stay allocated until the process terminates
|
||||
b = &Bootstrap{
|
||||
log: log,
|
||||
log: logger,
|
||||
connectIPVersion: cfg.ConnectIPVersion,
|
||||
|
||||
systemResolver: net.DefaultResolver,
|
||||
dialer: &net.Dialer{},
|
||||
timeout: timeout,
|
||||
dialer: &net.Dialer{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
|
||||
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS)
|
||||
|
@ -57,7 +70,7 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
}
|
||||
|
||||
if len(bootstraped) == 0 {
|
||||
log.Infof("bootstrapDns is not configured, will use system resolver")
|
||||
logger.Infof("bootstrapDns is not configured, will use system resolver")
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
@ -109,18 +122,10 @@ func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
|
|||
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
||||
// Use system resolver if no bootstrap is configured
|
||||
if b.resolver == nil {
|
||||
cfg := config.GetConfig()
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), b.timeout)
|
||||
defer cancel()
|
||||
|
||||
timeout := cfg.Upstreams.Timeout
|
||||
if timeout.IsAboveZero() {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout.ToDuration())
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
return b.systemResolver.LookupIP(ctx, cfg.ConnectIPVersion.Net(), host)
|
||||
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
|
||||
}
|
||||
|
||||
if ips, ok := b.bootstraped[r]; ok {
|
||||
|
@ -134,7 +139,9 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
|||
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
|
||||
func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
||||
if b.resolver == nil {
|
||||
return &http.Transport{}
|
||||
return &http.Transport{
|
||||
DialContext: b.dialer.DialContext,
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Transport{
|
||||
|
@ -143,11 +150,11 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
|||
}
|
||||
|
||||
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log := b.log.WithField("network", network).WithField("addr", addr)
|
||||
logger := b.log.WithField("network", network).WithField("addr", addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("dial error: %s", err)
|
||||
logger.Errorf("dial error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
@ -168,14 +175,14 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
|
|||
// Resolve the host with the bootstrap DNS
|
||||
ips, err := b.resolve(host, qTypes)
|
||||
if err != nil {
|
||||
log.Errorf("resolve error: %s", err)
|
||||
logger.Errorf("resolve error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip := ips[rand.Intn(len(ips))] //nolint:gosec
|
||||
|
||||
log.WithField("ip", ip).Tracef("dialing %s", host)
|
||||
logger.WithField("ip", ip).Tracef("dialing %s", host)
|
||||
|
||||
// Use the standard dialer to actually connect
|
||||
addrWithIP := net.JoinHostPort(ip.String(), port)
|
||||
|
|
|
@ -76,7 +76,6 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
transport := sut.NewHTTPTransport()
|
||||
|
||||
Expect(transport).ShouldNot(BeNil())
|
||||
Expect(*transport).Should(BeZero()) //nolint:govet
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var systemResolverBootstrap = &Bootstrap{}
|
||||
var systemResolverBootstrap = &Bootstrap{dialer: newMockDialer()}
|
||||
|
||||
var _ = Describe("Resolver", func() {
|
||||
Describe("Chains", func() {
|
||||
|
|
Loading…
Reference in New Issue