blocky/resolver/ecs_resolver_test.go

289 lines
7.4 KiB
Go

package resolver
import (
"context"
"net"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
"github.com/creasty/defaults"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/mock"
)
var _ = Describe("EcsResolver", func() {
var (
sut *ECSResolver
sutConfig config.ECS
m *mockResolver
mockAnswer *dns.Msg
err error
origIP net.IP
ecsIP net.IP
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
err = defaults.Set(&sutConfig)
Expect(err).Should(Succeed())
mockAnswer = new(dns.Msg)
origIP = net.ParseIP("1.2.3.4").To4()
ecsIP = net.ParseIP("4.3.2.1").To4()
})
JustBeforeEach(func() {
if m == nil {
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{
Res: mockAnswer,
RType: ResponseTypeCUSTOMDNS,
Reason: "Test",
}, nil)
}
sut = NewECSResolver(sutConfig).(*ECSResolver)
sut.Next(m)
})
When("ECS is disabled", func() {
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
})
When("ECS is enabled", func() {
BeforeEach(func() {
sutConfig.UseAsClient = true
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
When("use ECS client ip is enabled", func() {
BeforeEach(func() {
sutConfig.UseAsClient = true
})
It("should change ClientIP with subnet 32", func(ctx context.Context) {
request := newRequest("example.com.", A)
request.ClientIP = origIP
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(ecsIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
It("shouldn't change ClientIP with subnet 24", func(ctx context.Context) {
request := newRequest("example.com.", A)
request.ClientIP = origIP
addEcsOption(request.Req, ecsIP, 24)
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
})
When("add ECS information", func() {
BeforeEach(func() {
sutConfig.IPv4Mask = 32
sutConfig.IPv6Mask = 128
})
It("should add ECS information with subnet 32", func(ctx context.Context) {
request := newRequest("example.com.", A)
request.ClientIP = origIP
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP))
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
It("should add ECS information with subnet 128", func(ctx context.Context) {
request := newRequest("example.com.", AAAA)
request.ClientIP = net.ParseIP("2001:db8::68")
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
})
When("forward ECS information", func() {
BeforeEach(func() {
sutConfig.IPv4Mask = 32
sutConfig.IPv6Mask = 128
sutConfig.Forward = true
})
It("should forward ECS information with subnet 32", func(ctx context.Context) {
request := newRequest("example.com.", A)
request.ClientIP = origIP
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(ecsIP))
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](req.Req)
Expect(so.Address).Should(Equal(ecsIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
When("subnet mask is 24", func() {
BeforeEach(func() {
sutConfig.IPv4Mask = 24
})
It("should modify ECS information", func(ctx context.Context) {
request := newRequest("example.com.", A)
request.ClientIP = origIP
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(ecsIP))
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](req.Req)
Expect(so.Address).Should(Equal(net.ParseIP("4.3.2.0").To4()))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
})
It("should forward ECS information with subnet 128", func(ctx context.Context) {
request := newRequest("example.com.", AAAA)
request.ClientIP = net.ParseIP("2001:db8::68")
addEcsOption(request.Req, net.ParseIP("2001:db8::68"), 128)
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](req.Req)
Expect(so.Address).Should(Equal(net.ParseIP("2001:db8::68")))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
})
})
Context("maskIP", func() {
It("should mask IPv4", func() {
ip := net.ParseIP("192.168.10.123")
mask := config.ECSv4Mask(24)
mip, err := maskIP(ip, mask)
Expect(err).Should(Succeed())
Expect(mip).Should(Equal(net.ParseIP("192.168.10.0").To4()))
})
})
})
// addEcsOption adds the subnet information to the request as EDNS0 option
func addEcsOption(req *dns.Msg, ip net.IP, netmask uint8) {
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
e.SourceScope = ecsSourceScope
e.Family = ecsFamilyIPv4
e.SourceNetmask = netmask
e.Address = ip
util.SetEdns0Option(req, e)
}
// respondWith creates a new Response with the given request and message
func respondWith(res *dns.Msg) *Response {
return &Response{Res: res, RType: ResponseTypeRESOLVED, Reason: "Test"}
}