blocky/resolver/mocks.go

127 lines
2.7 KiB
Go
Raw Normal View History

2020-01-12 18:23:35 +01:00
package resolver
import (
2020-02-17 22:06:10 +01:00
"io/ioutil"
2020-01-12 18:23:35 +01:00
"net"
2020-02-17 22:06:10 +01:00
"net/http"
"net/http/httptest"
2020-01-12 18:23:35 +01:00
"strconv"
"strings"
2021-08-25 22:06:34 +02:00
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
2020-01-12 18:23:35 +01:00
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"
)
type resolverMock struct {
mock.Mock
NextResolver
}
func (r *resolverMock) Configuration() (result []string) {
return
}
func (r *resolverMock) Resolve(req *Request) (*Response, error) {
args := r.Called(req)
2020-02-14 21:58:58 +01:00
resp, ok := args.Get(0).(*Response)
if ok {
return resp, args.Error((1))
}
return nil, args.Error(1)
2020-01-12 18:23:35 +01:00
}
// TestDOHUpstream creates a mock DoH Upstream
2020-02-17 22:40:58 +01:00
func TestDOHUpstream(fn func(request *dns.Msg) (response *dns.Msg),
reqFn ...func(w http.ResponseWriter)) config.Upstream {
2020-02-17 22:06:10 +01:00
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't read request: ", err)
2020-02-17 22:06:10 +01:00
msg := new(dns.Msg)
err = msg.Unpack(body)
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't deserialize message: ", err)
2020-02-17 22:06:10 +01:00
response := fn(msg)
response.SetReply(msg)
b, err := response.Pack()
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't serialize message: ", err)
2020-02-17 22:06:10 +01:00
w.Header().Set("content-type", "application/dns-message")
2020-02-17 22:40:58 +01:00
for _, f := range reqFn {
2020-05-04 22:20:13 +02:00
if f != nil {
f(w)
}
2020-02-17 22:40:58 +01:00
}
2020-02-17 22:06:10 +01:00
_, err = w.Write(b)
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't write response: ", err)
2020-02-17 22:06:10 +01:00
}))
upstream, err := config.ParseUpstream(server.URL)
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't resolve address: ", err)
2020-02-17 22:06:10 +01:00
return upstream
}
2021-02-26 21:44:53 +01:00
// TestUDPUpstream creates a mock UDP upstream
2020-05-04 22:20:13 +02:00
//nolint:funlen
2020-01-12 18:23:35 +01:00
func TestUDPUpstream(fn func(request *dns.Msg) (response *dns.Msg)) config.Upstream {
a, err := net.ResolveUDPAddr("udp4", ":0")
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't resolve address: ", err)
2020-01-12 18:23:35 +01:00
ln, err := net.ListenUDP("udp4", a)
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't create connection: ", err)
2020-01-12 18:23:35 +01:00
ladr := ln.LocalAddr().String()
host := strings.Split(ladr, ":")[0]
2021-02-04 22:20:43 +01:00
p, err := strconv.ParseUint(strings.Split(ladr, ":")[1], 10, 16)
2020-01-12 18:23:35 +01:00
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't convert port: ", err)
2020-01-12 18:23:35 +01:00
port := uint16(p)
go func() {
for {
buffer := make([]byte, 1024)
n, addr, err := ln.ReadFromUDP(buffer)
2021-01-19 21:52:24 +01:00
util.FatalOnError("error on reading from udp: ", err)
2020-01-12 18:23:35 +01:00
msg := new(dns.Msg)
err = msg.Unpack(buffer[0 : n-1])
2020-01-17 21:53:15 +01:00
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't deserialize message: ", err)
2020-01-12 18:23:35 +01:00
response := fn(msg)
2020-05-04 22:20:13 +02:00
// nil should indicate an error
if response == nil {
_, _ = ln.WriteToUDP([]byte("dummy"), addr)
continue
2020-05-04 22:20:13 +02:00
}
rCode := response.Rcode
2020-01-12 18:23:35 +01:00
response.SetReply(msg)
2020-05-04 22:20:13 +02:00
if rCode != 0 {
response.Rcode = rCode
}
2020-01-12 18:23:35 +01:00
b, err := response.Pack()
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't serialize message: ", err)
2020-01-12 18:23:35 +01:00
_, err = ln.WriteToUDP(b, addr)
2021-01-19 21:52:24 +01:00
util.FatalOnError("can't write to UDP: ", err)
2020-01-12 18:23:35 +01:00
}
}()
return config.Upstream{Net: "tcp+udp", Host: host, Port: port}
2020-01-12 18:23:35 +01:00
}