blocky/resolver/mocks_test.go

237 lines
4.9 KiB
Go

package resolver
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
)
type mockResolver struct {
mock.Mock
NextResolver
ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error)
ResponseFn func(req *dns.Msg) *dns.Msg
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
}
// Type implements `Resolver`.
func (r *mockResolver) Type() string {
return "mock"
}
// IsEnabled implements `config.Configurable`.
func (r *mockResolver) IsEnabled() bool {
args := r.Called()
return args.Get(0).(bool)
}
// LogConfig implements `config.Configurable`.
func (r *mockResolver) LogConfig(*logrus.Entry) {
r.Called()
}
func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
args := r.Called(req)
if r.ResolveFn != nil {
return r.ResolveFn(ctx, req)
}
if r.ResponseFn != nil {
return &model.Response{
Res: r.ResponseFn(req.Req),
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
if r.AnswerFn != nil {
for _, question := range req.Req.Question {
answer, err := r.AnswerFn(dns.Type(question.Qtype), question.Name)
if err != nil {
return nil, fmt.Errorf("AnswerFn error: %w", err)
}
if answer != nil {
return &model.Response{
Res: answer,
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
}
response := new(dns.Msg)
response.SetRcode(req.Req, dns.RcodeBadName)
return &model.Response{
Res: response,
Reason: "",
RType: model.ResponseTypeRESOLVED,
}, nil
}
resp, ok := args.Get(0).(*model.Response)
if ok {
return resp, args.Error(1)
}
return nil, args.Error(1)
}
var (
autoAnswerIPv4 = net.IPv4(127, 0, 0, 1)
autoAnswerIPv6 = net.IPv6loopback
)
// autoAnswer provides a valid fake answer.
//
// To be used as a value for `mockResolver.AnswerFn`.
func autoAnswer(qType dns.Type, qName string) (*dns.Msg, error) {
var ip net.IP
switch uint16(qType) {
case dns.TypeA:
ip = autoAnswerIPv4
case dns.TypeAAAA:
ip = autoAnswerIPv6
default:
return nil, fmt.Errorf("autoAnswer not implemented for qType=%s", dns.TypeToString[uint16(qType)])
}
return util.NewMsgWithAnswer(qName, 60, qType, ip.String())
}
// newTestBootstrap creates a test Bootstrap
func newTestBootstrap(ctx context.Context, response *dns.Msg) *Bootstrap {
const cfgTxt = `
upstream: https://mock
ips:
- 0.0.0.0
`
bootstrapUpstream := &mockResolver{}
var bCfg config.BootstrapDNS
err := yaml.UnmarshalStrict([]byte(cfgTxt), &bCfg)
util.FatalOnError("test bootstrap config is broken, did you change the struct?", err)
b, err := NewBootstrap(ctx, &config.Config{BootstrapDNS: bCfg})
util.FatalOnError("can't create bootstrap", err)
b.resolver = bootstrapUpstream
if response != nil {
bootstrapUpstream.
On("Resolve", mock.Anything).
Return(&model.Response{Res: response}, nil)
}
return b
}
// newTestDOHUpstream creates a test DoH Upstream
func newTestDOHUpstream(fn func(request *dns.Msg) (response *dns.Msg),
reqFn ...func(w http.ResponseWriter),
) config.Upstream {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
util.FatalOnError("can't read request: ", err)
msg := new(dns.Msg)
err = msg.Unpack(body)
util.FatalOnError("can't deserialize message: ", err)
response := fn(msg)
response.SetReply(msg)
b, err := response.Pack()
util.FatalOnError("can't serialize message: ", err)
w.Header().Set("content-type", "application/dns-message")
for _, f := range reqFn {
if f != nil {
f(w)
}
}
_, err = w.Write(b)
util.FatalOnError("can't write response: ", err)
}))
upstream, err := config.ParseUpstream(server.URL)
util.FatalOnError("can't resolve address: ", err)
return upstream
}
type mockDialer struct {
mock.Mock
}
func newMockDialer() *mockDialer {
return &mockDialer{}
}
func (d *mockDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
d.Called(ctx, network, addr)
return aMockConn, nil
}
var aMockConn = &mockConn{}
type mockConn struct{}
func (c *mockConn) Read([]byte) (n int, err error) {
panic("not implemented")
}
func (c *mockConn) Write([]byte) (n int, err error) {
panic("not implemented")
}
func (c *mockConn) Close() error {
panic("not implemented")
}
func (c *mockConn) LocalAddr() net.Addr {
panic("not implemented")
}
func (c *mockConn) RemoteAddr() net.Addr {
panic("not implemented")
}
func (c *mockConn) SetDeadline(time.Time) error {
panic("not implemented")
}
func (c *mockConn) SetReadDeadline(time.Time) error {
panic("not implemented")
}
func (c *mockConn) SetWriteDeadline(time.Time) error {
panic("not implemented")
}