refactor: make use of contexts in more places

- `CacheControl.FlushCaches`
- `Querier.Query`
- `Resolver.Resolve`

Besides all the API churn, this leads to `ParallelBestResolver`,
`StrictResolver` and `UpstreamResolver` simplification: timeouts only
need to be setup in one place, `UpstreamResolver`.

We also benefit from using HTTP request contexts, so if the client
closes the connection we stop processing on our side.
This commit is contained in:
ThinkChaos 2023-11-19 15:47:50 -05:00
parent e4ebc16ccc
commit eae99ec550
52 changed files with 796 additions and 627 deletions

View File

@ -40,11 +40,11 @@ type ListRefresher interface {
} }
type Querier interface { type Querier interface {
Query(question string, qType dns.Type) (*model.Response, error) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
} }
type CacheControl interface { type CacheControl interface {
FlushCaches() FlushCaches(ctx context.Context)
} }
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) { func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
@ -137,13 +137,13 @@ func (i *OpenAPIInterfaceImpl) ListRefresh(_ context.Context,
return ListRefresh200Response{}, nil return ListRefresh200Response{}, nil
} }
func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObject) (QueryResponseObject, error) { func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestObject) (QueryResponseObject, error) {
qType := dns.Type(dns.StringToType[request.Body.Type]) qType := dns.Type(dns.StringToType[request.Body.Type])
if qType == dns.Type(dns.TypeNone) { if qType == dns.Type(dns.TypeNone) {
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
} }
resp, err := i.querier.Query(dns.Fqdn(request.Body.Query), qType) resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -156,10 +156,10 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
}), nil }), nil
} }
func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context, func (i *OpenAPIInterfaceImpl) CacheFlush(ctx context.Context,
_ CacheFlushRequestObject, _ CacheFlushRequestObject,
) (CacheFlushResponseObject, error) { ) (CacheFlushResponseObject, error) {
i.cacheControl.FlushCaches() i.cacheControl.FlushCaches(ctx)
return CacheFlush200Response{}, nil return CacheFlush200Response{}, nil
} }

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"time" "time"
// . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
return args.Get(0).(BlockingStatus) return args.Get(0).(BlockingStatus)
} }
func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) { func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
args := m.Called(question, qType) args := m.Called(ctx, question, qType)
return args.Get(0).(*model.Response), args.Error(1) return args.Get(0).(*model.Response), args.Error(1)
} }
func (m *CacheControlMock) FlushCaches() { func (m *CacheControlMock) FlushCaches(ctx context.Context) {
_ = m.Called() _ = m.Called(ctx)
} }
var _ = Describe("API implementation tests", func() { var _ = Describe("API implementation tests", func() {
@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
listRefreshMock *ListRefreshMock listRefreshMock *ListRefreshMock
cacheControlMock *CacheControlMock cacheControlMock *CacheControlMock
sut *OpenAPIInterfaceImpl sut *OpenAPIInterfaceImpl
ctx context.Context
cancelFn context.CancelFunc
) )
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
blockingControlMock = &BlockingControlMock{} blockingControlMock = &BlockingControlMock{}
querierMock = &QuerierMock{} querierMock = &QuerierMock{}
listRefreshMock = &ListRefreshMock{} listRefreshMock = &ListRefreshMock{}
@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
) )
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
querierMock.On("Query", "google.com.", A).Return(&model.Response{ querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
Res: queryResponse, Res: queryResponse,
Reason: "reason", Reason: "reason",
}, nil) }, nil)
resp, err := sut.Query(context.Background(), QueryRequestObject{ resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{ Body: &ApiQueryRequest{
Query: "google.com", Type: "A", Query: "google.com", Type: "A",
}, },
@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
}) })
It("should return 400 on wrong parameter", func() { It("should return 400 on wrong parameter", func() {
resp, err := sut.Query(context.Background(), QueryRequestObject{ resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{ Body: &ApiQueryRequest{
Query: "google.com", Query: "google.com",
Type: "WRONGTYPE", Type: "WRONGTYPE",
@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() { It("should return 200 on success", func() {
listRefreshMock.On("RefreshLists").Return(nil) listRefreshMock.On("RefreshLists").Return(nil)
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{}) resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
var resp200 ListRefresh200Response var resp200 ListRefresh200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200)) Expect(resp).Should(BeAssignableToTypeOf(resp200))
@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 500 on failure", func() { It("should return 500 on failure", func() {
listRefreshMock.On("RefreshLists").Return(errors.New("failed")) listRefreshMock.On("RefreshLists").Return(errors.New("failed"))
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{}) resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
var resp500 ListRefresh500TextResponse var resp500 ListRefresh500TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp500)) Expect(resp).Should(BeAssignableToTypeOf(resp500))
@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
duration := "3s" duration := "3s"
grroups := "gr1,gr2" grroups := "gr1,gr2"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{ resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{ Params: DisableBlockingParams{
Duration: &duration, Duration: &duration,
Groups: &grroups, Groups: &grroups,
@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 400 on failure", func() { It("should return 400 on failure", func() {
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed")) blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{}) resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
var resp400 DisableBlocking400TextResponse var resp400 DisableBlocking400TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp400)) Expect(resp).Should(BeAssignableToTypeOf(resp400))
@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 400 on wrong duration parameter", func() { It("should return 400 on wrong duration parameter", func() {
wrongDuration := "4sds" wrongDuration := "4sds"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{ resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{ Params: DisableBlockingParams{
Duration: &wrongDuration, Duration: &wrongDuration,
}, },
@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() { It("should return 200 on success", func() {
blockingControlMock.On("EnableBlocking").Return() blockingControlMock.On("EnableBlocking").Return()
resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{}) resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
var resp200 EnableBlocking200Response var resp200 EnableBlocking200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200)) Expect(resp).Should(BeAssignableToTypeOf(resp200))
@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
AutoEnableInSec: 47, AutoEnableInSec: 47,
}) })
resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{}) resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
var resp200 BlockingStatus200JSONResponse var resp200 BlockingStatus200JSONResponse
Expect(resp).Should(BeAssignableToTypeOf(resp200)) Expect(resp).Should(BeAssignableToTypeOf(resp200))
@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
Describe("Cache API", func() { Describe("Cache API", func() {
When("Cache flush is called", func() { When("Cache flush is called", func() {
It("should return 200 on success", func() { It("should return 200 on success", func() {
cacheControlMock.On("FlushCaches").Return() cacheControlMock.On("FlushCaches", ctx).Return()
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{}) resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
var resp200 CacheFlush200Response var resp200 CacheFlush200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200)) Expect(resp).Should(BeAssignableToTypeOf(resp200))

View File

@ -37,7 +37,7 @@ type Options struct {
// OnExpirationCallback will be called just before an element gets expired and will // OnExpirationCallback will be called just before an element gets expired and will
// be removed from cache. This function can return new value and TTL to leave the // be removed from cache. This function can return new value and TTL to leave the
// element in the cache or nil to remove it // element in the cache or nil to remove it
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration) type OnExpirationCallback[T any] func(ctx context.Context, key string) (val *T, ttl time.Duration)
// OnCacheHitCallback will be called on cache get if entry was found // OnCacheHitCallback will be called on cache get if entry was found
type OnCacheHitCallback func(key string) type OnCacheHitCallback func(key string)
@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
l, _ := lru.New(defaultSize) l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{ c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval, cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) { preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
return nil, 0 return nil, 0
}, },
onCacheHit: func(key string) {}, onCacheHit: func(key string) {},
@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
var keysToDelete []string var keysToDelete []string
for _, key := range expiredKeys { for _, key := range expiredKeys {
newVal, newTTL := e.preExpirationFn(key) newVal, newTTL := e.preExpirationFn(context.Background(), key)
if newVal != nil { if newVal != nil {
e.Put(key, newVal, newTTL) e.Put(key, newVal, newTTL)
} else { } else {

View File

@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
Describe("preExpiration function", func() { Describe("preExpiration function", func() {
When("function is defined", func() { When("function is defined", func() {
It("should update the value and TTL if function returns values", func() { It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) { fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "v2" v2 := "v2"
return &v2, time.Second return &v2, time.Second
@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
}) })
It("should update the value and TTL if function returns values on cleanup if element is expired", func() { It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
fn := func(key string) (val *string, ttl time.Duration) { fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "val2" v2 := "val2"
return &v2, time.Second return &v2, time.Second
@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
}) })
It("should delete the key if function returns nil", func() { It("should delete the key if function returns nil", func() {
fn := func(key string) (val *string, ttl time.Duration) { fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
return nil, 0 return nil, 0
} }
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn) cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)

View File

@ -25,11 +25,11 @@ type cacheValue[T any] struct {
type OnEntryReloadedCallback func(key string) type OnEntryReloadedCallback func(key string)
// ReloadEntryFn reloads a prefetched entry by key // ReloadEntryFn reloads a prefetched entry by key
type ReloadEntryFn[T any] func(key string) (*T, time.Duration) type ReloadEntryFn[T any] func(ctx context.Context, key string) (*T, time.Duration)
type PrefetchingOptions[T any] struct { type PrefetchingOptions[T any] struct {
Options Options
ReloadFn func(cacheKey string) (*T, time.Duration) ReloadFn ReloadEntryFn[T]
PrefetchThreshold int PrefetchThreshold int
PrefetchExpires time.Duration PrefetchExpires time.Duration
PrefetchMaxItemsCount int PrefetchMaxItemsCount int
@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold) return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
} }
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) { func (e *PrefetchingExpiringLRUCache[T]) onExpired(
ctx context.Context, cacheKey string,
) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) { if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey) loadedVal, ttl := e.reloadFn(ctx, cacheKey)
if loadedVal != nil { if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil { if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey) e.onPrefetchEntryReloaded(cacheKey)

View File

@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
}, },
PrefetchThreshold: 2, PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond, PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) { ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2" v := "v2"
return &v, 50 * time.Millisecond return &v, 50 * time.Millisecond
@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
}, },
PrefetchThreshold: 2, PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond, PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) { ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2" v := "v2"
return &v, 50 * time.Millisecond return &v, 50 * time.Millisecond
@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
Options: Options{ Options: Options{
CleanupInterval: 100 * time.Millisecond, CleanupInterval: 100 * time.Millisecond,
}, },
ReloadFn: func(cacheKey string) (*string, time.Duration) { ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2" v := "v2"
return &v, 50 * time.Millisecond return &v, 50 * time.Millisecond
@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
}, },
PrefetchThreshold: 2, PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond, PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) { ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2" v := "v2"
return &v, 50 * time.Millisecond return &v, 50 * time.Millisecond

View File

@ -154,16 +154,16 @@ func NewBlockingResolver(ctx context.Context,
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{ res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
CleanupInterval: defaultBlockingCleanUpInterval, CleanupInterval: defaultBlockingCleanUpInterval,
}, func(key string) (val *[]net.IP, ttl time.Duration) { }, func(ctx context.Context, key string) (val *[]net.IP, ttl time.Duration) {
return res.queryForFQIdentifierIPs(key) return res.queryForFQIdentifierIPs(ctx, key)
}) })
if res.redisClient != nil { if res.redisClient != nil {
setupRedisEnabledSubscriber(ctx, res) go res.redisSubscriber(ctx)
} }
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) { err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
go res.initFQDNIPCache() go res.initFQDNIPCache(ctx)
}) })
if err != nil { if err != nil {
@ -173,29 +173,27 @@ func NewBlockingResolver(ctx context.Context,
return res, nil return res, nil
} }
func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) { func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
go func() { for {
for { select {
select { case em := <-r.redisClient.EnabledChannel:
case em := <-c.redisClient.EnabledChannel: if em != nil {
if em != nil { r.log().Debug("Received state from redis: ", em)
c.log().Debug("Received state from redis: ", em)
if em.State { if em.State {
c.internalEnableBlocking() r.internalEnableBlocking()
} else { } else {
err := c.internalDisableBlocking(em.Duration, em.Groups) err := r.internalDisableBlocking(em.Duration, em.Groups)
if err != nil { if err != nil {
c.log().Warn("Blocking couldn't be disabled:", err) r.log().Warn("Blocking couldn't be disabled:", err)
}
} }
} }
case <-ctx.Done():
return
} }
case <-ctx.Done():
return
} }
}() }
} }
// RefreshLists triggers the refresh of all black and white lists in the cache // RefreshLists triggers the refresh of all black and white lists in the cache
@ -358,7 +356,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
return false return false
} }
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string, func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []string,
request *model.Request, logger *logrus.Entry, request *model.Request, logger *logrus.Entry,
) (bool, *model.Response, error) { ) (bool, *model.Response, error) {
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request") logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
@ -371,7 +369,7 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 { if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
logger.WithField("groups", groups).Debugf("domain is whitelisted") logger.WithField("groups", groups).Debugf("domain is whitelisted")
resp, err := r.next.Resolve(request) resp, err := r.next.Resolve(ctx, request)
return true, resp, err return true, resp, err
} }
@ -393,18 +391,18 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
} }
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked // Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "blacklist_resolver") logger := log.WithPrefix(request.Log, "blacklist_resolver")
groupsToCheck := r.groupsToCheckForClient(request) groupsToCheck := r.groupsToCheckForClient(request)
if len(groupsToCheck) > 0 { if len(groupsToCheck) > 0 {
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger) handled, resp, err := r.handleBlacklist(ctx, groupsToCheck, request, logger)
if handled { if handled {
return resp, err return resp, err
} }
} }
respFromNext, err := r.next.Resolve(request) respFromNext, err := r.next.Resolve(ctx, request)
if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil { if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil {
for _, rr := range respFromNext.Res.Answer { for _, rr := range respFromNext.Res.Answer {
@ -574,7 +572,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
} }
} }
func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP, time.Duration) { func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
prefixedLog := log.WithPrefix(r.log(), "client_id_cache") prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
var result []net.IP var result []net.IP
@ -582,7 +580,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
var ttl time.Duration var ttl time.Duration
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} { for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(&model.Request{ resp, err := r.next.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)), Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
Log: prefixedLog, Log: prefixedLog,
}) })
@ -606,12 +604,12 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
return &result, ttl return &result, ttl
} }
func (r *BlockingResolver) initFQDNIPCache() { func (r *BlockingResolver) initFQDNIPCache(ctx context.Context) {
identifiers := maps.Keys(r.clientGroupsBlock) identifiers := maps.Keys(r.clientGroupsBlock)
for _, identifier := range identifiers { for _, identifier := range identifiers {
if isFQDN(identifier) { if isFQDN(identifier) {
iPs, ttl := r.queryForFQIdentifierIPs(identifier) iPs, ttl := r.queryForFQIdentifierIPs(ctx, identifier)
r.fqdnIPCache.Put(identifier, iPs, ttl) r.fqdnIPCache.Put(identifier, iPs, ttl)
} }
} }

View File

@ -155,7 +155,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
} }
Bus().Publish(ApplicationStarted, "") Bus().Publish(ApplicationStarted, "")
Eventually(func(g Gomega) { Eventually(func(g Gomega) {
g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))). g.Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
Should(And( Should(And(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"), BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 60)), HaveTTL(BeNumerically("==", 60)),
@ -185,6 +185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Domain is on the black list", func() { When("Domain is on the black list", func() {
It("should block request", func() { It("should block request", func() {
Eventually(sut.Resolve). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")). WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
Should( Should(
SatisfyAll( SatisfyAll(
@ -222,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("client name is defined in client groups block", func() { When("client name is defined in client groups block", func() {
It("should block the A query if domain is on the black list (single)", func() { It("should block the A query if domain is on the black list (single)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -233,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block the A query if domain is on the black list (multipart 1)", func() { It("should block the A query if domain is on the black list (multipart 1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -244,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block the A query if domain is on the black list (multipart 2)", func() { It("should block the A query if domain is on the black list (multipart 2)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -255,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block the A query if domain is on the black list (merged)", func() { It("should block the A query if domain is on the black list (merged)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"), BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
@ -266,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block the AAAA query if domain is on the black list", func() { It("should block the AAAA query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", AAAA, "::"), BeDNSRecord("domain1.com.", AAAA, "::"),
@ -277,18 +278,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block the HTTPS query if domain is on the black list", func() { It("should block the HTTPS query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
Should(HaveReturnCode(dns.RcodeNameError)) Should(HaveReturnCode(dns.RcodeNameError))
}) })
It("should block the MX query if domain is on the black list", func() { It("should block the MX query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
Should(HaveReturnCode(dns.RcodeNameError)) Should(HaveReturnCode(dns.RcodeNameError))
}) })
}) })
When("Client ip is defined in client groups block", func() { When("Client ip is defined in client groups block", func() {
It("should block the query if domain is on the black list", func() { It("should block the query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -301,7 +302,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() { When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
It("should not block the query for 10.43.8.63 if domain is on the black list", func() { It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -313,7 +314,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
}) })
It("should not block the query for 10.43.8.80 if domain is on the black list", func() { It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -328,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() { When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
It("should block the query for 10.43.8.64 if domain is on the black list", func() { It("should block the query for 10.43.8.64 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -339,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block the query for 10.43.8.79 if domain is on the black list", func() { It("should block the query for 10.43.8.79 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -353,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Client has multiple names and for each name a client group block definition exists", func() { When("Client has multiple names and for each name a client group block definition exists", func() {
It("should block query if domain is in one group", func() { It("should block query if domain is in one group", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -364,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
It("should block query if domain is in another group too", func() { It("should block query if domain is in another group too", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"), BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
@ -377,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
When("Client name matches wildcard", func() { When("Client name matches wildcard", func() {
It("should block query if domain is in one group", func() { It("should block query if domain is in one group", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -391,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Default group is defined", func() { When("Default group is defined", func() {
It("should block domains from default group for each client", func() { It("should block domains from default group for each client", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -418,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should return NXDOMAIN if query is blocked", func() { It("should return NXDOMAIN if query is blocked", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -444,7 +445,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should return answer with specified TTL", func() { It("should return answer with specified TTL", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -461,7 +462,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should return custom IP with specified TTL", func() { It("should return custom IP with specified TTL", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "12.12.12.12"), BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
@ -489,7 +490,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should return ipv4 address for A query if query is blocked", func() { It("should return ipv4 address for A query if query is blocked", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "12.12.12.12"), BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
@ -501,7 +502,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should return ipv6 address for AAAA query if query is blocked", func() { It("should return ipv6 address for AAAA query if query is blocked", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"), BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
@ -528,7 +529,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should use fallback for ipv6 and return zero ip", func() { It("should use fallback for ipv6 and return zero ip", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", AAAA, "::"), BeDNSRecord("blocked3.com.", AAAA, "::"),
@ -547,7 +548,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
}) })
It("should block query, if lookup result contains blacklisted IP", func() { It("should block query, if lookup result contains blacklisted IP", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "0.0.0.0"), BeDNSRecord("example.com.", A, "0.0.0.0"),
@ -567,7 +568,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
) )
}) })
It("should block query, if lookup result contains blacklisted IP", func() { It("should block query, if lookup result contains blacklisted IP", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", AAAA, "::"), BeDNSRecord("example.com.", AAAA, "::"),
@ -590,7 +591,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3} mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
}) })
It("should block the query, if response contains a CNAME with domain on a blacklist", func() { It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "0.0.0.0"), BeDNSRecord("example.com.", A, "0.0.0.0"),
@ -617,7 +618,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
} }
}) })
It("Should not be blocked", func() { It("Should not be blocked", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -649,7 +650,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
It("should block everything else except domains on the white list with default group", func() { It("should block everything else except domains on the white list with default group", func() {
By("querying domain on the whitelist", func() { By("querying domain on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -662,7 +663,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("querying another domain, which is not on the whitelist", func() { By("querying another domain, which is not on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("google.com.", A, "0.0.0.0"), BeDNSRecord("google.com.", A, "0.0.0.0"),
@ -678,7 +679,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
It("should block everything else except domains on the white list "+ It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() { "if multiple white list only groups are defined", func() {
By("querying domain on the whitelist", func() { By("querying domain on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -691,7 +692,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("querying another domain, which is not on the whitelist", func() { By("querying another domain, which is not on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"), BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
@ -706,7 +707,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
It("should block everything else except domains on the white list "+ It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() { "if multiple white list only groups are defined", func() {
By("querying domain on the whitelist group 1", func() { By("querying domain on the whitelist group 1", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -719,7 +720,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("querying another domain, which is in the whitelist group 1", func() { By("querying another domain, which is in the whitelist group 1", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -745,7 +746,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
}) })
It("should not block if DNS answer contains IP from the white list", func() { It("should not block if DNS answer contains IP from the white list", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.145.123.145"), BeDNSRecord("example.com.", A, "123.145.123.145"),
@ -775,7 +776,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
When("domain is not on the black list", func() { When("domain is not on the black list", func() {
It("should delegate to next resolver", func() { It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -792,7 +793,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
} }
}) })
It("should delegate to next resolver", func() { It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -819,7 +820,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking is called", func() { When("Disable blocking is called", func() {
It("no query should be blocked", func() { It("no query should be blocked", func() {
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -830,7 +831,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("Perform query to ensure that the blocking status is active (group1)", func() { By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -847,7 +848,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again (defaultGroup)", func() { By("perform the same query again (defaultGroup)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -861,7 +862,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again (group1)", func() { By("perform the same query again (group1)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -880,7 +881,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again (defaultGroup)", func() { By("perform the same query again (defaultGroup)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -893,7 +894,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("Perform query to ensure that the blocking status is active (group1)", func() { By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -908,7 +909,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking for all groups is called with a duration parameter", func() { When("Disable blocking for all groups is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() { It("No query should be blocked only for passed amount of time", func() {
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -918,7 +919,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
By("Perform query to ensure that the blocking status is active (group1)", func() { By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -941,7 +942,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -954,7 +955,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("perform the same query again to ensure that this query will not be blocked (group1)", func() { By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -974,7 +975,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
// wait 1 sec // wait 1 sec
Eventually(enabled, "1s").Should(Receive(BeTrue())) Eventually(enabled, "1s").Should(Receive(BeTrue()))
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -983,7 +984,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
)) ))
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -998,7 +999,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking for one group is called with a duration parameter", func() { When("Disable blocking for one group is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() { It("No query should be blocked only for passed amount of time", func() {
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -1008,7 +1009,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)) ))
}) })
By("Perform query to ensure that the blocking status is active (group1)", func() { By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -1031,7 +1032,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -1042,7 +1043,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}) })
By("perform the same query again to ensure that this query will not be blocked (group1)", func() { By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
// now is blocking disabled, query the url again // now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -1062,7 +1063,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
// wait 1 sec // wait 1 sec
Eventually(enabled, "1s").Should(Receive(BeTrue())) Eventually(enabled, "1s").Should(Receive(BeTrue()))
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"), BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -1071,7 +1072,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
)) ))
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"), BeDNSRecord("domain1.com.", A, "0.0.0.0"),

View File

@ -106,14 +106,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
return b, nil return b, nil
} }
func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) { func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
hostname := r.upstream.Host hostname := r.upstream.Host
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
return newIPSet([]net.IP{ip}), nil return newIPSet([]net.IP{ip}), nil
} }
ips, err := b.resolveUpstream(r, hostname) ips, err := b.resolveUpstream(ctx, r, hostname)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,10 +121,10 @@ func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
return newIPSet(ips), nil return newIPSet(ips), nil
} }
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) { func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) {
// Use system resolver if no bootstrap is configured // Use system resolver if no bootstrap is configured
if b.resolver == nil { if b.resolver == nil {
ctx, cancel := context.WithTimeout(context.Background(), b.timeout) ctx, cancel := context.WithTimeout(ctx, b.timeout)
defer cancel() defer cancel()
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host) return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
@ -135,7 +135,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
return ips, nil return ips, nil
} }
return b.resolve(host, b.connectIPVersion.QTypes()) return b.resolve(ctx, host, b.connectIPVersion.QTypes())
} }
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames // NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
@ -175,7 +175,7 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
} }
// Resolve the host with the bootstrap DNS // Resolve the host with the bootstrap DNS
ips, err := b.resolve(host, qTypes) ips, err := b.resolve(ctx, host, qTypes)
if err != nil { if err != nil {
logger.Errorf("resolve error: %s", err) logger.Errorf("resolve error: %s", err)
@ -192,11 +192,11 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
return b.dialer.DialContext(ctx, network, addrWithIP) return b.dialer.DialContext(ctx, network, addrWithIP)
} }
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) { func (b *Bootstrap) resolve(ctx context.Context, hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
ips = make([]net.IP, 0, len(qTypes)) ips = make([]net.IP, 0, len(qTypes))
for _, qType := range qTypes { for _, qType := range qTypes {
qIPs, qErr := b.resolveType(hostname, qType) qIPs, qErr := b.resolveType(ctx, hostname, qType)
if qErr != nil { if qErr != nil {
err = multierror.Append(err, qErr) err = multierror.Append(err, qErr)
@ -213,7 +213,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
return return
} }
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) { func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.Type) (ips []net.IP, err error) {
if ip := net.ParseIP(hostname); ip != nil { if ip := net.ParseIP(hostname); ip != nil {
return []net.IP{ip}, nil return []net.IP{ip}, nil
} }
@ -223,7 +223,7 @@ func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP,
Log: b.log, Log: b.log,
} }
rsp, err := b.resolver.Resolve(&req) rsp, err := b.resolver.Resolve(ctx, &req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -72,7 +72,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
}, },
} }
_, err := sut.resolveUpstream(nil, "example.com") _, err := sut.resolveUpstream(ctx, nil, "example.com")
Expect(err).ShouldNot(Succeed()) Expect(err).ShouldNot(Succeed())
Expect(usedSystemResolver).Should(Receive(BeTrue())) Expect(usedSystemResolver).Should(Receive(BeTrue()))
}) })
@ -244,7 +244,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
When("called from bootstrap.upstream", func() { When("called from bootstrap.upstream", func() {
It("uses hardcoded IPs", func() { It("uses hardcoded IPs", func() {
ips, err := sut.resolveUpstream(bootstrapUpstream, "host") ips, err := sut.resolveUpstream(ctx, bootstrapUpstream, "host")
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs)) Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs))
@ -253,7 +253,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
When("hostname is an IP", func() { When("hostname is an IP", func() {
It("returns immediately", func() { It("returns immediately", func() {
ips, err := sut.resolve("0.0.0.0", config.IPVersionDual.QTypes()) ips, err := sut.resolve(ctx, "0.0.0.0", config.IPVersionDual.QTypes())
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(ips).Should(ContainElement(net.IPv4zero)) Expect(ips).Should(ContainElement(net.IPv4zero))
@ -269,7 +269,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil) bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
ips, err := sut.resolve("localhost", []dns.Type{AAAA}) ips, err := sut.resolve(ctx, "localhost", []dns.Type{AAAA})
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(ips).Should(HaveLen(1)) Expect(ips).Should(HaveLen(1))
@ -283,7 +283,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr) bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr)
ips, err := sut.resolve("localhost", []dns.Type{A}) ips, err := sut.resolve(ctx, "localhost", []dns.Type{A})
Expect(err).ShouldNot(Succeed()) Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
@ -297,7 +297,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil) bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A}) ips, err := sut.resolve(ctx, "unknownhost.invalid", []dns.Type{A})
Expect(err).ShouldNot(Succeed()) Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no such host")) Expect(err.Error()).Should(ContainSubstring("no such host"))
@ -329,7 +329,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
r := newUpstreamResolverUnchecked(upstream, sut) r := newUpstreamResolverUnchecked(upstream, sut)
rsp, err := r.Resolve(mainReq) rsp, err := r.Resolve(ctx, mainReq)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1)) Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1))
Expect(rsp.Res.Question[0].Name).Should(Equal("example.com.")) Expect(rsp.Res.Question[0].Name).Should(Equal("example.com."))
@ -373,7 +373,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
// implicit expectation of 0 bootstrapUpstream.Resolve calls // implicit expectation of 0 bootstrapUpstream.Resolve calls
_, err = t.DialContext(context.Background(), "ip", "!bad-addr!") _, err = t.DialContext(ctx, "ip", "!bad-addr!")
Expect(err).ShouldNot(Succeed()) Expect(err).ShouldNot(Succeed())
}) })
@ -384,7 +384,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
t := sut.NewHTTPTransport() t := sut.NewHTTPTransport()
_, err = t.DialContext(context.Background(), "ip", "abc:123") _, err = t.DialContext(ctx, "ip", "abc:123")
Expect(err).ShouldNot(Succeed()) Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
@ -397,7 +397,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
t := sut.NewHTTPTransport() t := sut.NewHTTPTransport()
_, err = t.DialContext(context.Background(), "ip", "abc:123") _, err = t.DialContext(ctx, "ip", "abc:123")
Expect(err).ShouldNot(Succeed()) Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no such host")) Expect(err.Error()).Should(ContainSubstring("no such host"))
@ -437,7 +437,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
Describe("resolve", func() { Describe("resolve", func() {
AfterEach(func() { AfterEach(func() {
_, err := sut.resolveUpstream(nil, "example.com") _, err := sut.resolveUpstream(ctx, nil, "example.com")
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
@ -501,7 +501,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
AfterEach(func() { AfterEach(func() {
t := sut.NewHTTPTransport() t := sut.NewHTTPTransport()
conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0") conn, err := t.DialContext(ctx, dialIPVersion.Net(), "localhost:0")
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(conn).Should(Equal(aMockConn)) Expect(conn).Should(Equal(aMockConn))
@ -583,7 +583,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
}) })
It("uses both", func() { It("uses both", func() {
_, err := sut.resolve("example.com.", []dns.Type{dns.Type(dns.TypeA)}) _, err := sut.resolve(ctx, "example.com.", []dns.Type{dns.Type(dns.TypeA)})
Expect(err).To(Succeed()) Expect(err).To(Succeed())

View File

@ -59,7 +59,7 @@ func newCachingResolver(ctx context.Context,
configureCaches(ctx, c, &cfg) configureCaches(ctx, c, &cfg)
if c.redisClient != nil { if c.redisClient != nil {
setupRedisCacheSubscriber(ctx, c) go c.redisSubscriber(ctx)
c.redisClient.GetRedisCache() c.redisClient.GetRedisCache()
} }
@ -105,14 +105,14 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
} }
} }
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Duration) { func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey) qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log() logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(dns.Fqdn(domainName), qType, logger) req := newRequest(dns.Fqdn(domainName), qType, logger)
response, err := r.next.Resolve(req) response, err := r.next.Resolve(ctx, req)
if err == nil { if err == nil {
if response.Res.Rcode == dns.RcodeSuccess { if response.Res.Rcode == dns.RcodeSuccess {
@ -132,22 +132,20 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Durat
return nil, 0 return nil, 0
} }
func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) { func (r *CachingResolver) redisSubscriber(ctx context.Context) {
go func() { for {
for { select {
select { case rc := <-r.redisClient.CacheChannel:
case rc := <-c.redisClient.CacheChannel: if rc != nil {
if rc != nil { r.log().Debug("Received key from redis: ", rc.Key)
c.log().Debug("Received key from redis: ", rc.Key) ttl := r.adjustTTLs(rc.Response.Res.Answer)
ttl := c.adjustTTLs(rc.Response.Res.Answer) r.putInCache(rc.Key, rc.Response, ttl, false)
c.putInCache(rc.Key, rc.Response, ttl, false)
}
case <-ctx.Done():
return
} }
case <-ctx.Done():
return
} }
}() }
} }
// LogConfig implements `config.Configurable`. // LogConfig implements `config.Configurable`.
@ -159,13 +157,13 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
// Resolve checks if the current query should use the cache and if the result is already in // Resolve checks if the current query should use the cache and if the result is already in
// the cache and returns it or delegates to the next resolver // the cache and returns it or delegates to the next resolver
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) { func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver") logger := log.WithPrefix(request.Log, "caching_resolver")
if !r.IsEnabled() || !isRequestCacheable(request) { if !r.IsEnabled() || !isRequestCacheable(request) {
logger.Debug("skip cache") logger.Debug("skip cache")
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
for _, question := range request.Req.Question { for _, question := range request.Req.Question {
@ -191,7 +189,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
} }
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver") logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
response, err = r.next.Resolve(request) response, err = r.next.Resolve(ctx, request)
if err == nil { if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer) cacheTTL := r.adjustTTLs(response.Res.Answer)
@ -319,7 +317,7 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
} }
} }
func (r *CachingResolver) FlushCaches() { func (r *CachingResolver) FlushCaches(context.Context) {
r.log().Debug("flush caches") r.log().Debug("flush caches")
r.resultCache.Clear() r.resultCache.Clear()
} }

View File

@ -114,7 +114,7 @@ var _ = Describe("CachingResolver", func() {
})).Should(Succeed()) })).Should(Succeed())
// first request // first request
_, _ = sut.Resolve(newRequest("example.com.", A)) _, _ = sut.Resolve(ctx, newRequest("example.com.", A))
// Domain is not prefetched // Domain is not prefetched
Expect(domainPrefetched).ShouldNot(Receive()) Expect(domainPrefetched).ShouldNot(Receive())
@ -124,7 +124,7 @@ var _ = Describe("CachingResolver", func() {
// now query again > threshold // now query again > threshold
for i := 0; i < prefetchThreshold+1; i++ { for i := 0; i < prefetchThreshold+1; i++ {
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
} }
@ -132,7 +132,7 @@ var _ = Describe("CachingResolver", func() {
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true))) Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
// and it should hit from prefetch cache // and it should hit from prefetch cache
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
@ -156,7 +156,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("should cache response and use response's TTL for multiple records", func() { It("should cache response and use response's TTL for multiple records", func() {
By("first request", func() { By("first request", func() {
result, err := sut.Resolve(newRequest("example.com.", A)) result, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(result). Expect(result).
Should( Should(
@ -176,7 +176,7 @@ var _ = Describe("CachingResolver", func() {
By("second request", func() { By("second request", func() {
Eventually(func(g Gomega) { Eventually(func(g Gomega) {
result, err := sut.Resolve(newRequest("example.com.", A)) result, err := sut.Resolve(ctx, newRequest("example.com.", A))
g.Expect(err).Should(Succeed()) g.Expect(err).Should(Succeed())
g.Expect(result). g.Expect(result).
Should( Should(
@ -218,7 +218,7 @@ var _ = Describe("CachingResolver", func() {
_ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) { _ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) {
totalCacheCount <- d totalCacheCount <- d
}) })
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -239,7 +239,7 @@ var _ = Describe("CachingResolver", func() {
domain <- true domain <- true
}) })
g.Expect(sut.Resolve(newRequest("example.com.", A))). g.Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
@ -264,7 +264,7 @@ var _ = Describe("CachingResolver", func() {
It("should cache response and use min caching time as TTL", func() { It("should cache response and use min caching time as TTL", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -276,7 +276,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", A)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
@ -299,7 +301,7 @@ var _ = Describe("CachingResolver", func() {
It("should cache response and use min caching time as TTL", func() { It("should cache response and use min caching time as TTL", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -310,7 +312,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
@ -344,7 +348,7 @@ var _ = Describe("CachingResolver", func() {
It("Shouldn't cache any responses", func() { It("Shouldn't cache any responses", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -355,7 +359,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -378,7 +384,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("should cache response and use max caching time as TTL if response TTL is bigger", func() { It("should cache response and use max caching time as TTL if response TTL is bigger", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -389,7 +395,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
@ -417,7 +425,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("should cache response and return 0 TTL if entry is expired", func() { It("should cache response and return 0 TTL if entry is expired", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -430,7 +438,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)). Eventually(sut.Resolve, "2s").
WithContext(ctx).
WithArguments(newRequest("example.com.", A)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
@ -461,7 +471,7 @@ var _ = Describe("CachingResolver", func() {
}) })
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError), HaveReturnCode(dns.RcodeNameError),
@ -472,7 +482,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED NEGATIVE"), HaveReason("CACHED NEGATIVE"),
@ -495,7 +507,7 @@ var _ = Describe("CachingResolver", func() {
It("response shouldn't be cached", func() { It("response shouldn't be cached", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError), HaveReturnCode(dns.RcodeNameError),
@ -506,7 +518,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReason(""), HaveReason(""),
@ -529,7 +543,7 @@ var _ = Describe("CachingResolver", func() {
It("response should be cached", func() { It("response should be cached", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -540,7 +554,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"), HaveReason("CACHED"),
@ -563,7 +579,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("Should be cached", func() { It("Should be cached", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", MX))). Expect(sut.Resolve(ctx, newRequest("google.de.", MX))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -575,7 +591,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("google.de.", MX)).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"), HaveReason("CACHED"),
@ -599,7 +617,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("Should not be cached", func() { It("Should not be cached", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))). Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -611,7 +629,7 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))). Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -633,7 +651,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("Should not be cached", func() { It("Should not be cached", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))). Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -645,7 +663,7 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))). Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -671,7 +689,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("Should not be cached", func() { It("Should not be cached", func() {
By("first request", func() { By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))). Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess), HaveReturnCode(dns.RcodeSuccess),
@ -688,7 +706,9 @@ var _ = Describe("CachingResolver", func() {
}) })
By("second request", func() { By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", A)). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("google.de.", A)).
Should(SatisfyAll( Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"), HaveReason("CACHED"),
@ -750,7 +770,7 @@ var _ = Describe("CachingResolver", func() {
}) })
It("put in redis", func() { It("put in redis", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(HaveResponseType(ResponseTypeRESOLVED)) Should(HaveResponseType(ResponseTypeRESOLVED))
Eventually(func() []string { Eventually(func() []string {
@ -772,7 +792,9 @@ var _ = Describe("CachingResolver", func() {
} }
redisClient.CacheChannel <- redisMockMsg redisClient.CacheChannel <- redisMockMsg
Eventually(sut.Resolve).WithArguments(request). Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(request).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeCACHED), HaveResponseType(ResponseTypeCACHED),

View File

@ -32,7 +32,7 @@ func NewClientNamesResolver(ctx context.Context,
) (cr *ClientNamesResolver, err error) { ) (cr *ClientNamesResolver, err error) {
var r Resolver var r Resolver
if !cfg.Upstream.IsDefault() { if !cfg.Upstream.IsDefault() {
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams) r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -59,17 +59,17 @@ func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
} }
// Resolve tries to resolve the client name from the ip address // Resolve tries to resolve the client name from the ip address
func (r *ClientNamesResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
clientNames := r.getClientNames(request) clientNames := r.getClientNames(ctx, request)
request.ClientNames = clientNames request.ClientNames = clientNames
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; ")) request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
// returns names of client // returns names of client
func (r *ClientNamesResolver) getClientNames(request *model.Request) []string { func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model.Request) []string {
if request.RequestClientID != "" { if request.RequestClientID != "" {
return []string{request.RequestClientID} return []string{request.RequestClientID}
} }
@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
return cpy return cpy
} }
names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver")) names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))
r.cache.Put(ip.String(), &names, time.Hour) r.cache.Put(ip.String(), &names, time.Hour)
@ -111,7 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
} }
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise // tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) { func (r *ClientNamesResolver) resolveClientNames(
ctx context.Context, ip net.IP, logger *logrus.Entry,
) (result []string) {
// try client mapping first // try client mapping first
result = r.getNameFromIPMapping(ip, result) result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 { if len(result) > 0 {
@ -124,7 +126,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
reverse, _ := dns.ReverseAddr(ip.String()) reverse, _ := dns.ReverseAddr(ip.String())
resp, err := r.externalResolver.Resolve(&model.Request{ resp, err := r.externalResolver.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)), Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
Log: logger, Log: logger,
}) })

View File

@ -22,8 +22,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
sut *ClientNamesResolver sut *ClientNamesResolver
sutConfig config.ClientLookupConfig sutConfig config.ClientLookupConfig
m *mockResolver m *mockResolver
ctx context.Context
cancelFn context.CancelFunc ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -34,6 +35,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
JustBeforeEach(func() { JustBeforeEach(func() {
var err error var err error
ctx, cancelFn = context.WithCancel(context.Background()) ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn) DeferCleanup(cancelFn)
@ -71,7 +73,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should use clientID if set", func() { It("should use clientID if set", func() {
request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123") request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -82,7 +84,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}) })
It("should use IP as fallback if clientID not set", func() { It("should use IP as fallback if clientID not set", func() {
request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "") request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -112,7 +114,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve defined name with ipv4 address", func() { It("should resolve defined name with ipv4 address", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -124,7 +126,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve defined name with ipv6 address", func() { It("should resolve defined name with ipv6 address", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -135,7 +137,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}) })
It("should resolve multiple names defined names", func() { It("should resolve multiple names defined names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -168,7 +170,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve client name", func() { It("should resolve client name", func() {
By("first request", func() { By("first request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -180,7 +182,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
By("second request", func() { By("second request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -198,7 +200,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
By("third request", func() { By("third request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -223,7 +225,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve all client names", func() { It("should resolve all client names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -251,7 +253,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve client name", func() { It("should resolve client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -272,7 +274,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve the client name depending to defined order", func() { It("should resolve the client name depending to defined order", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -298,7 +300,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should use fallback for client name", func() { It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -318,7 +320,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}) })
It("should use fallback for client name", func() { It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -335,7 +337,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}) })
It("should resolve no names", func() { It("should resolve no names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -351,7 +353,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
}) })
It("should use fallback for client name", func() { It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -23,14 +24,14 @@ type ConditionalUpstreamResolver struct {
// NewConditionalUpstreamResolver returns new resolver instance // NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver( func NewConditionalUpstreamResolver(
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ConditionalUpstreamResolver, error) { ) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams)) m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstreams := range cfg.Mapping.Upstreams { for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams} cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams) r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,7 +49,9 @@ func NewConditionalUpstreamResolver(
return &r, nil return &r, nil
} }
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) { func (r *ConditionalUpstreamResolver) processRequest(
ctx context.Context, request *model.Request,
) (bool, *model.Response, error) {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0]) domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
domain := domainFromQuestion domain := domainFromQuestion
@ -56,7 +59,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
// try with domain with and without sub-domains // try with domain with and without sub-domains
for len(domain) > 0 { for len(domain) > 0 {
if resolver, found := r.mapping[domain]; found { if resolver, found := r.mapping[domain]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request) resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
return true, resp, err return true, resp, err
} }
@ -68,7 +71,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
} }
} }
} else if resolver, found := r.mapping["."]; found { } else if resolver, found := r.mapping["."]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request) resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
return true, resp, err return true, resp, err
} }
@ -77,11 +80,11 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
} }
// Resolve uses the conditional resolver to resolve the query // Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "conditional_resolver") logger := log.WithPrefix(request.Log, "conditional_resolver")
if len(r.mapping) > 0 { if len(r.mapping) > 0 {
resolved, resp, err := r.processRequest(request) resolved, resp, err := r.processRequest(ctx, request)
if resolved { if resolved {
return resp, err return resp, err
} }
@ -89,17 +92,17 @@ func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Re
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
func (r *ConditionalUpstreamResolver) internalResolve(reso Resolver, doFQ, do string, func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso Resolver, doFQ, do string,
req *model.Request, req *model.Request,
) (*model.Response, error) { ) (*model.Response, error) {
// internal request resolution // internal request resolution
logger := log.WithPrefix(req.Log, "conditional_resolver") logger := log.WithPrefix(req.Log, "conditional_resolver")
req.Req.Question[0].Name = dns.Fqdn(doFQ) req.Req.Question[0].Name = dns.Fqdn(doFQ)
response, err := reso.Resolve(req) response, err := reso.Resolve(ctx, req)
if err == nil { if err == nil {
response.Reason = "CONDITIONAL" response.Reason = "CONDITIONAL"

View File

@ -19,6 +19,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
var ( var (
sut *ConditionalUpstreamResolver sut *ConditionalUpstreamResolver
m *mockResolver m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -28,6 +31,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
@ -59,7 +65,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
}) })
DeferCleanup(refuseTestUpstream.Close) DeferCleanup(refuseTestUpstream.Close)
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
Mapping: config.ConditionalUpstreamMapping{ Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{ Upstreams: map[string][]config.Upstream{
"fritz.box": {fbTestUpstream.Start()}, "fritz.box": {fbTestUpstream.Start()},
@ -93,7 +99,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
Describe("Resolve conditional DNS queries via defined DNS server", func() { Describe("Resolve conditional DNS queries via defined DNS server", func() {
When("conditional resolver returns error code", func() { When("conditional resolver returns error code", func() {
It("Should be returned without changes", func() { It("Should be returned without changes", func() {
Expect(sut.Resolve(newRequest("refused.domain.", A))). Expect(sut.Resolve(ctx, newRequest("refused.domain.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -109,7 +115,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
When("Query is exact equal defined condition in mapping", func() { When("Query is exact equal defined condition in mapping", func() {
Context("first mapping entry", func() { Context("first mapping entry", func() {
It("Should resolve the IP of conditional DNS", func() { It("Should resolve the IP of conditional DNS", func() {
Expect(sut.Resolve(newRequest("fritz.box.", A))). Expect(sut.Resolve(ctx, newRequest("fritz.box.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("fritz.box.", A, "123.124.122.122"), BeDNSRecord("fritz.box.", A, "123.124.122.122"),
@ -125,7 +131,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
}) })
Context("last mapping entry", func() { Context("last mapping entry", func() {
It("Should resolve the IP of conditional DNS", func() { It("Should resolve the IP of conditional DNS", func() {
Expect(sut.Resolve(newRequest("other.box.", A))). Expect(sut.Resolve(ctx, newRequest("other.box.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("other.box.", A, "192.192.192.192"), BeDNSRecord("other.box.", A, "192.192.192.192"),
@ -141,7 +147,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
}) })
When("Query is a subdomain of defined condition in mapping", func() { When("Query is a subdomain of defined condition in mapping", func() {
It("Should resolve the IP of subdomain", func() { It("Should resolve the IP of subdomain", func() {
Expect(sut.Resolve(newRequest("test.fritz.box.", A))). Expect(sut.Resolve(ctx, newRequest("test.fritz.box.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("test.fritz.box.", A, "123.124.122.122"), BeDNSRecord("test.fritz.box.", A, "123.124.122.122"),
@ -156,7 +162,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
}) })
When("Query is not fqdn and . condition is defined in mapping", func() { When("Query is not fqdn and . condition is defined in mapping", func() {
It("Should resolve the IP of .", func() { It("Should resolve the IP of .", func() {
Expect(sut.Resolve(newRequest("test.", A))). Expect(sut.Resolve(ctx, newRequest("test.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("test.", A, "168.168.168.168"), BeDNSRecord("test.", A, "168.168.168.168"),
@ -173,7 +179,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
Describe("Delegation to next resolver", func() { Describe("Delegation to next resolver", func() {
When("Query doesn't match defined mapping", func() { When("Query doesn't match defined mapping", func() {
It("should delegate to next resolver", func() { It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequest("google.com.", A))). Expect(sut.Resolve(ctx, newRequest("google.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -186,11 +192,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
When("upstream is invalid", func() { When("upstream is invalid", func() {
It("errors during construction", func() { It("errors during construction", func() {
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
Mapping: config.ConditionalUpstreamMapping{ Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{ Upstreams: map[string][]config.Upstream{
".": {config.Upstream{Host: "example.com"}}, ".": {config.Upstream{Host: "example.com"}},

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"net" "net"
"strings" "strings"
@ -123,7 +124,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
} }
// Resolve uses internal mapping to resolve the query // Resolve uses internal mapping to resolve the query
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "custom_dns_resolver") logger := log.WithPrefix(request.Log, "custom_dns_resolver")
reverseResp := r.handleReverseDNS(request) reverseResp := r.handleReverseDNS(request)
@ -140,5 +141,5 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"net" "net"
"time" "time"
@ -21,6 +22,9 @@ var _ = Describe("CustomDNSResolver", func() {
sut *CustomDNSResolver sut *CustomDNSResolver
m *mockResolver m *mockResolver
cfg config.CustomDNSConfig cfg config.CustomDNSConfig
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -30,6 +34,9 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
cfg = config.CustomDNSConfig{ cfg = config.CustomDNSConfig{
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{ Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{
"custom.domain": {net.ParseIP("192.168.143.123")}, "custom.domain": {net.ParseIP("192.168.143.123")},
@ -73,7 +80,7 @@ var _ = Describe("CustomDNSResolver", func() {
Context("filterUnmappedTypes is true", func() { Context("filterUnmappedTypes is true", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = true }) BeforeEach(func() { cfg.FilterUnmappedTypes = true })
It("defined ip4 query should be resolved", func() { It("defined ip4 query should be resolved", func() {
Expect(sut.Resolve(newRequest("custom.domain.", A))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("custom.domain.", A, "192.168.143.123"), BeDNSRecord("custom.domain.", A, "192.168.143.123"),
@ -86,7 +93,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
}) })
It("TXT query for defined mapping should return NOERROR and empty result", func() { It("TXT query for defined mapping should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("custom.domain.", TXT))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -98,7 +105,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
}) })
It("ip6 query should return NOERROR and empty result", func() { It("ip6 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -114,7 +121,7 @@ var _ = Describe("CustomDNSResolver", func() {
Context("filterUnmappedTypes is false", func() { Context("filterUnmappedTypes is false", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = false }) BeforeEach(func() { cfg.FilterUnmappedTypes = false })
It("defined ip4 query should be resolved", func() { It("defined ip4 query should be resolved", func() {
Expect(sut.Resolve(newRequest("custom.domain.", A))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("custom.domain.", A, "192.168.143.123"), BeDNSRecord("custom.domain.", A, "192.168.143.123"),
@ -127,7 +134,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
}) })
It("TXT query for defined mapping should be delegated to next resolver", func() { It("TXT query for defined mapping should be delegated to next resolver", func() {
Expect(sut.Resolve(newRequest("custom.domain.", TXT))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -139,7 +146,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
}) })
It("ip6 query should return NOERROR and empty result", func() { It("ip6 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))). Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -154,7 +161,7 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
When("Ip 6 mapping is defined for custom domain ", func() { When("Ip 6 mapping is defined for custom domain ", func() {
It("ip6 query should be resolved", func() { It("ip6 query should be resolved", func() {
Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))). Expect(sut.Resolve(ctx, newRequest("ip6.domain.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"), BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
@ -170,7 +177,7 @@ var _ = Describe("CustomDNSResolver", func() {
When("Multiple IPs are defined for custom domain ", func() { When("Multiple IPs are defined for custom domain ", func() {
It("all IPs for the current type should be returned", func() { It("all IPs for the current type should be returned", func() {
By("IPv6 query", func() { By("IPv6 query", func() {
Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))). Expect(sut.Resolve(ctx, newRequest("multiple.ips.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"), BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
@ -185,7 +192,7 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
By("IPv4 query", func() { By("IPv4 query", func() {
Expect(sut.Resolve(newRequest("multiple.ips.", A))). Expect(sut.Resolve(ctx, newRequest("multiple.ips.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
WithTransform(ToAnswer, SatisfyAll( WithTransform(ToAnswer, SatisfyAll(
@ -207,7 +214,7 @@ var _ = Describe("CustomDNSResolver", func() {
When("Reverse DNS request is received", func() { When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() { It("should resolve the defined domain name", func() {
By("ipv4", func() { By("ipv4", func() {
Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))). Expect(sut.Resolve(ctx, newRequest("123.143.168.192.in-addr.arpa.", PTR))).
Should( Should(
SatisfyAll( SatisfyAll(
WithTransform(ToAnswer, SatisfyAll( WithTransform(ToAnswer, SatisfyAll(
@ -226,7 +233,7 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
By("ipv6", func() { By("ipv6", func() {
Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.", Expect(sut.Resolve(ctx, newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
PTR))). PTR))).
Should( Should(
SatisfyAll( SatisfyAll(
@ -250,7 +257,7 @@ var _ = Describe("CustomDNSResolver", func() {
}) })
When("Domain mapping is defined", func() { When("Domain mapping is defined", func() {
It("subdomain must also match", func() { It("subdomain must also match", func() {
Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))). Expect(sut.Resolve(ctx, newRequest("ABC.CUSTOM.DOMAIN.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"), BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"),
@ -268,7 +275,7 @@ var _ = Describe("CustomDNSResolver", func() {
Describe("Delegating to next resolver", func() { Describe("Delegating to next resolver", func() {
When("no mapping for domain exist", func() { When("no mapping for domain exist", func() {
It("should delegate to next resolver", func() { It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"fmt" "fmt"
"net" "net"
@ -46,7 +47,7 @@ func NewECSResolver(cfg config.ECS) ChainedResolver {
// Resolve adds the subnet information as EDNS0 option to the request of the next resolver // Resolve adds the subnet information as EDNS0 option to the request of the next resolver
// and sets the client IP from the EDNS0 option to the request if this option is enabled // and sets the client IP from the EDNS0 option to the request if this option is enabled
func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *ECSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if r.cfg.IsEnabled() { if r.cfg.IsEnabled() {
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req) so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req)
// Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set // Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set
@ -67,7 +68,7 @@ func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
} }
} }
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
// setSubnet appends the subnet information to the request as EDNS0 option // setSubnet appends the subnet information to the request as EDNS0 option

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"net" "net"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -25,6 +26,9 @@ var _ = Describe("EcsResolver", func() {
err error err error
origIP net.IP origIP net.IP
ecsIP net.IP ecsIP net.IP
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -34,6 +38,9 @@ var _ = Describe("EcsResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
err = defaults.Set(&sutConfig) err = defaults.Set(&sutConfig)
Expect(err).ShouldNot(HaveOccurred()) Expect(err).ShouldNot(HaveOccurred())
@ -86,13 +93,13 @@ var _ = Describe("EcsResolver", func() {
addEcsOption(request.Req, ecsIP, ecsMaskIPv4) addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
m.ResolveFn = func(req *Request) (*Response, error) { m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(ecsIP)) Expect(req.ClientIP).Should(Equal(ecsIP))
return respondWith(mockAnswer), nil return respondWith(mockAnswer), nil
} }
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -107,13 +114,13 @@ var _ = Describe("EcsResolver", func() {
addEcsOption(request.Req, ecsIP, 24) addEcsOption(request.Req, ecsIP, 24)
m.ResolveFn = func(req *Request) (*Response, error) { m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP)) Expect(req.ClientIP).Should(Equal(origIP))
return respondWith(mockAnswer), nil return respondWith(mockAnswer), nil
} }
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -134,14 +141,14 @@ var _ = Describe("EcsResolver", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
request.ClientIP = origIP request.ClientIP = origIP
m.ResolveFn = func(req *Request) (*Response, error) { m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP)) Expect(req.ClientIP).Should(Equal(origIP))
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil return respondWith(mockAnswer), nil
} }
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -154,13 +161,13 @@ var _ = Describe("EcsResolver", func() {
request := newRequest("example.com.", AAAA) request := newRequest("example.com.", AAAA)
request.ClientIP = net.ParseIP("2001:db8::68") request.ClientIP = net.ParseIP("2001:db8::68")
m.ResolveFn = func(req *Request) (*Response, error) { m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil return respondWith(mockAnswer), nil
} }
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -25,12 +27,12 @@ func NewEDEResolver(cfg config.EDE) *EDEResolver {
// Resolve adds the reason as EDNS0 option to the response of the next resolver // Resolve adds the reason as EDNS0 option to the response of the next resolver
// if it is enabled in the configuration // if it is enabled in the configuration
func (r *EDEResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *EDEResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if !r.cfg.Enable { if !r.cfg.Enable {
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
resp, err := r.next.Resolve(request) resp, err := r.next.Resolve(ctx, request)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,6 +2,7 @@
package resolver package resolver
import ( import (
"context"
"errors" "errors"
"math" "math"
@ -24,6 +25,9 @@ var _ = Describe("EdeResolver", func() {
sutConfig config.EDE sutConfig config.EDE
m *mockResolver m *mockResolver
mockAnswer *dns.Msg mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -33,6 +37,9 @@ var _ = Describe("EdeResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg) mockAnswer = new(dns.Msg)
}) })
@ -57,7 +64,7 @@ var _ = Describe("EdeResolver", func() {
} }
}) })
It("shouldn't add EDE information", func() { It("shouldn't add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -89,7 +96,7 @@ var _ = Describe("EdeResolver", func() {
} }
It("should add EDE information", func() { It("should add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -115,7 +122,7 @@ var _ = Describe("EdeResolver", func() {
}) })
It("shouldn't add EDE information", func() { It("shouldn't add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -137,7 +144,7 @@ var _ = Describe("EdeResolver", func() {
}) })
It("should return it", func() { It("should return it", func() {
resp, err := sut.Resolve(newRequest("example.com", A)) resp, err := sut.Resolve(ctx, newRequest("example.com", A))
Expect(resp).To(BeNil()) Expect(resp).To(BeNil())
Expect(err).To(Equal(resolveErr)) Expect(err).To(Equal(resolveErr))
}) })

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -21,7 +23,7 @@ func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver {
} }
} }
func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *FilteringResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
qType := request.Req.Question[0].Qtype qType := request.Req.Question[0].Qtype
if r.cfg.QueryTypes.Contains(dns.Type(qType)) { if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
response := new(dns.Msg) response := new(dns.Msg)
@ -30,5 +32,5 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er
return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil
} }
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/log"
@ -18,6 +20,9 @@ var _ = Describe("FilteringResolver", func() {
sutConfig config.FilteringConfig sutConfig config.FilteringConfig
m *mockResolver m *mockResolver
mockAnswer *dns.Msg mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -27,6 +32,9 @@ var _ = Describe("FilteringResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg) mockAnswer = new(dns.Msg)
}) })
@ -60,7 +68,7 @@ var _ = Describe("FilteringResolver", func() {
} }
}) })
It("Should delegate to next resolver if request query has other type", func() { It("Should delegate to next resolver if request query has other type", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -72,7 +80,7 @@ var _ = Describe("FilteringResolver", func() {
Expect(m.Calls).Should(HaveLen(1)) Expect(m.Calls).Should(HaveLen(1))
}) })
It("Should return empty answer for defined query type", func() { It("Should return empty answer for defined query type", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -90,7 +98,7 @@ var _ = Describe("FilteringResolver", func() {
sutConfig = config.FilteringConfig{} sutConfig = config.FilteringConfig{}
}) })
It("Should return empty answer without error", func() { It("Should return empty answer without error", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))). Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"strings" "strings"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -22,7 +23,7 @@ func NewFQDNOnlyResolver(cfg config.FQDNOnly) *FQDNOnlyResolver {
} }
} }
func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *FQDNOnlyResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if r.IsEnabled() { if r.IsEnabled() {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0]) domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
if !strings.Contains(domainFromQuestion, ".") { if !strings.Contains(domainFromQuestion, ".") {
@ -33,5 +34,5 @@ func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, err
} }
} }
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/log"
@ -17,6 +19,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
sutConfig config.FQDNOnly sutConfig config.FQDNOnly
m *mockResolver m *mockResolver
mockAnswer *dns.Msg mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -26,6 +31,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg) mockAnswer = new(dns.Msg)
}) })
@ -57,7 +65,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
sutConfig = config.FQDNOnly{Enable: true} sutConfig = config.FQDNOnly{Enable: true}
}) })
It("Should delegate to next resolver if request query is fqdn", func() { It("Should delegate to next resolver if request query is fqdn", func() {
Expect(sut.Resolve(newRequest("example.com", A))). Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -69,7 +77,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
Expect(m.Calls).Should(HaveLen(1)) Expect(m.Calls).Should(HaveLen(1))
}) })
It("Should return NXDOMAIN if request query is not fqdn", func() { It("Should return NXDOMAIN if request query is not fqdn", func() {
Expect(sut.Resolve(newRequest("example", AAAA))). Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -103,7 +111,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
sutConfig = config.FQDNOnly{Enable: false} sutConfig = config.FQDNOnly{Enable: false}
}) })
It("Should delegate to next resolver if request query is fqdn", func() { It("Should delegate to next resolver if request query is fqdn", func() {
Expect(sut.Resolve(newRequest("example.com", A))). Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -115,7 +123,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
Expect(m.Calls).Should(HaveLen(1)) Expect(m.Calls).Should(HaveLen(1))
}) })
It("Should delegate to next resolver if request query is not fqdn", func() { It("Should delegate to next resolver if request query is not fqdn", func() {
Expect(sut.Resolve(newRequest("example", AAAA))). Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),

View File

@ -109,9 +109,9 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil return nil
} }
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if !r.IsEnabled() { if !r.IsEnabled() {
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
reverseResp := r.handleReverseDNS(request) reverseResp := r.handleReverseDNS(request)
@ -134,7 +134,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver") r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg { func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg {

View File

@ -25,6 +25,9 @@ var _ = Describe("HostsFileResolver", func() {
tmpFile *TmpFile tmpFile *TmpFile
err error err error
resp *Response resp *Response
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -34,6 +37,9 @@ var _ = Describe("HostsFileResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
tmpDir = NewTmpFolder("HostsFileResolver") tmpDir = NewTmpFolder("HostsFileResolver")
Expect(tmpDir.Error).Should(Succeed()) Expect(tmpDir.Error).Should(Succeed())
DeferCleanup(tmpDir.Clean) DeferCleanup(tmpDir.Clean)
@ -53,8 +59,6 @@ var _ = Describe("HostsFileResolver", func() {
}) })
JustBeforeEach(func() { JustBeforeEach(func() {
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap) sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
@ -96,7 +100,7 @@ var _ = Describe("HostsFileResolver", func() {
Expect(sut.hosts.isEmpty()).Should(BeTrue()) Expect(sut.hosts.isEmpty()).Should(BeTrue())
}) })
It("should go to next resolver on query", func() { It("should go to next resolver on query", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -112,11 +116,11 @@ var _ = Describe("HostsFileResolver", func() {
sutConfig.Sources = make([]config.BytesSource, 0) sutConfig.Sources = make([]config.BytesSource, 0)
}) })
JustBeforeEach(func() { JustBeforeEach(func() {
err = sut.loadSources(context.Background()) err = sut.loadSources(ctx)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
}) })
It("should go to next resolver on query", func() { It("should go to next resolver on query", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -178,7 +182,7 @@ var _ = Describe("HostsFileResolver", func() {
When("IPv4 mapping is defined for a host", func() { When("IPv4 mapping is defined for a host", func() {
It("defined ipv4 query should be resolved", func() { It("defined ipv4 query should be resolved", func() {
Expect(sut.Resolve(newRequest("ipv4host.", A))). Expect(sut.Resolve(ctx, newRequest("ipv4host.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -188,7 +192,7 @@ var _ = Describe("HostsFileResolver", func() {
)) ))
}) })
It("defined ipv4 query for alias should be resolved", func() { It("defined ipv4 query for alias should be resolved", func() {
Expect(sut.Resolve(newRequest("router2.", A))). Expect(sut.Resolve(ctx, newRequest("router2.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -198,7 +202,7 @@ var _ = Describe("HostsFileResolver", func() {
)) ))
}) })
It("ipv4 query should return NOERROR and empty result", func() { It("ipv4 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("does.not.exist.", A))). Expect(sut.Resolve(ctx, newRequest("does.not.exist.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -210,7 +214,7 @@ var _ = Describe("HostsFileResolver", func() {
When("IPv6 mapping is defined for a host", func() { When("IPv6 mapping is defined for a host", func() {
It("defined ipv6 query should be resolved", func() { It("defined ipv6 query should be resolved", func() {
Expect(sut.Resolve(newRequest("ipv6host.", AAAA))). Expect(sut.Resolve(ctx, newRequest("ipv6host.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -220,7 +224,7 @@ var _ = Describe("HostsFileResolver", func() {
)) ))
}) })
It("ipv6 query should return NOERROR and empty result", func() { It("ipv6 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))). Expect(sut.Resolve(ctx, newRequest("does.not.exist.", AAAA))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -232,7 +236,7 @@ var _ = Describe("HostsFileResolver", func() {
When("the domain is not known", func() { When("the domain is not known", func() {
It("calls the next resolver", func() { It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("not-in-hostsfile.tld.", A)) resp, err = sut.Resolve(ctx, newRequest("not-in-hostsfile.tld.", A))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
@ -241,7 +245,7 @@ var _ = Describe("HostsFileResolver", func() {
When("the question type is not handled", func() { When("the question type is not handled", func() {
It("calls the next resolver", func() { It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("localhost.", MX)) resp, err = sut.Resolve(ctx, newRequest("localhost.", MX))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
@ -251,7 +255,7 @@ var _ = Describe("HostsFileResolver", func() {
When("Reverse DNS request is received", func() { When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() { It("should resolve the defined domain name", func() {
By("ipv4 with one hostname", func() { By("ipv4 with one hostname", func() {
Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))). Expect(sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.arpa.", PTR))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -261,7 +265,7 @@ var _ = Describe("HostsFileResolver", func() {
)) ))
}) })
By("ipv4 with aliases", func() { By("ipv4 with aliases", func() {
Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))). Expect(sut.Resolve(ctx, newRequest("1.0.0.10.in-addr.arpa.", PTR))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -274,7 +278,9 @@ var _ = Describe("HostsFileResolver", func() {
)) ))
}) })
By("ipv6", func() { By("ipv6", func() {
Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))). Expect(sut.Resolve(ctx,
newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR)),
).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -290,7 +296,7 @@ var _ = Describe("HostsFileResolver", func() {
}) })
It("should ignore invalid PTR", func() { It("should ignore invalid PTR", func() {
resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR)) resp, err = sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
@ -298,7 +304,7 @@ var _ = Describe("HostsFileResolver", func() {
When("filterLoopback is true", func() { When("filterLoopback is true", func() {
It("calls the next resolver", func() { It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR)) resp, err = sut.Resolve(ctx, newRequest("1.0.0.127.in-addr.arpa.", PTR))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
@ -307,7 +313,7 @@ var _ = Describe("HostsFileResolver", func() {
When("the IP is not known", func() { When("the IP is not known", func() {
It("calls the next resolver", func() { It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR)) resp, err = sut.Resolve(ctx, newRequest("255.255.255.255.in-addr.arpa.", PTR))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())
@ -320,7 +326,7 @@ var _ = Describe("HostsFileResolver", func() {
}) })
It("resolve the defined domain name", func() { It("resolve the defined domain name", func() {
Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))). Expect(sut.Resolve(ctx, newRequest("1.1.0.127.in-addr.arpa.", PTR))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE), HaveResponseType(ResponseTypeHOSTSFILE),
@ -338,7 +344,7 @@ var _ = Describe("HostsFileResolver", func() {
Describe("Delegating to next resolver", func() { Describe("Delegating to next resolver", func() {
When("no hosts file is provided", func() { When("no hosts file is provided", func() {
It("should delegate to next resolver", func() { It("should delegate to next resolver", func() {
_, err = sut.Resolve(newRequest("example.com.", A)) _, err = sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
// delegate was executed // delegate was executed
m.AssertExpectations(GinkgoT()) m.AssertExpectations(GinkgoT())

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"strings" "strings"
"time" "time"
@ -25,8 +26,8 @@ type MetricsResolver struct {
} }
// Resolve resolves the passed request // Resolve resolves the passed request
func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *MetricsResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
response, err := r.next.Resolve(request) response, err := r.next.Resolve(ctx, request)
if r.cfg.Enable { if r.cfg.Enable {
r.totalQueries.With(prometheus.Labels{ r.totalQueries.With(prometheus.Labels{

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"errors" "errors"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -21,6 +22,9 @@ var _ = Describe("MetricResolver", func() {
var ( var (
sut *MetricsResolver sut *MetricsResolver
m *mockResolver m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -30,6 +34,9 @@ var _ = Describe("MetricResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewMetricsResolver(config.MetricsConfig{Enable: true}) sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
m = &mockResolver{} m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
@ -56,7 +63,7 @@ var _ = Describe("MetricResolver", func() {
Context("Recording request metrics", func() { Context("Recording request metrics", func() {
When("Request will be performed", func() { When("Request will be performed", func() {
It("Should record metrics", func() { It("Should record metrics", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -77,7 +84,7 @@ var _ = Describe("MetricResolver", func() {
sut.Next(m) sut.Next(m)
}) })
It("Error should be recorded", func() { It("Error should be recorded", func() {
_, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client")) _, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util" "github.com/0xERR0R/blocky/util"
@ -62,6 +63,21 @@ func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response
return t return t
} }
func (t *MockUDPUpstreamServer) WithDelay(delay time.Duration) *MockUDPUpstreamServer {
answerFn := t.answerFn
if answerFn == nil {
panic("WithDelay must be called after a WithAnswer function")
}
t.answerFn = func(request *dns.Msg) *dns.Msg {
time.Sleep(delay)
return answerFn(request)
}
return t
}
func (t *MockUDPUpstreamServer) GetCallCount() int { func (t *MockUDPUpstreamServer) GetCallCount() int {
return int(atomic.LoadInt32(&t.callCount)) return int(atomic.LoadInt32(&t.callCount))
} }

View File

@ -23,7 +23,7 @@ type mockResolver struct {
mock.Mock mock.Mock
NextResolver NextResolver
ResolveFn func(req *model.Request) (*model.Response, error) ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error)
ResponseFn func(req *dns.Msg) *dns.Msg ResponseFn func(req *dns.Msg) *dns.Msg
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error) AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
} }
@ -45,11 +45,11 @@ func (r *mockResolver) LogConfig(*logrus.Entry) {
r.Called() r.Called()
} }
func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) { func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
args := r.Called(req) args := r.Called(req)
if r.ResolveFn != nil { if r.ResolveFn != nil {
return r.ResolveFn(req) return r.ResolveFn(ctx, req)
} }
if r.ResponseFn != nil { if r.ResponseFn != nil {

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -28,6 +30,6 @@ func (NoOpResolver) IsEnabled() bool {
func (NoOpResolver) LogConfig(*logrus.Entry) { func (NoOpResolver) LogConfig(*logrus.Entry) {
} }
func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) { func (NoOpResolver) Resolve(context.Context, *model.Request) (*model.Response, error) {
return NoResponse, nil return NoResponse, nil
} }

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
. "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@ -8,7 +10,12 @@ import (
) )
var _ = Describe("NoOpResolver", func() { var _ = Describe("NoOpResolver", func() {
var sut *NoOpResolver var (
sut *NoOpResolver
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() { Describe("Type", func() {
It("follows conventions", func() { It("follows conventions", func() {
@ -17,12 +24,15 @@ var _ = Describe("NoOpResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewNoOpResolver() sut = NewNoOpResolver()
}) })
Describe("Resolving", func() { Describe("Resolving", func() {
It("returns no response", func() { It("returns no response", func() {
resp, err := sut.Resolve(newRequest("test.tld", A)) resp, err := sut.Resolve(ctx, newRequest("test.tld", A))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).Should(Equal(NoResponse)) Expect(resp).Should(Equal(NoResponse))
}) })

View File

@ -53,17 +53,23 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
return status return status
} }
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) { func (r *upstreamResolverStatus) resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
resp, err := r.resolver.Resolve(req) resp, err := r.resolver.Resolve(ctx, req)
if err != nil { if err != nil {
// Ignore `Canceled`: resolver lost the race, not an error // Ignore `Canceled`: resolver lost the race, not an error
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
r.lastErrorTime.Store(time.Now()) r.lastErrorTime.Store(time.Now())
} }
err = fmt.Errorf("%s: %w", r.resolver, err) return nil, fmt.Errorf("%s: %w", r.resolver, err)
} }
return resp, nil
}
func (r *upstreamResolverStatus) resolveToChan(ctx context.Context, req *model.Request, ch chan<- requestResponse) {
resp, err := r.resolve(ctx, req)
ch <- requestResponse{ ch <- requestResponse{
resolver: &r.resolver, resolver: &r.resolver,
response: resp, response: resp,
@ -78,10 +84,10 @@ type requestResponse struct {
} }
// testResolver sends a test query to verify the resolver is reachable and working // testResolver sends a test query to verify the resolver is reachable and working
func testResolver(r *UpstreamResolver) error { func testResolver(ctx context.Context, r *UpstreamResolver) error {
request := newRequest("github.com.", dns.Type(dns.TypeA)) request := newRequest("github.com.", dns.Type(dns.TypeA))
resp, err := r.Resolve(request) resp, err := r.Resolve(ctx, request)
if err != nil || resp.RType != model.ResponseTypeRESOLVED { if err != nil || resp.RType != model.ResponseTypeRESOLVED {
return fmt.Errorf("test resolve of upstream server failed: %w", err) return fmt.Errorf("test resolve of upstream server failed: %w", err)
} }
@ -91,11 +97,11 @@ func testResolver(r *UpstreamResolver) error {
// NewParallelBestResolver creates new resolver instance // NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver( func NewParallelBestResolver(
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool, ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ParallelBestResolver, error) { ) (*ParallelBestResolver, error) {
logger := log.PrefixedLog(parallelResolverType) logger := log.PrefixedLog(parallelResolverType)
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams) resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -150,26 +156,18 @@ func (r *ParallelBestResolver) String() string {
} }
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result // Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType) logger := log.WithPrefix(request.Log, parallelResolverType)
if len(r.resolvers) == 1 { if len(r.resolvers) == 1 {
logger.WithField("resolver", r.resolvers[0].resolver).Debug("delegating to resolver") resolver := r.resolvers[0]
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
return r.resolvers[0].resolver.Resolve(request) return resolver.resolve(ctx, request)
} }
ctx := context.Background() ctx, cancel := context.WithCancel(ctx)
defer cancel() // abort requests to resolvers that lost the race
// using context with timeout for random upstream strategy
if r.resolverCount == 1 {
var cancel context.CancelFunc
timeout := config.GetConfig().Upstreams.Timeout
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
defer cancel()
}
resolvers := pickRandom(r.resolvers, r.resolverCount) resolvers := pickRandom(r.resolvers, r.resolverCount)
ch := make(chan requestResponse, len(resolvers)) ch := make(chan requestResponse, len(resolvers))
@ -177,10 +175,10 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
for _, resolver := range resolvers { for _, resolver := range resolvers {
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver") logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
go resolver.resolve(request, ch) go resolver.resolveToChan(ctx, request, ch)
} }
response, collectedErrors := evaluateResponses(ctx, logger, ch, resolvers) response, collectedErrors := evaluateResponses(logger, ch, resolvers)
if response != nil { if response != nil {
return response, nil return response, nil
} }
@ -189,63 +187,51 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...)) return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...))
} }
return r.retryWithDifferent(logger, request, resolvers) return r.retryWithDifferent(ctx, logger, request, resolvers)
} }
func evaluateResponses( func evaluateResponses(
ctx context.Context, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
) (*model.Response, []error) { ) (*model.Response, []error) {
collectedErrors := make([]error, 0, len(resolvers)) collectedErrors := make([]error, 0, len(resolvers))
for len(collectedErrors) < len(resolvers) { for len(collectedErrors) < len(resolvers) {
select { result := <-ch
case <-ctx.Done(): logger := logger.WithField("resolver", *result.resolver)
// this context currently only has a deadline when resolverCount == 1
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
logger.WithField("resolver", resolvers[0].resolver).
Debug("upstream exceeded timeout, trying other upstream")
resolvers[0].lastErrorTime.Store(time.Now())
}
case result := <-ch:
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
} else {
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil if result.err != nil {
} logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
continue
} }
logger.WithField("answer", util.AnswerToString(result.response.Res.Answer)).Debug("using response from resolver")
return result.response, nil
} }
return nil, collectedErrors return nil, collectedErrors
} }
func (r *ParallelBestResolver) retryWithDifferent( func (r *ParallelBestResolver) retryWithDifferent(
logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus, ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
) (*model.Response, error) { ) (*model.Response, error) {
// second try (if retryWithDifferentResolver == true) // second try (if retryWithDifferentResolver == true)
resolver := weightedRandom(r.resolvers, resolvers) resolver := weightedRandom(r.resolvers, resolvers)
logger.Debugf("using %s as second resolver", resolver.resolver) logger.Debugf("using %s as second resolver", resolver.resolver)
ch := make(chan requestResponse, 1) resp, err := resolver.resolve(ctx, request)
if err != nil {
resolver.resolve(request, ch) return nil, fmt.Errorf("resolution retry failed: %w", err)
result := <-ch
if result.err != nil {
return nil, fmt.Errorf("resolution retry failed: %w", result.err)
} }
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"resolver": *result.resolver, "resolver": *resolver,
"answer": util.AnswerToString(result.response.Res.Answer), "answer": util.AnswerToString(resp.Res.Answer),
}).Debug("using response from resolver") }).Debug("using response from resolver")
return result.response, nil return resp, nil
} }
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool // pickRandom picks n (resolverCount) different random resolvers from the given resolver pool

View File

@ -19,6 +19,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
const ( const (
verifyUpstreams = true verifyUpstreams = true
noVerifyUpstreams = false noVerifyUpstreams = false
timeout = 50 * time.Millisecond
) )
var ( var (
@ -40,6 +42,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}) })
BeforeEach(func() { BeforeEach(func() {
old := config.GetConfig().Upstreams.Timeout
DeferCleanup(func() { config.GetConfig().Upstreams.Timeout = old })
config.GetConfig().Upstreams.Timeout = config.Duration(timeout)
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
ctx, cancelFn = context.WithCancel(context.Background()) ctx, cancelFn = context.WithCancel(context.Background())
@ -58,7 +64,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Upstreams: upstreams, Upstreams: upstreams,
} }
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify) sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify)
}) })
Describe("IsEnabled", func() { Describe("IsEnabled", func() {
@ -98,7 +104,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream.Start(), mockUpstream.Start(),
} }
_, err := NewParallelBestResolver(config.UpstreamGroup{ _, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName, Name: upstreamDefaultCfgName,
Upstreams: upstreams, Upstreams: upstreams,
}, },
@ -133,6 +139,43 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}) })
}) })
When("upstream is too slow", func() {
BeforeEach(func() {
timeoutUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.1").
WithDelay(2 * timeout)
DeferCleanup(timeoutUpstream.Close)
upstreams = []config.Upstream{timeoutUpstream.Start()}
})
When("strict checking is enabled", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
When("strict checking is disabled", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
})
It("should start", func() {
Expect(err).Should(Succeed())
})
It("should not resolve", func() {
Expect(err).Should(Succeed())
request := newRequest("example.com.", A)
_, err := sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
Expect(isTimeout(err)).Should(BeTrue())
})
})
})
Describe("Resolving result from fastest upstream resolver", func() { Describe("Resolving result from fastest upstream resolver", func() {
When("2 Upstream resolvers are defined", func() { When("2 Upstream resolvers are defined", func() {
When("one resolver is fast and another is slow", func() { When("one resolver is fast and another is slow", func() {
@ -140,21 +183,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
fastTestUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") fastTestUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(fastTestUpstream.Close) DeferCleanup(fastTestUpstream.Close)
slowTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { slowTestUpstream := NewMockUDPUpstreamServer().
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123") WithAnswerRR("example.com 123 IN A 123.124.122.123").
time.Sleep(50 * time.Millisecond) WithDelay(timeout / 2)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(slowTestUpstream.Close) DeferCleanup(slowTestUpstream.Close)
upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()} upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()}
}) })
It("Should use result from fastest one", func() { It("Should use result from fastest one", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -167,21 +205,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
When("one resolver is slow, but another returns an error", func() { When("one resolver is slow, but another returns an error", func() {
var slowTestUpstream *MockUDPUpstreamServer var slowTestUpstream *MockUDPUpstreamServer
BeforeEach(func() { BeforeEach(func() {
slowTestUpstream = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { slowTestUpstream = NewMockUDPUpstreamServer().
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123") WithAnswerRR("example.com 123 IN A 123.124.122.123").
time.Sleep(50 * time.Millisecond) WithDelay(timeout / 2)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(slowTestUpstream.Close) DeferCleanup(slowTestUpstream.Close)
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()} upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
}) })
It("Should use result from successful resolver", func() { It("Should use result from successful resolver", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"), BeDNSRecord("example.com.", A, "123.124.122.123"),
@ -202,7 +235,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("Should return error", func() { It("Should return error", func() {
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
_, err = sut.Resolve(request) _, err = sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
}) })
@ -218,7 +251,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("Should use result from defined resolver", func() { It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -242,7 +275,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close) DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(config.UpstreamGroup{ sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName, Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
}, },
@ -268,7 +301,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
By("perform 100 request, error upstream's weight will be reduced", func() { By("perform 100 request, error upstream's weight will be reduced", func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
_, _ = sut.Resolve(request) _, _ = sut.Resolve(ctx, request)
} }
}) })
@ -302,7 +335,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("errors during construction", func() { It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(config.UpstreamGroup{ r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: "test", Name: "test",
Upstreams: []config.Upstream{{Host: "example.com"}}, Upstreams: []config.Upstream{{Host: "example.com"}},
}, b, verifyUpstreams) }, b, verifyUpstreams)
@ -313,11 +346,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}) })
Describe("random resolver strategy", func() { Describe("random resolver strategy", func() {
const timeout = config.Duration(time.Second)
BeforeEach(func() { BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
config.GetConfig().Upstreams.Timeout = timeout
}) })
Describe("Name", func() { Describe("Name", func() {
@ -342,7 +372,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}) })
It("Should return result from either one", func() { It("Should return result from either one", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should(SatisfyAll( Should(SatisfyAll(
HaveTTL(BeNumerically("==", 123)), HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -356,24 +386,19 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}) })
When("one upstream exceeds timeout", func() { When("one upstream exceeds timeout", func() {
BeforeEach(func() { BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { timeoutUpstream := NewMockUDPUpstreamServer().
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") WithAnswerRR("example.com 123 IN A 123.124.122.1").
time.Sleep(time.Duration(timeout) + 2*time.Second) WithDelay(2 * timeout)
DeferCleanup(timeoutUpstream.Close)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2") testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
DeferCleanup(testUpstream2.Close) DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()} upstreams = []config.Upstream{timeoutUpstream.Start(), testUpstream2.Start()}
}) })
It("should ask a other random upstream and return its response", func() { It("should ask another random upstream and return its response", func() {
request := newRequest("example.com", A) request := newRequest("example.com", A)
Expect(sut.Resolve(request)).Should( Expect(sut.Resolve(ctx, request)).Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.2"), BeDNSRecord("example.com.", A, "123.124.122.2"),
HaveTTL(BeNumerically("==", 123)), HaveTTL(BeNumerically("==", 123)),
@ -382,57 +407,25 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
)) ))
}) })
}) })
When("two upstreams exceed timeout", func() { When("all upstreams exceed timeout", func() {
BeforeEach(func() { BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { testUpstream1 := NewMockUDPUpstreamServer().
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1") WithAnswerRR("example.com 123 IN A 123.124.122.1").
time.Sleep(timeout.ToDuration() + 2*time.Second) WithDelay(2 * timeout)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close) DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { testUpstream2 := NewMockUDPUpstreamServer().
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2") WithAnswerRR("example.com 123 IN A 123.124.122.2").
time.Sleep(timeout.ToDuration() + 2*time.Second) WithDelay(2 * timeout)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream2.Close) DeferCleanup(testUpstream2.Close)
testUpstream3 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.3") upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
DeferCleanup(testUpstream3.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start(), testUpstream3.Start()}
}) })
// These two tests are flaky -_- (maybe recreate the RandomResolver ) It("Should return error", func() {
It("should not return error (due to random selection the request could to through)", func() { request := newRequest("example.com.", A)
Eventually(func() error { _, err := sut.Resolve(ctx, request)
request := newRequest("example.com", A) Expect(err).Should(HaveOccurred())
_, err := sut.Resolve(request) Expect(isTimeout(err)).Should(BeTrue())
return err
}).WithTimeout(30 * time.Second).
Should(Not(HaveOccurred()))
})
It("should return error (because it can be possible that the two broken upstreams are chosen)", func() {
Eventually(func() error {
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
return err
}).WithTimeout(30 * time.Second).
Should(HaveOccurred())
}) })
}) })
}) })
@ -446,7 +439,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
}) })
It("Should return error", func() { It("Should return error", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
_, err := sut.Resolve(request) _, err := sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
}) })
}) })
@ -461,7 +454,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("Should use result from defined resolver", func() { It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -485,7 +478,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close) DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(config.UpstreamGroup{ sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName, Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
}, },
@ -506,7 +499,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
By("perform 100 request, error upstream's weight will be reduced", func() { By("perform 100 request, error upstream's weight will be reduced", func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
_, _ = sut.Resolve(request) _, _ = sut.Resolve(ctx, request)
} }
}) })

View File

@ -111,12 +111,12 @@ func (r *QueryLoggingResolver) doCleanUp() {
} }
// Resolve logs the query, duration and the result // Resolve logs the query, duration and the result
func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, queryLoggingResolverType) logger := log.WithPrefix(request.Log, queryLoggingResolverType)
start := time.Now() start := time.Now()
resp, err := r.next.Resolve(request) resp, err := r.next.Resolve(ctx, request)
duration := time.Since(start).Milliseconds() duration := time.Since(start).Milliseconds()

View File

@ -44,6 +44,9 @@ var _ = Describe("QueryLoggingResolver", func() {
m *mockResolver m *mockResolver
tmpDir *TmpFolder tmpDir *TmpFolder
mockAnswer *dns.Msg mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -53,6 +56,9 @@ var _ = Describe("QueryLoggingResolver", func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg) mockAnswer = new(dns.Msg)
tmpDir = NewTmpFolder("queryLoggingResolver") tmpDir = NewTmpFolder("queryLoggingResolver")
Expect(tmpDir.Error).Should(Succeed()) Expect(tmpDir.Error).Should(Succeed())
@ -64,9 +70,6 @@ var _ = Describe("QueryLoggingResolver", func() {
sutConfig.SetDefaults() // not called when using a struct literal sutConfig.SetDefaults() // not called when using a struct literal
} }
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewQueryLoggingResolver(ctx, sutConfig) sut = NewQueryLoggingResolver(ctx, sutConfig)
m = &mockResolver{} m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil) m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
@ -98,7 +101,7 @@ var _ = Describe("QueryLoggingResolver", func() {
} }
}) })
It("should process request without query logging", func() { It("should process request without query logging", func() {
Expect(sut.Resolve(newRequest("example.com", A))). Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -120,7 +123,7 @@ var _ = Describe("QueryLoggingResolver", func() {
}) })
It("should create a log file per client", func() { It("should create a log file per client", func() {
By("request from client 1", func() { By("request from client 1", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -128,7 +131,7 @@ var _ = Describe("QueryLoggingResolver", func() {
)) ))
}) })
By("request from client 2, has name with special chars, should be escaped", func() { By("request from client 2, has name with special chars, should be escaped", func() {
Expect(sut.Resolve(newRequestWithClient( Expect(sut.Resolve(ctx, newRequestWithClient(
"example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))). "example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))).
Should( Should(
SatisfyAll( SatisfyAll(
@ -188,7 +191,7 @@ var _ = Describe("QueryLoggingResolver", func() {
}) })
It("should create one log file for all clients", func() { It("should create one log file for all clients", func() {
By("request from client 1", func() { By("request from client 1", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -196,7 +199,7 @@ var _ = Describe("QueryLoggingResolver", func() {
)) ))
}) })
By("request from client 2, has name with special chars, should be escaped", func() { By("request from client 2, has name with special chars, should be escaped", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -249,7 +252,7 @@ var _ = Describe("QueryLoggingResolver", func() {
}) })
It("should create one log file", func() { It("should create one log file", func() {
By("request from client 1", func() { By("request from client 1", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED), HaveResponseType(ResponseTypeRESOLVED),
@ -297,7 +300,7 @@ var _ = Describe("QueryLoggingResolver", func() {
sut.writer = mockWriter sut.writer = mockWriter
Eventually(func() int { Eventually(func() int {
_, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1")) _, ierr := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
Expect(ierr).Should(Succeed()) Expect(ierr).Should(Succeed())
return len(sut.logChan) return len(sut.logChan)

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"time" "time"
@ -74,7 +75,7 @@ type Resolver interface {
Type() string Type() string
// Resolve performs resolution of a DNS request // Resolve performs resolution of a DNS request
Resolve(req *model.Request) (*model.Response, error) Resolve(ctx context.Context, req *model.Request) (*model.Response, error)
} }
// ChainedResolver represents a resolver, which can delegate result to the next one // ChainedResolver represents a resolver, which can delegate result to the next one
@ -216,13 +217,14 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
} }
func createResolvers( func createResolvers(
logger *logrus.Entry, cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool, ctx context.Context, logger *logrus.Entry,
cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
) ([]Resolver, error) { ) ([]Resolver, error) {
resolvers := make([]Resolver, 0, len(cfg.Upstreams)) resolvers := make([]Resolver, 0, len(cfg.Upstreams))
hasValidResolvers := false hasValidResolvers := false
for _, u := range cfg.Upstreams { for _, u := range cfg.Upstreams {
resolver, err := NewUpstreamResolver(u, bootstrap, shoudVerifyUpstreams) resolver, err := NewUpstreamResolver(ctx, u, bootstrap, shoudVerifyUpstreams)
if err != nil { if err != nil {
logger.Warnf("upstream group %s: %v", cfg.Name, err) logger.Warnf("upstream group %s: %v", cfg.Name, err)
@ -230,7 +232,7 @@ func createResolvers(
} }
if shoudVerifyUpstreams { if shoudVerifyUpstreams {
err = testResolver(resolver) err = testResolver(ctx, resolver)
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} else { } else {

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -57,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
} }
// Resolve uses the inner resolver to resolve the rewritten query // Resolve uses the inner resolver to resolve the rewritten query
func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "rewriter_resolver") logger := log.WithPrefix(request.Log, "rewriter_resolver")
original := request.Req original := request.Req
@ -69,7 +70,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver") logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver")
response, err := r.inner.Resolve(request) response, err := r.inner.Resolve(ctx, request)
// Test for error after checking for fallbackUpstream // Test for error after checking for fallbackUpstream
// Revert the request: must be done before calling r.next // Revert the request: must be done before calling r.next
@ -80,7 +81,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain // Inner resolver had no answer, configuration requests fallback, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver") logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
if err != nil { if err != nil {
@ -91,7 +92,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
// Inner resolver had no response, continue with the normal chain // Inner resolver had no response, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver") logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
// Revert the rewrite in r.inner's response // Revert the rewrite in r.inner's response

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/model"
@ -94,7 +96,7 @@ var _ = Describe("RewriterResolver", func() {
return res return res
} }
resp, err := sut.Resolve(request) resp, err := sut.Resolve(context.Background(), request)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
if resp != mNextResponse { if resp != mNextResponse {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal)) Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
@ -132,18 +134,18 @@ var _ = Describe("RewriterResolver", func() {
expectNilAnswer = true expectNilAnswer = true
// Make inner call the NoOpResolver // Make inner call the NoOpResolver
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request)) Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten // Inner should see fqdnRewritten
Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten)) Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten))
return mInner.next.Resolve(req) return mInner.next.Resolve(ctx, req)
} }
// Resolver after RewriterResolver should see `fqdnOriginal` // Resolver after RewriterResolver should see `fqdnOriginal`
mNext.On("Resolve", mock.Anything) mNext.On("Resolve", mock.Anything)
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) { mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal)) Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
return mNextResponse, nil return mNextResponse, nil
@ -156,7 +158,7 @@ var _ = Describe("RewriterResolver", func() {
expectNilAnswer = true expectNilAnswer = true
// Make inner return a nil Answer but not an empty Response // Make inner return a nil Answer but not an empty Response
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request)) Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten // Inner should see fqdnRewritten
@ -179,7 +181,7 @@ var _ = Describe("RewriterResolver", func() {
fqdnRewritten = sampleRewritten fqdnRewritten = sampleRewritten
// Make inner return a nil Answer but not an empty Response // Make inner return a nil Answer but not an empty Response
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request)) Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten // Inner should see fqdnRewritten
@ -190,7 +192,7 @@ var _ = Describe("RewriterResolver", func() {
// Resolver after RewriterResolver should see `fqdnOriginal` // Resolver after RewriterResolver should see `fqdnOriginal`
mNext.On("Resolve", mock.Anything) mNext.On("Resolve", mock.Anything)
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) { mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal)) Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
return mNextResponse, nil return mNextResponse, nil

View File

@ -30,11 +30,11 @@ type StrictResolver struct {
// NewStrictResolver creates a new strict resolver instance // NewStrictResolver creates a new strict resolver instance
func NewStrictResolver( func NewStrictResolver(
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool, ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*StrictResolver, error) { ) (*StrictResolver, error) {
logger := log.PrefixedLog(strictResolverType) logger := log.PrefixedLog(strictResolverType)
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams) resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -76,44 +76,27 @@ func (r *StrictResolver) String() string {
} }
// Resolve sends the query request in a strict order to the upstream resolvers // Resolve sends the query request in a strict order to the upstream resolvers
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, strictResolverType) logger := log.WithPrefix(request.Log, strictResolverType)
// start with first resolver // start with first resolver
for i := range r.resolvers { for _, resolver := range r.resolvers {
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resolver := r.resolvers[i]
logger.Debugf("using %s as resolver", resolver.resolver) logger.Debugf("using %s as resolver", resolver.resolver)
ch := make(chan requestResponse, 1) resp, err := resolver.resolve(ctx, request)
if err != nil {
go resolver.resolve(request, ch) // log error and try next upstream
logger.WithField("resolver", resolver.resolver).Debug("resolution failed from resolver, cause: ", err)
select {
case <-ctx.Done():
// log debug/info that timeout exceeded, call `continue` to try next upstream
logger.WithField("resolver", r.resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
continue continue
case result := <-ch:
if result.err != nil {
// log error & call `continue` to try next upstream
logger.Debug("resolution failed from resolver, cause: ", result.err)
continue
}
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
} }
logger.WithFields(logrus.Fields{
"resolver": *resolver,
"answer": util.AnswerToString(resp.Res.Answer),
}).Debug("using response from resolver")
return resp, nil
} }
return nil, errors.New("resolution was not successful, no resolver returned an answer in time") return nil, errors.New("resolution was not successful, no resolver returned an answer in time")

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"time" "time"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -27,6 +28,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
err error err error
bootstrap *Bootstrap bootstrap *Bootstrap
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -36,6 +40,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
}) })
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
upstreams = []config.Upstream{ upstreams = []config.Upstream{
{Host: "wrong"}, {Host: "wrong"},
{Host: "127.0.0.2"}, {Host: "127.0.0.2"},
@ -51,7 +58,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
Name: upstreamDefaultCfgName, Name: upstreamDefaultCfgName,
Upstreams: upstreams, Upstreams: upstreams,
} }
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify) sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify)
}) })
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second) config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
@ -100,7 +107,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
mockUpstream.Start(), mockUpstream.Start(),
} }
_, err := NewStrictResolver(config.UpstreamGroup{ _, err := NewStrictResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName, Name: upstreamDefaultCfgName,
Upstreams: upstreams, Upstreams: upstreams,
}, },
@ -151,7 +158,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
}) })
It("Should use result from first one", func() { It("Should use result from first one", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -180,7 +187,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
}) })
It("should return response from next upstream", func() { It("should return response from next upstream", func() {
request := newRequest("example.com", A) request := newRequest("example.com", A)
Expect(sut.Resolve(request)).Should( Expect(sut.Resolve(ctx, request)).Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.2"), BeDNSRecord("example.com.", A, "123.124.122.2"),
HaveTTL(BeNumerically("==", 123)), HaveTTL(BeNumerically("==", 123)),
@ -214,7 +221,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
}) })
It("should return error", func() { It("should return error", func() {
request := newRequest("example.com", A) request := newRequest("example.com", A)
_, err := sut.Resolve(request) _, err := sut.Resolve(ctx, request)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })
@ -230,7 +237,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
}) })
It("Should use result from second one", func() { It("Should use result from second one", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"), BeDNSRecord("example.com.", A, "123.124.122.123"),
@ -247,7 +254,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
}) })
It("Should return error", func() { It("Should return error", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
_, err = sut.Resolve(request) _, err = sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
}) })
}) })
@ -262,7 +269,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
It("Should use result from defined resolver", func() { It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A) request := newRequest("example.com.", A)
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"net" "net"
"strings" "strings"
@ -99,7 +100,7 @@ func NewSpecialUseDomainNamesResolver(cfg config.SUDN) *SpecialUseDomainNamesRes
} }
} }
func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *SpecialUseDomainNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
handler := r.handler(request) handler := r.handler(request)
if handler != nil { if handler != nil {
resp := handler(request, r.cfg) resp := handler(request, r.cfg)
@ -108,7 +109,7 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.
} }
} }
return r.next.Resolve(request) return r.next.Resolve(ctx, request)
} }
func (r *SpecialUseDomainNamesResolver) handler(request *model.Request) sudnHandler { func (r *SpecialUseDomainNamesResolver) handler(request *model.Request) sudnHandler {

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"fmt" "fmt"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
@ -19,6 +20,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
sut *SpecialUseDomainNamesResolver sut *SpecialUseDomainNamesResolver
sutConfig config.SUDN sutConfig config.SUDN
m *mockResolver m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
) )
Describe("Type", func() { Describe("Type", func() {
@ -30,6 +34,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig, err = config.WithDefaults[config.SUDN]() sutConfig, err = config.WithDefaults[config.SUDN]()
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
}) })
@ -48,7 +55,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
Describe("handlers", func() { Describe("handlers", func() {
It("should have correct response type", func() { It("should have correct response type", func() {
for domain, handler := range sudnHandlers { for domain, handler := range sudnHandlers {
resp, err := sut.Resolve(newRequest(domain, A)) resp, err := sut.Resolve(ctx, newRequest(domain, A))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
if handler == nil { if handler == nil {
@ -90,7 +97,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
DescribeTable("handled domains", DescribeTable("handled domains",
func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) { func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) {
resp, err := sut.Resolve(newRequest(qName, qType)) resp, err := sut.Resolve(ctx, newRequest(qName, qType))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).Should(SatisfyAll( Expect(resp).Should(SatisfyAll(
HaveResponseType(ResponseTypeSPECIAL), HaveResponseType(ResponseTypeSPECIAL),
@ -133,7 +140,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
DescribeTable("", DescribeTable("",
func(qType dns.Type, qName string, expectedRCode int) { func(qType dns.Type, qName string, expectedRCode int) {
resp, err := sut.Resolve(newRequest(qName, qType)) resp, err := sut.Resolve(ctx, newRequest(qName, qType))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).Should(HaveReturnCode(expectedRCode)) Expect(resp).Should(HaveReturnCode(expectedRCode))
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
@ -150,7 +157,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
}) })
It("should forward example.com", func() { It("should forward example.com", func() {
Expect(sut.Resolve(newRequest("example.com", A))). Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.145.123.145"), BeDNSRecord("example.com.", A, "123.145.123.145"),
@ -161,7 +168,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
}) })
It("should forward home.arpa. IN DS", func() { It("should forward home.arpa. IN DS", func() {
Expect(sut.Resolve(newRequest("something.home.arpa.", DS))). Expect(sut.Resolve(ctx, newRequest("something.home.arpa.", DS))).
Should( Should(
SatisfyAll( SatisfyAll(
// setup code doesn't care about the question // setup code doesn't care about the question
@ -173,7 +180,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
}) })
It("should forward non special use domains", func() { It("should forward non special use domains", func() {
resp, err := sut.Resolve(newRequest("something.not-special.", AAAA)) resp, err := sut.Resolve(ctx, newRequest("something.not-special.", AAAA))
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
}) })

View File

@ -2,6 +2,7 @@ package resolver
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -39,8 +40,9 @@ type UpstreamResolver struct {
type upstreamClient interface { type upstreamClient interface {
fmtURL(ip net.IP, port uint16, path string) string fmtURL(ip net.IP, port uint16, path string) string
callExternal(msg *dns.Msg, upstreamURL string, callExternal(
protocol model.RequestProtocol) (response *dns.Msg, rtt time.Duration, err error) ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error)
} }
type dnsUpstreamClient struct { type dnsUpstreamClient struct {
@ -53,8 +55,6 @@ type httpUpstreamClient struct {
} }
func createUpstreamClient(cfg config.Upstream) upstreamClient { func createUpstreamClient(cfg config.Upstream) upstreamClient {
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
tlsConfig := tls.Config{ tlsConfig := tls.Config{
ServerName: cfg.Host, ServerName: cfg.Host,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
@ -73,7 +73,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
TLSHandshakeTimeout: defaultTLSHandshakeTimeout, TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
}, },
Timeout: timeout,
}, },
host: cfg.Host, host: cfg.Host,
} }
@ -83,7 +82,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
tcpClient: &dns.Client{ tcpClient: &dns.Client{
TLSConfig: &tlsConfig, TLSConfig: &tlsConfig,
Net: cfg.Net.String(), Net: cfg.Net.String(),
Timeout: timeout,
SingleInflight: true, SingleInflight: true,
}, },
} }
@ -92,12 +90,10 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
return &dnsUpstreamClient{ return &dnsUpstreamClient{
tcpClient: &dns.Client{ tcpClient: &dns.Client{
Net: "tcp", Net: "tcp",
Timeout: timeout,
SingleInflight: true, SingleInflight: true,
}, },
udpClient: &dns.Client{ udpClient: &dns.Client{
Net: "udp", Net: "udp",
Timeout: timeout,
SingleInflight: true, SingleInflight: true,
}, },
} }
@ -112,8 +108,8 @@ func (r *httpUpstreamClient) fmtURL(ip net.IP, port uint16, path string) string
return fmt.Sprintf("https://%s%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), path) return fmt.Sprintf("https://%s%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), path)
} }
func (r *httpUpstreamClient) callExternal(msg *dns.Msg, func (r *httpUpstreamClient) callExternal(
upstreamURL string, _ model.RequestProtocol, ctx context.Context, msg *dns.Msg, upstreamURL string, _ model.RequestProtocol,
) (*dns.Msg, time.Duration, error) { ) (*dns.Msg, time.Duration, error) {
start := time.Now() start := time.Now()
@ -122,7 +118,7 @@ func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
return nil, 0, fmt.Errorf("can't pack message: %w", err) return nil, 0, fmt.Errorf("can't pack message: %w", err)
} }
req, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("can't create the new request %w", err) return nil, 0, fmt.Errorf("can't create the new request %w", err)
} }
@ -169,16 +165,16 @@ func (r *dnsUpstreamClient) fmtURL(ip net.IP, port uint16, _ string) string {
return net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) return net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
} }
func (r *dnsUpstreamClient) callExternal(msg *dns.Msg, func (r *dnsUpstreamClient) callExternal(
upstreamURL string, protocol model.RequestProtocol, ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error) { ) (response *dns.Msg, rtt time.Duration, err error) {
if protocol == model.RequestProtocolTCP { if protocol == model.RequestProtocolTCP {
response, rtt, err = r.tcpClient.Exchange(msg, upstreamURL) response, rtt, err = r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
if err != nil { if err != nil && r.udpClient != nil {
// try UDP as fallback // try UDP as fallback
var opErr *net.OpError var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" && r.udpClient != nil { if errors.As(err, &opErr) && opErr.Op == "dial" {
return r.udpClient.Exchange(msg, upstreamURL) return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
} }
} }
@ -186,18 +182,20 @@ func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
} }
if r.udpClient != nil { if r.udpClient != nil {
return r.udpClient.Exchange(msg, upstreamURL) return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
} }
return r.tcpClient.Exchange(msg, upstreamURL) return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
} }
// NewUpstreamResolver creates new resolver instance // NewUpstreamResolver creates new resolver instance
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap, verify bool) (*UpstreamResolver, error) { func NewUpstreamResolver(
ctx context.Context, upstream config.Upstream, bootstrap *Bootstrap, verify bool,
) (*UpstreamResolver, error) {
r := newUpstreamResolverUnchecked(upstream, bootstrap) r := newUpstreamResolverUnchecked(upstream, bootstrap)
if verify { if verify {
_, err := r.bootstrap.UpstreamIPs(r) _, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -234,50 +232,41 @@ func (r UpstreamResolver) String() string {
} }
// Resolve calls external resolver // Resolve calls external resolver
func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) { func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
ips, err := r.bootstrap.UpstreamIPs(r) ips, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var ( var (
rtt time.Duration
resp *dns.Msg resp *dns.Msg
ip net.IP ip net.IP
) )
err = retry.Do( err = retry.Do(
func() error { func() error {
ctx, cancel := context.WithTimeout(ctx, config.GetConfig().Upstreams.Timeout.ToDuration())
defer cancel()
ip = ips.Current() ip = ips.Current()
upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path) upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path)
var err error response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol) if err != nil {
if err == nil { return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.upstream.String(),
"upstream_ip": ip.String(),
"protocol": request.Protocol,
"net": r.upstream.Net,
"response_time_ms": rtt.Milliseconds(),
}).Debugf("received response from upstream")
return nil
} }
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err) resp = response
r.logResponse(request, response, ip, rtt)
return nil
}, },
retry.Context(ctx),
retry.Attempts(retryAttempts), retry.Attempts(retryAttempts),
retry.DelayType(retry.FixedDelay), retry.DelayType(retry.FixedDelay),
retry.Delay(1*time.Millisecond), retry.Delay(1*time.Millisecond),
retry.LastErrorOnly(true), retry.LastErrorOnly(true),
retry.RetryIf(func(err error) bool { retry.RetryIf(isTimeout),
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}),
retry.OnRetry(func(n uint, err error) { retry.OnRetry(func(n uint, err error) {
r.log().WithFields(logrus.Fields{ r.log().WithFields(logrus.Fields{
"upstream": r.upstream.String(), "upstream": r.upstream.String(),
@ -289,8 +278,31 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
ips.Next() ips.Next()
})) }))
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
// Make the error more user friendly than just "context deadline exceeded"
err = fmt.Errorf("timeout (%w)", err)
}
return nil, err return nil, err
} }
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil
} }
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) {
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.upstream.String(),
"upstream_ip": ip.String(),
"protocol": request.Protocol,
"net": r.upstream.Net,
"response_time_ms": rtt.Milliseconds(),
}).Debugf("received response from upstream")
}
func isTimeout(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
@ -21,9 +22,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
var ( var (
sut *UpstreamResolver sut *UpstreamResolver
sutConfig config.Upstream sutConfig config.Upstream
ctx context.Context
cancelFn context.CancelFunc
) )
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = config.Upstream{Host: "localhost"} sutConfig = config.Upstream{Host: "localhost"}
}) })
@ -62,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start() upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil) sut := newUpstreamResolverUnchecked(upstream, nil)
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -81,7 +88,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start() upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil) sut := newUpstreamResolverUnchecked(upstream, nil)
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
HaveNoAnswer(), HaveNoAnswer(),
@ -100,7 +107,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start() upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil) sut := newUpstreamResolverUnchecked(upstream, nil)
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
}) })
}) })
@ -133,7 +140,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
atomic.StoreInt32(&counter, 0) atomic.StoreInt32(&counter, 0)
atomic.StoreInt32(&attemptsWithTimeout, 2) atomic.StoreInt32(&attemptsWithTimeout, 2)
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -146,12 +153,37 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
By("3 attempts with timeout -> should return error", func() { By("3 attempts with timeout -> should return error", func() {
atomic.StoreInt32(&counter, 0) atomic.StoreInt32(&counter, 0)
atomic.StoreInt32(&attemptsWithTimeout, 3) atomic.StoreInt32(&attemptsWithTimeout, 3)
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("i/o timeout")) Expect(err.Error()).Should(ContainSubstring("i/o timeout"))
}) })
}) })
}) })
When("user request is TCP", func() {
When("TCP upstream connection fails", func() {
BeforeEach(func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sutConfig = mockUpstream.Start()
})
It("should retry with UDP", func() {
req := newRequest("example.com.", A)
req.Protocol = RequestProtocolTCP
Expect(sut.Resolve(ctx, req)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
))
})
})
})
}) })
Describe("Using Dns over HTTP (DOH) upstream", func() { Describe("Using Dns over HTTP (DOH) upstream", func() {
@ -185,7 +217,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}) })
When("Configured DOH resolver can resolve query", func() { When("Configured DOH resolver can resolve query", func() {
It("should return answer from DNS upstream", func() { It("should return answer from DNS upstream", func() {
Expect(sut.Resolve(newRequest("example.com.", A))). Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"), BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -203,7 +235,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
} }
}) })
It("should return error", func() { It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500")) Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
}) })
@ -215,7 +247,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
} }
}) })
It("should return error", func() { It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should( Expect(err.Error()).Should(
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'")) ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
@ -228,7 +260,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
} }
}) })
It("should return error", func() { It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("can't unpack message")) Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
}) })
@ -241,7 +273,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}, systemResolverBootstrap) }, systemResolverBootstrap)
}) })
It("should return error", func() { It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A)) _, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred()) Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(Or( Expect(err.Error()).Should(Or(
ContainSubstring("no such host"), ContainSubstring("no such host"),

View File

@ -1,6 +1,7 @@
package resolver package resolver
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
@ -64,7 +65,7 @@ func (r *UpstreamTreeResolver) String() string {
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", ")) return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
} }
func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) { func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, upstreamTreeResolverType) logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
group := r.upstreamGroupByClient(request) group := r.upstreamGroupByClient(request)
@ -72,7 +73,7 @@ func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response,
// delegate request to group resolver // delegate request to group resolver
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver") logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
return r.branches[group].Resolve(request) return r.branches[group].Resolve(ctx, request)
} }
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string { func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {

View File

@ -1,6 +1,8 @@
package resolver package resolver
import ( import (
"context"
"github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest" . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/log"
@ -139,7 +141,15 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
}) })
When("client specific resolvers are defined", func() { When("client specific resolvers are defined", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() { BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{ sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {config.Upstream{}}, upstreamDefaultCfgName: {config.Upstream{}},
"laptop": {config.Upstream{}}, "laptop": {config.Upstream{}},
@ -191,7 +201,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use default if client name or IP don't match", func() { It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test") request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "default"), BeDNSRecord("example.com.", A, "default"),
@ -202,7 +212,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name matches exact", func() { It("Should use client specific resolver if client name matches exact", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop") request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"), BeDNSRecord("example.com.", A, "laptop"),
@ -213,7 +223,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name matches with wildcard", func() { It("Should use client specific resolver if client name matches with wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m") request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "client-*-m"), BeDNSRecord("example.com.", A, "client-*-m"),
@ -224,7 +234,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name matches with range wildcard", func() { It("Should use client specific resolver if client name matches with range wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7") request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "client[0-9]"), BeDNSRecord("example.com.", A, "client[0-9]"),
@ -235,7 +245,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client IP matches", func() { It("Should use client specific resolver if client IP matches", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname") request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"), BeDNSRecord("example.com.", A, "192.168.178.33"),
@ -246,7 +256,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name (containing IP) matches", func() { It("Should use client specific resolver if client name (containing IP) matches", func() {
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33") request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"), BeDNSRecord("example.com.", A, "192.168.178.33"),
@ -257,7 +267,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() { It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname") request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "10.43.8.67/28"), BeDNSRecord("example.com.", A, "10.43.8.67/28"),
@ -268,7 +278,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use exact IP match before client name match", func() { It("Should use exact IP match before client name match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop") request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"), BeDNSRecord("example.com.", A, "192.168.178.33"),
@ -279,7 +289,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client name match before CIDR match", func() { It("Should use client name match before CIDR match", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop") request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"), BeDNSRecord("example.com.", A, "laptop"),
@ -293,7 +303,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1") request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
request.Log = logger request.Log = logger
Expect(sut.Resolve(request)). Expect(sut.Resolve(ctx, request)).
Should( Should(
SatisfyAll( SatisfyAll(
SatisfyAny( SatisfyAny(

View File

@ -389,7 +389,7 @@ func createQueryResolver(
bootstrap *resolver.Bootstrap, bootstrap *resolver.Bootstrap,
redisClient *redis.Client, redisClient *redis.Client,
) (r resolver.ChainedResolver, err error) { ) (r resolver.ChainedResolver, err error) {
upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap) upstreamBranches, uErr := createUpstreamBranches(ctx, cfg, bootstrap)
if uErr != nil { if uErr != nil {
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr) return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
} }
@ -398,7 +398,9 @@ func createQueryResolver(
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap) blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream) condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(
ctx, cfg.Conditional, bootstrap, cfg.StartVerifyUpstream,
)
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap) hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
err = multierror.Append( err = multierror.Append(
@ -433,6 +435,7 @@ func createQueryResolver(
} }
func createUpstreamBranches( func createUpstreamBranches(
ctx context.Context,
cfg *config.Config, cfg *config.Config,
bootstrap *resolver.Bootstrap, bootstrap *resolver.Bootstrap,
) (map[string]resolver.Resolver, error) { ) (map[string]resolver.Resolver, error) {
@ -453,11 +456,11 @@ func createUpstreamBranches(
switch cfg.Upstreams.Strategy { switch cfg.Upstreams.Strategy {
case config.UpstreamStrategyParallelBest: case config.UpstreamStrategyParallelBest:
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream) upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyStrict: case config.UpstreamStrategyStrict:
upstream, err = resolver.NewStrictResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream) upstream, err = resolver.NewStrictResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyRandom: case config.UpstreamStrategyRandom:
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream) upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
} }
upstreamBranches[group] = upstream upstreamBranches[group] = upstream
@ -643,7 +646,7 @@ func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
r := createResolverRequest(w, request) r := createResolverRequest(w, request)
response, err := s.queryResolver.Resolve(r) response, err := s.queryResolver.Resolve(context.Background(), r)
if err != nil { if err != nil {
logger().Error("error on processing request:", err) logger().Error("error on processing request:", err)

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"html/template" "html/template"
@ -149,7 +150,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg) r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)
resResponse, err := s.queryResolver.Resolve(r) resResponse, err := s.queryResolver.Resolve(req.Context(), r)
if err != nil { if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw) logAndResponseWithError(err, "unable to process query: ", rw)
@ -192,11 +193,11 @@ func extractIP(r *http.Request) string {
return hostPort return hostPort
} }
func (s *Server) Query(question string, qType dns.Type) (*model.Response, error) { func (s *Server) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType) dnsRequest := util.NewMsgWithQuestion(question, qType)
r := createResolverRequest(nil, dnsRequest) r := createResolverRequest(nil, dnsRequest)
return s.queryResolver.Resolve(r) return s.queryResolver.Resolve(ctx, r)
} }
func createHTTPSRouter(cfg *config.Config) *chi.Mux { func createHTTPSRouter(cfg *config.Config) *chi.Mux {

View File

@ -76,10 +76,9 @@ var _ = BeforeSuite(func() {
clientMockUpstream = resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { clientMockUpstream = resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
var clientName string var clientName string
client := mockClientName.Load()
if client != nil { if name, ok := mockClientName.Load().(string); ok {
clientName = mockClientName.Load().(string) clientName = name
} }
response, err := util.NewMsgWithAnswer( response, err := util.NewMsgWithAnswer(
@ -118,8 +117,7 @@ var _ = BeforeSuite(func() {
youtubeFile := tmpDir.CreateStringFile("youtube.com.txt", "youtube.com") youtubeFile := tmpDir.CreateStringFile("youtube.com.txt", "youtube.com")
Expect(youtubeFile.Error).Should(Succeed()) Expect(youtubeFile.Error).Should(Succeed())
// create server cfg := &config.Config{
sut, err = NewServer(ctx, &config.Config{
CustomDNS: config.CustomDNSConfig{ CustomDNS: config.CustomDNSConfig{
CustomTTL: config.Duration(3600 * time.Second), CustomTTL: config.Duration(3600 * time.Second),
Mapping: config.CustomDNSMapping{ Mapping: config.CustomDNSMapping{
@ -160,7 +158,8 @@ var _ = BeforeSuite(func() {
BlockTTL: config.Duration(6 * time.Hour), BlockTTL: config.Duration(6 * time.Hour),
}, },
Upstreams: config.UpstreamsConfig{ Upstreams: config.UpstreamsConfig{
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}}, Timeout: config.Duration(250 * time.Millisecond),
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
}, },
ClientLookup: config.ClientLookupConfig{ ClientLookup: config.ClientLookupConfig{
Upstream: upstreamClient, Upstream: upstreamClient,
@ -178,16 +177,19 @@ var _ = BeforeSuite(func() {
Enable: true, Enable: true,
Path: "/metrics", Path: "/metrics",
}, },
}) }
// Hacky but needed to update the global since we still have code that reads it
*config.GetConfig() = *cfg
// create server
sut, err = NewServer(ctx, cfg)
Expect(err).Should(Succeed()) Expect(err).Should(Succeed())
errChan := make(chan error, 10) errChan := make(chan error, 10)
// start server // start server
go func() { go sut.Start(ctx, errChan)
sut.Start(ctx, errChan)
}()
DeferCleanup(sut.Stop) DeferCleanup(sut.Stop)
Consistently(errChan, "1s").ShouldNot(Receive()) Consistently(errChan, "1s").ShouldNot(Receive())
@ -697,7 +699,7 @@ var _ = Describe("Running DNS server", func() {
Describe("NewServer with strict upstream strategy", func() { Describe("NewServer with strict upstream strategy", func() {
It("successfully returns upstream branches", func() { It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(&config.Config{ branches, err := createUpstreamBranches(context.Background(), &config.Config{
Upstreams: config.UpstreamsConfig{ Upstreams: config.UpstreamsConfig{
Strategy: config.UpstreamStrategyStrict, Strategy: config.UpstreamStrategyStrict,
Groups: config.UpstreamGroups{ Groups: config.UpstreamGroups{
@ -715,7 +717,7 @@ var _ = Describe("Running DNS server", func() {
Describe("NewServer with random upstream strategy", func() { Describe("NewServer with random upstream strategy", func() {
It("successfully returns upstream branches", func() { It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(&config.Config{ branches, err := createUpstreamBranches(context.Background(), &config.Config{
Upstreams: config.UpstreamsConfig{ Upstreams: config.UpstreamsConfig{
Strategy: config.UpstreamStrategyRandom, Strategy: config.UpstreamStrategyRandom,
Groups: config.UpstreamGroups{ Groups: config.UpstreamGroups{