mirror of https://github.com/0xERR0R/blocky.git
test(bootstrap): add connectIPVersion tests for HTTP Transport
This commit is contained in:
parent
2cb826db22
commit
012c8d49f8
|
@ -28,7 +28,11 @@ type Bootstrap struct {
|
|||
|
||||
connectIPVersion config.IPVersion
|
||||
|
||||
// To allow replacing during tests
|
||||
systemResolver *net.Resolver
|
||||
dialer interface {
|
||||
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
}
|
||||
|
||||
// NewBootstrap creates and returns a new Bootstrap.
|
||||
|
@ -42,7 +46,9 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
|
|||
b = &Bootstrap{
|
||||
log: log,
|
||||
connectIPVersion: cfg.ConnectIPVersion,
|
||||
systemResolver: net.DefaultResolver, // allow replacing it during tests
|
||||
|
||||
systemResolver: net.DefaultResolver,
|
||||
dialer: &net.Dialer{},
|
||||
}
|
||||
|
||||
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS)
|
||||
|
@ -130,52 +136,52 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
|||
return &http.Transport{}
|
||||
}
|
||||
|
||||
dialer := net.Dialer{}
|
||||
|
||||
return &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log := b.log.WithField("network", network).WithField("addr", addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("dial error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var qTypes []dns.Type
|
||||
|
||||
switch {
|
||||
case b.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured
|
||||
qTypes = b.connectIPVersion.QTypes()
|
||||
case strings.HasSuffix(network, "4"):
|
||||
qTypes = config.IPVersionV4.QTypes()
|
||||
case strings.HasSuffix(network, "6"):
|
||||
qTypes = config.IPVersionV6.QTypes()
|
||||
default:
|
||||
qTypes = config.IPVersionDual.QTypes()
|
||||
}
|
||||
|
||||
// Resolve the host with the bootstrap DNS
|
||||
ips, err := b.resolve(host, qTypes)
|
||||
if err != nil {
|
||||
log.Errorf("resolve error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip := ips[rand.Intn(len(ips))] //nolint:gosec
|
||||
|
||||
log.WithField("ip", ip).Tracef("dialing %s", host)
|
||||
|
||||
// Use the standard dialer to actually connect
|
||||
addrWithIP := net.JoinHostPort(ip.String(), port)
|
||||
|
||||
return dialer.DialContext(ctx, network, addrWithIP)
|
||||
},
|
||||
DialContext: b.dialContext,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log := b.log.WithField("network", network).WithField("addr", addr)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("dial error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var qTypes []dns.Type
|
||||
|
||||
switch {
|
||||
case b.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured
|
||||
qTypes = b.connectIPVersion.QTypes()
|
||||
case strings.HasSuffix(network, "4"):
|
||||
qTypes = config.IPVersionV4.QTypes()
|
||||
case strings.HasSuffix(network, "6"):
|
||||
qTypes = config.IPVersionV6.QTypes()
|
||||
default:
|
||||
qTypes = config.IPVersionDual.QTypes()
|
||||
}
|
||||
|
||||
// Resolve the host with the bootstrap DNS
|
||||
ips, err := b.resolve(host, qTypes)
|
||||
if err != nil {
|
||||
log.Errorf("resolve error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip := ips[rand.Intn(len(ips))] //nolint:gosec
|
||||
|
||||
log.WithField("ip", ip).Tracef("dialing %s", host)
|
||||
|
||||
// Use the standard dialer to actually connect
|
||||
addrWithIP := net.JoinHostPort(ip.String(), port)
|
||||
|
||||
return b.dialer.DialContext(ctx, network, addrWithIP)
|
||||
}
|
||||
|
||||
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
||||
ips = make([]net.IP, 0, len(qTypes))
|
||||
|
||||
|
|
|
@ -400,55 +400,161 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
})
|
||||
})
|
||||
|
||||
When("connectIPVersion is", func() {
|
||||
Describe("connectIPVersion", func() {
|
||||
var (
|
||||
m *mockResolver
|
||||
ipVersion config.IPVersion
|
||||
m *mockResolver
|
||||
dialIPVersion config.IPVersion
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = ipVersion
|
||||
dialIPVersion = config.IPVersionDual
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
m = &mockResolver{AnswerFn: autoAnswer}
|
||||
sut.resolver = m
|
||||
|
||||
m.On("Resolve", mock.Anything).
|
||||
Times(len(ipVersion.QTypes())).
|
||||
Times(len(sutConfig.ConnectIPVersion.QTypes())).
|
||||
Run(func(args mock.Arguments) {
|
||||
req, ok := args.Get(0).(*model.Request)
|
||||
Expect(ok).Should(BeTrue())
|
||||
|
||||
qType := dns.Type(req.Req.Question[0].Qtype)
|
||||
Expect(qType).Should(BeElementOf(ipVersion.QTypes()))
|
||||
|
||||
if sutConfig.ConnectIPVersion != config.IPVersionDual {
|
||||
Expect(qType).Should(BeElementOf(sutConfig.ConnectIPVersion.QTypes()))
|
||||
} else {
|
||||
Expect(qType).Should(BeElementOf(dialIPVersion.QTypes()))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("resolve", func() {
|
||||
AfterEach(func() {
|
||||
_, err := sut.resolveUpstream(nil, "example.com")
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
|
||||
Context("using dual", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = config.IPVersionV4
|
||||
})
|
||||
|
||||
sut.resolver = m
|
||||
It("should query both IPv4 and IPv6", func() {})
|
||||
})
|
||||
|
||||
Context("using v4", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = config.IPVersionV4
|
||||
})
|
||||
|
||||
It("should query IPv4 only", func() {})
|
||||
})
|
||||
|
||||
Context("using v6", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = config.IPVersionV6
|
||||
})
|
||||
|
||||
It("should query IPv6 only", func() {})
|
||||
})
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_, err := sut.resolveUpstream(nil, "example.com")
|
||||
Expect(err).Should(Succeed())
|
||||
Describe("HTTP Transport", func() {
|
||||
var (
|
||||
d *mockDialer
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
validIPs []string
|
||||
)
|
||||
|
||||
Context("dual", func() {
|
||||
ipVersion = config.IPVersionDual
|
||||
JustBeforeEach(func() {
|
||||
d = newMockDialer()
|
||||
sut.dialer = d
|
||||
|
||||
It("should query both IPv4 and IPv6", func() {})
|
||||
})
|
||||
m.On("Resolve", mock.Anything).Once()
|
||||
|
||||
Context("v4", func() {
|
||||
ipVersion = config.IPVersionV4
|
||||
d.On("DialContext", mock.Anything, mock.Anything, mock.Anything).
|
||||
Once().
|
||||
Run(func(args mock.Arguments) {
|
||||
network, ok := args.Get(1).(string)
|
||||
Expect(ok).Should(BeTrue())
|
||||
Expect(network).Should(Equal(dialIPVersion.Net()))
|
||||
|
||||
It("should query IPv4 only", func() {})
|
||||
})
|
||||
addr, ok := args.Get(2).(string)
|
||||
Expect(ok).Should(BeTrue())
|
||||
|
||||
Context("v6", func() {
|
||||
ipVersion = config.IPVersionV6
|
||||
ip, port, err := net.SplitHostPort(addr)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(ip).Should(BeElementOf(validIPs))
|
||||
Expect(port).Should(Equal("0"))
|
||||
})
|
||||
})
|
||||
|
||||
It("should query IPv6 only", func() {})
|
||||
AfterEach(func() {
|
||||
t := sut.NewHTTPTransport()
|
||||
|
||||
conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0")
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(conn).Should(Equal(aMockConn))
|
||||
|
||||
d.AssertExpectations(GinkgoT())
|
||||
})
|
||||
|
||||
Context("using dual", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = config.IPVersionDual
|
||||
validIPs = []string{autoAnswerIPv4.String(), autoAnswerIPv6.String()}
|
||||
})
|
||||
|
||||
It("should dial one of IPv4 and IPv6", func() {})
|
||||
|
||||
Context("and dialing IPv4", func() {
|
||||
BeforeEach(func() {
|
||||
dialIPVersion = config.IPVersionV4 // overrides ipVersion
|
||||
validIPs = []string{autoAnswerIPv4.String()}
|
||||
})
|
||||
|
||||
It("should use IPv4 only", func() {})
|
||||
})
|
||||
|
||||
Context("and dialing IPv6", func() {
|
||||
BeforeEach(func() {
|
||||
dialIPVersion = config.IPVersionV6 // overrides ipVersion
|
||||
validIPs = []string{autoAnswerIPv6.String()}
|
||||
})
|
||||
|
||||
It("should use IPv6 only", func() {})
|
||||
})
|
||||
})
|
||||
|
||||
Context("using v4", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = config.IPVersionV4
|
||||
validIPs = []string{autoAnswerIPv4.String()}
|
||||
})
|
||||
|
||||
It("should dial IPv4 only", func() {})
|
||||
|
||||
It("should ignore the dial IP version", func() {
|
||||
dialIPVersion = config.IPVersionV6 // overridden by ipVersion
|
||||
})
|
||||
})
|
||||
|
||||
Context("using v6", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig.ConnectIPVersion = config.IPVersionV6
|
||||
validIPs = []string{autoAnswerIPv6.String()}
|
||||
})
|
||||
|
||||
It("should dial IPv6 only", func() {})
|
||||
|
||||
It("should ignore the dial IP version", func() {
|
||||
dialIPVersion = config.IPVersionV4 // overridden by ipVersion
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
@ -80,6 +82,11 @@ func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {
|
|||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
var (
|
||||
autoAnswerIPv4 = net.IPv4(127, 0, 0, 1)
|
||||
autoAnswerIPv6 = net.IPv6loopback
|
||||
)
|
||||
|
||||
// autoAnswer provides a valid fake answer.
|
||||
//
|
||||
// To be used as a value for `mockResolver.AnswerFn`.
|
||||
|
@ -88,9 +95,9 @@ func autoAnswer(qType dns.Type, qName string) (*dns.Msg, error) {
|
|||
|
||||
switch uint16(qType) {
|
||||
case dns.TypeA:
|
||||
ip = net.IPv4zero
|
||||
ip = autoAnswerIPv4
|
||||
case dns.TypeAAAA:
|
||||
ip = net.IPv6zero
|
||||
ip = autoAnswerIPv6
|
||||
default:
|
||||
return nil, fmt.Errorf("autoAnswer not implemented for qType=%s", dns.TypeToString[uint16(qType)])
|
||||
}
|
||||
|
@ -155,3 +162,53 @@ func newTestDOHUpstream(fn func(request *dns.Msg) (response *dns.Msg),
|
|||
|
||||
return upstream
|
||||
}
|
||||
|
||||
type mockDialer struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func newMockDialer() *mockDialer {
|
||||
return &mockDialer{}
|
||||
}
|
||||
|
||||
func (d *mockDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
d.Called(ctx, network, addr)
|
||||
|
||||
return aMockConn, nil
|
||||
}
|
||||
|
||||
var aMockConn = &mockConn{}
|
||||
|
||||
type mockConn struct{}
|
||||
|
||||
func (c *mockConn) Read(b []byte) (n int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(b []byte) (n int, err error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) LocalAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) RemoteAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) SetDeadline(t time.Time) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) SetReadDeadline(t time.Time) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *mockConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue