test(bootstrap): add connectIPVersion tests for HTTP Transport

This commit is contained in:
ThinkChaos 2023-01-21 18:01:23 -05:00
parent 2cb826db22
commit 012c8d49f8
3 changed files with 237 additions and 68 deletions

View File

@ -28,7 +28,11 @@ type Bootstrap struct {
connectIPVersion config.IPVersion
// To allow replacing during tests
systemResolver *net.Resolver
dialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
}
// NewBootstrap creates and returns a new Bootstrap.
@ -42,7 +46,9 @@ func NewBootstrap(cfg *config.Config) (b *Bootstrap, err error) {
b = &Bootstrap{
log: log,
connectIPVersion: cfg.ConnectIPVersion,
systemResolver: net.DefaultResolver, // allow replacing it during tests
systemResolver: net.DefaultResolver,
dialer: &net.Dialer{},
}
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS)
@ -130,52 +136,52 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
return &http.Transport{}
}
dialer := net.Dialer{}
return &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
log := b.log.WithField("network", network).WithField("addr", addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("dial error: %s", err)
return nil, err
}
var qTypes []dns.Type
switch {
case b.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured
qTypes = b.connectIPVersion.QTypes()
case strings.HasSuffix(network, "4"):
qTypes = config.IPVersionV4.QTypes()
case strings.HasSuffix(network, "6"):
qTypes = config.IPVersionV6.QTypes()
default:
qTypes = config.IPVersionDual.QTypes()
}
// Resolve the host with the bootstrap DNS
ips, err := b.resolve(host, qTypes)
if err != nil {
log.Errorf("resolve error: %s", err)
return nil, err
}
ip := ips[rand.Intn(len(ips))] //nolint:gosec
log.WithField("ip", ip).Tracef("dialing %s", host)
// Use the standard dialer to actually connect
addrWithIP := net.JoinHostPort(ip.String(), port)
return dialer.DialContext(ctx, network, addrWithIP)
},
DialContext: b.dialContext,
}
}
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
log := b.log.WithField("network", network).WithField("addr", addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("dial error: %s", err)
return nil, err
}
var qTypes []dns.Type
switch {
case b.connectIPVersion != config.IPVersionDual: // ignore `network` if a specific version is configured
qTypes = b.connectIPVersion.QTypes()
case strings.HasSuffix(network, "4"):
qTypes = config.IPVersionV4.QTypes()
case strings.HasSuffix(network, "6"):
qTypes = config.IPVersionV6.QTypes()
default:
qTypes = config.IPVersionDual.QTypes()
}
// Resolve the host with the bootstrap DNS
ips, err := b.resolve(host, qTypes)
if err != nil {
log.Errorf("resolve error: %s", err)
return nil, err
}
ip := ips[rand.Intn(len(ips))] //nolint:gosec
log.WithField("ip", ip).Tracef("dialing %s", host)
// Use the standard dialer to actually connect
addrWithIP := net.JoinHostPort(ip.String(), port)
return b.dialer.DialContext(ctx, network, addrWithIP)
}
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
ips = make([]net.IP, 0, len(qTypes))

View File

