mirror of https://github.com/0xERR0R/blocky.git
156 lines
4.0 KiB
Go
156 lines
4.0 KiB
Go
package resolver
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
|
|
"github.com/0xERR0R/blocky/config"
|
|
"github.com/0xERR0R/blocky/log"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var systemResolverBootstrap = &Bootstrap{
|
|
dialer: newMockDialer(),
|
|
configurable: withConfig(newBootstrapConfig(&config.Config{Upstreams: defaultUpstreamsConfig})),
|
|
}
|
|
|
|
var _ = Describe("Resolver", func() {
|
|
Describe("Chains", func() {
|
|
var (
|
|
r1 ChainedResolver
|
|
r2 ChainedResolver
|
|
r3 ChainedResolver
|
|
r4 Resolver
|
|
)
|
|
|
|
BeforeEach(func() {
|
|
r1 = &mockResolver{}
|
|
r2 = &mockResolver{}
|
|
r3 = &mockResolver{}
|
|
r4 = &NoOpResolver{}
|
|
})
|
|
|
|
Describe("Chain", func() {
|
|
It("should create a chain iterable using `GetNext`", func() {
|
|
ch := Chain(r1, r2, r3, r4)
|
|
Expect(ch).ShouldNot(BeNil())
|
|
Expect(ch).Should(Equal(r1))
|
|
Expect(r1.GetNext()).Should(Equal(r2))
|
|
Expect(r2.GetNext()).Should(Equal(r3))
|
|
Expect(r3.GetNext()).Should(Equal(r4))
|
|
})
|
|
|
|
It("should not link a final ChainedResolver", func() {
|
|
ch := Chain(r1, r2)
|
|
Expect(ch).ShouldNot(BeNil())
|
|
|
|
Expect(r1.GetNext()).Should(Equal(r2))
|
|
Expect(r2.GetNext()).Should(BeNil())
|
|
})
|
|
})
|
|
|
|
Describe("GetFromChainWithType", func() {
|
|
It("should return resolver with type", func() {
|
|
ch := Chain(&CustomDNSResolver{}, &BlockingResolver{})
|
|
res, err := GetFromChainWithType[*BlockingResolver](ch)
|
|
var expectedResolver *BlockingResolver
|
|
Expect(err).Should(Succeed())
|
|
Expect(res).Should(BeAssignableToTypeOf(expectedResolver))
|
|
})
|
|
It("should fail if chain does not contain the desired type", func() {
|
|
ch := Chain(&CustomDNSResolver{}, &BlockingResolver{})
|
|
_, err := GetFromChainWithType[*FilteringResolver](ch)
|
|
|
|
Expect(err).Should(HaveOccurred())
|
|
})
|
|
})
|
|
|
|
Describe("ForEach", func() {
|
|
It("should iterate on all resolvers in the chain", func() {
|
|
ch := Chain(r1, r2, r3, r4)
|
|
Expect(ch).ShouldNot(BeNil())
|
|
|
|
var itResult []Resolver
|
|
|
|
ForEach(ch, func(r Resolver) {
|
|
itResult = append(itResult, r)
|
|
})
|
|
|
|
Expect(itResult).ShouldNot(BeEmpty())
|
|
Expect(itResult).Should(Equal([]Resolver{r1, r2, r3, r4}))
|
|
})
|
|
})
|
|
|
|
Describe("LogResolverConfig", func() {
|
|
It("should call the resolver's `LogConfig`", func() {
|
|
logger := logrus.NewEntry(log.Log())
|
|
|
|
m := &mockResolver{}
|
|
m.On("IsEnabled").Return(true)
|
|
m.On("LogConfig")
|
|
|
|
LogResolverConfig(m, logger)
|
|
|
|
m.AssertExpectations(GinkgoT())
|
|
})
|
|
|
|
When("the resolver is disabled", func() {
|
|
It("should not call the resolver's `LogConfig`", func() {
|
|
logger := logrus.NewEntry(log.Log())
|
|
|
|
m := &mockResolver{}
|
|
m.On("IsEnabled").Return(false)
|
|
|
|
LogResolverConfig(m, logger)
|
|
|
|
m.AssertExpectations(GinkgoT())
|
|
})
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Name", func() {
|
|
var (
|
|
ctx context.Context
|
|
cancelFn context.CancelFunc
|
|
)
|
|
BeforeEach(func() {
|
|
ctx, cancelFn = context.WithCancel(context.Background())
|
|
DeferCleanup(cancelFn)
|
|
})
|
|
When("'Name' is called", func() {
|
|
It("should return resolver name", func() {
|
|
br, _ := NewBlockingResolver(ctx, config.Blocking{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
|
name := Name(br)
|
|
Expect(name).Should(Equal("blocking"))
|
|
})
|
|
})
|
|
When("'Name' is called on a NamedResolver", func() {
|
|
It("should return its custom name", func() {
|
|
br, _ := NewBlockingResolver(ctx, config.Blocking{BlockType: "zeroIP"}, nil, systemResolverBootstrap)
|
|
|
|
cfg := config.RewriterConfig{Rewrite: map[string]string{"not": "empty"}}
|
|
r := NewRewriterResolver(cfg, br)
|
|
|
|
name := Name(r)
|
|
Expect(name).Should(Equal("blocking w/ rewrite"))
|
|
})
|
|
})
|
|
})
|
|
})
|
|
|
|
func expectValidResolverType(sut Resolver) {
|
|
By("it must not contain spaces", func() {
|
|
Expect(sut.Type()).ShouldNot(ContainSubstring(" "))
|
|
})
|
|
By("it must be lower case", func() {
|
|
Expect(sut.Type()).Should(Equal(strings.ToLower(sut.Type())))
|
|
})
|
|
By("it must not contain 'resolver'", func() {
|
|
Expect(sut.Type()).ShouldNot(ContainSubstring("resolver"))
|
|
})
|
|
}
|