@ -400,55 +400,161 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
})
})
When("connectIPVersion is", func() {
Describe("connectIPVersion", func() {
var (
m *mockResolver
ipVersion config.IPVersion
m *mockResolver
dialIPVersion config.IPVersion
)
BeforeEach(func() {
sutConfig.ConnectIPVersion = ipVersion
dialIPVersion = config.IPVersionDual
})
JustBeforeEach(func() {
m = &mockResolver{AnswerFn: autoAnswer}
sut.resolver = m
m.On("Resolve", mock.Anything).
Times(len(ipVersion.QTypes())).
Times(len(sutConfig.ConnectIPVersion.QTypes())).
Run(func(args mock.Arguments) {
req, ok := args.Get(0).(*model.Request)
Expect(ok).Should(BeTrue())
qType := dns.Type(req.Req.Question[0].Qtype)
Expect(qType).Should(BeElementOf(ipVersion.QTypes()))
if sutConfig.ConnectIPVersion != config.IPVersionDual {
Expect(qType).Should(BeElementOf(sutConfig.ConnectIPVersion.QTypes()))
} else {
Expect(qType).Should(BeElementOf(dialIPVersion.QTypes()))
}
})
})
Describe("resolve", func() {
AfterEach(func() {
_, err := sut.resolveUpstream(nil, "example.com")
Expect(err).Should(Succeed())
m.AssertExpectations(GinkgoT())
})
Context("using dual", func() {
BeforeEach(func() {
sutConfig.ConnectIPVersion = config.IPVersionV4
})
sut.resolver = m
It("should query both IPv4 and IPv6", func() {})
})
Context("using v4", func() {
BeforeEach(func() {
sutConfig.ConnectIPVersion = config.IPVersionV4
})
It("should query IPv4 only", func() {})
})
Context("using v6", func() {
BeforeEach(func() {
sutConfig.ConnectIPVersion = config.IPVersionV6
})
It("should query IPv6 only", func() {})
})
})
AfterEach(func() {
_, err := sut.resolveUpstream(nil, "example.com")
Expect(err).Should(Succeed())
Describe("HTTP Transport", func() {
var (
d *mockDialer
m.AssertExpectations(GinkgoT())
})
validIPs []string
)
Context("dual", func() {
ipVersion = config.IPVersionDual
JustBeforeEach(func() {
d = newMockDialer()
sut.dialer = d
It("should query both IPv4 and IPv6", func() {})
})
m.On("Resolve", mock.Anything).Once()
Context("v4", func() {
ipVersion = config.IPVersionV4
d.On("DialContext", mock.Anything, mock.Anything, mock.Anything).
Once().
Run(func(args mock.Arguments) {
network, ok := args.Get(1).(string)
Expect(ok).Should(BeTrue())
Expect(network).Should(Equal(dialIPVersion.Net()))
It("should query IPv4 only", func() {})
})
addr, ok := args.Get(2).(string)
Expect(ok).Should(BeTrue())
Context("v6", func() {
ipVersion = config.IPVersionV6
ip, port, err := net.SplitHostPort(addr)
Expect(err).Should(Succeed())
Expect(ip).Should(BeElementOf(validIPs))
Expect(port).Should(Equal("0"))
})
})
It("should query IPv6 only", func() {})
AfterEach(func() {
t := sut.NewHTTPTransport()
conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0")
Expect(err).Should(Succeed())
Expect(conn).Should(Equal(aMockConn))
d.AssertExpectations(GinkgoT())
})
Context("using dual", func() {
BeforeEach(func() {
sutConfig.ConnectIPVersion = config.IPVersionDual
validIPs = []string{autoAnswerIPv4.String(), autoAnswerIPv6.String()}
})
It("should dial one of IPv4 and IPv6", func() {})
Context("and dialing IPv4", func() {
BeforeEach(func() {
dialIPVersion = config.IPVersionV4 // overrides ipVersion
validIPs = []string{autoAnswerIPv4.String()}
})
It("should use IPv4 only", func() {})
})
Context("and dialing IPv6", func() {
BeforeEach(func() {
dialIPVersion = config.IPVersionV6 // overrides ipVersion
validIPs = []string{autoAnswerIPv6.String()}
})
It("should use IPv6 only", func() {})
})
})
Context("using v4", func() {
BeforeEach(func() {
sutConfig.ConnectIPVersion = config.IPVersionV4
validIPs = []string{autoAnswerIPv4.String()}
})
It("should dial IPv4 only", func() {})
It("should ignore the dial IP version", func() {
dialIPVersion = config.IPVersionV6 // overridden by ipVersion
})
})
Context("using v6", func() {
BeforeEach(func() {
sutConfig.ConnectIPVersion = config.IPVersionV6
validIPs = []string{autoAnswerIPv6.String()}
})
It("should dial IPv6 only", func() {})
It("should ignore the dial IP version", func() {
dialIPVersion = config.IPVersionV4 // overridden by ipVersion
})
})
})
})

View File

@ -1,11 +1,13 @@
package resolver
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
@ -80,6 +82,11 @@ func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {
return nil, args.Error(1)
}
var (
autoAnswerIPv4 = net.IPv4(127, 0, 0, 1)
autoAnswerIPv6 = net.IPv6loopback
)
// autoAnswer provides a valid fake answer.
//
// To be used as a value for `mockResolver.AnswerFn`.
@ -88,9 +95,9 @@ func autoAnswer(qType dns.Type, qName string) (*dns.Msg, error) {
switch uint16(qType) {
case dns.TypeA:
ip = net.IPv4zero
ip = autoAnswerIPv4
case dns.TypeAAAA:
ip = net.IPv6zero
ip = autoAnswerIPv6
default:
return nil, fmt.Errorf("autoAnswer not implemented for qType=%s", dns.TypeToString[uint16(qType)])
}
@ -155,3 +162,53 @@ func newTestDOHUpstream(fn func(request *dns.Msg) (response *dns.Msg),
return upstream
}
type mockDialer struct {
mock.Mock
}
func newMockDialer() *mockDialer {
return &mockDialer{}
}
func (d *mockDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
d.Called(ctx, network, addr)
return aMockConn, nil
}
var aMockConn = &mockConn{}
type mockConn struct{}
func (c *mockConn) Read(b []byte) (n int, err error) {
panic("not implemented")
}
func (c *mockConn) Write(b []byte) (n int, err error) {
panic("not implemented")
}
func (c *mockConn) Close() error {
panic("not implemented")
}
func (c *mockConn) LocalAddr() net.Addr {
panic("not implemented")
}
func (c *mockConn) RemoteAddr() net.Addr {
panic("not implemented")
}
func (c *mockConn) SetDeadline(t time.Time) error {
panic("not implemented")
}
func (c *mockConn) SetReadDeadline(t time.Time) error {
panic("not implemented")
}
func (c *mockConn) SetWriteDeadline(t time.Time) error {
panic("not implemented")
}