refactor: make use of contexts in more places

- `CacheControl.FlushCaches`
- `Querier.Query`
- `Resolver.Resolve`

Besides all the API churn, this leads to `ParallelBestResolver`,
`StrictResolver` and `UpstreamResolver` simplification: timeouts only
need to be setup in one place, `UpstreamResolver`.

We also benefit from using HTTP request contexts, so if the client
closes the connection we stop processing on our side.
This commit is contained in:
ThinkChaos 2023-11-19 15:47:50 -05:00
parent e4ebc16ccc
commit eae99ec550
52 changed files with 796 additions and 627 deletions

View File

@ -40,11 +40,11 @@ type ListRefresher interface {
}
type Querier interface {
Query(question string, qType dns.Type) (*model.Response, error)
Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
}
type CacheControl interface {
FlushCaches()
FlushCaches(ctx context.Context)
}
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
@ -137,13 +137,13 @@ func (i *OpenAPIInterfaceImpl) ListRefresh(_ context.Context,
return ListRefresh200Response{}, nil
}
func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObject) (QueryResponseObject, error) {
func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestObject) (QueryResponseObject, error) {
qType := dns.Type(dns.StringToType[request.Body.Type])
if qType == dns.Type(dns.TypeNone) {
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
}
resp, err := i.querier.Query(dns.Fqdn(request.Body.Query), qType)
resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
if err != nil {
return nil, err
}
@ -156,10 +156,10 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
}), nil
}
func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context,
func (i *OpenAPIInterfaceImpl) CacheFlush(ctx context.Context,
_ CacheFlushRequestObject,
) (CacheFlushResponseObject, error) {
i.cacheControl.FlushCaches()
i.cacheControl.FlushCaches(ctx)
return CacheFlush200Response{}, nil
}

View File

@ -5,7 +5,6 @@ import (
"errors"
"time"
// . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
return args.Get(0).(BlockingStatus)
}
func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) {
args := m.Called(question, qType)
func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
args := m.Called(ctx, question, qType)
return args.Get(0).(*model.Response), args.Error(1)
}
func (m *CacheControlMock) FlushCaches() {
_ = m.Called()
func (m *CacheControlMock) FlushCaches(ctx context.Context) {
_ = m.Called(ctx)
}
var _ = Describe("API implementation tests", func() {
@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
listRefreshMock *ListRefreshMock
cacheControlMock *CacheControlMock
sut *OpenAPIInterfaceImpl
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
blockingControlMock = &BlockingControlMock{}
querierMock = &QuerierMock{}
listRefreshMock = &ListRefreshMock{}
@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
)
Expect(err).Should(Succeed())
querierMock.On("Query", "google.com.", A).Return(&model.Response{
querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
Res: queryResponse,
Reason: "reason",
}, nil)
resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com", Type: "A",
},
@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
})
It("should return 400 on wrong parameter", func() {
resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com",
Type: "WRONGTYPE",
@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
listRefreshMock.On("RefreshLists").Return(nil)
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp200 ListRefresh200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 500 on failure", func() {
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp500 ListRefresh500TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp500))
@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
duration := "3s"
grroups := "gr1,gr2"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &duration,
Groups: &grroups,
@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 400 on failure", func() {
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{})
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp400 DisableBlocking400TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp400))
@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 400 on wrong duration parameter", func() {
wrongDuration := "4sds"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &wrongDuration,
},
@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
blockingControlMock.On("EnableBlocking").Return()
resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{})
resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp200 EnableBlocking200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
AutoEnableInSec: 47,
})
resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{})
resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
Expect(err).Should(Succeed())
var resp200 BlockingStatus200JSONResponse
Expect(resp).Should(BeAssignableToTypeOf(resp200))
@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
Describe("Cache API", func() {
When("Cache flush is called", func() {
It("should return 200 on success", func() {
cacheControlMock.On("FlushCaches").Return()
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
cacheControlMock.On("FlushCaches", ctx).Return()
resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
Expect(err).Should(Succeed())
var resp200 CacheFlush200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))

View File

@ -37,7 +37,7 @@ type Options struct {
// OnExpirationCallback will be called just before an element gets expired and will
// be removed from cache. This function can return new value and TTL to leave the
// element in the cache or nil to remove it
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
type OnExpirationCallback[T any] func(ctx context.Context, key string) (val *T, ttl time.Duration)
// OnCacheHitCallback will be called on cache get if entry was found
type OnCacheHitCallback func(key string)
@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
return nil, 0
},
onCacheHit: func(key string) {},
@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
var keysToDelete []string
for _, key := range expiredKeys {
newVal, newTTL := e.preExpirationFn(key)
newVal, newTTL := e.preExpirationFn(context.Background(), key)
if newVal != nil {
e.Put(key, newVal, newTTL)
} else {

View File

@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
Describe("preExpiration function", func() {
When("function is defined", func() {
It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "v2"
return &v2, time.Second
@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
})
It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "val2"
return &v2, time.Second
@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
})
It("should delete the key if function returns nil", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
return nil, 0
}
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)

View File

@ -25,11 +25,11 @@ type cacheValue[T any] struct {
type OnEntryReloadedCallback func(key string)
// ReloadEntryFn reloads a prefetched entry by key
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
type ReloadEntryFn[T any] func(ctx context.Context, key string) (*T, time.Duration)
type PrefetchingOptions[T any] struct {
Options
ReloadFn func(cacheKey string) (*T, time.Duration)
ReloadFn ReloadEntryFn[T]
PrefetchThreshold int
PrefetchExpires time.Duration
PrefetchMaxItemsCount int
@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
}
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
func (e *PrefetchingExpiringLRUCache[T]) onExpired(
ctx context.Context, cacheKey string,
) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey)
loadedVal, ttl := e.reloadFn(ctx, cacheKey)
if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey)

View File

@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond
@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"
return &v, 50 * time.Millisecond

View File

@ -154,16 +154,16 @@ func NewBlockingResolver(ctx context.Context,
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
CleanupInterval: defaultBlockingCleanUpInterval,
}, func(key string) (val *[]net.IP, ttl time.Duration) {
return res.queryForFQIdentifierIPs(key)
}, func(ctx context.Context, key string) (val *[]net.IP, ttl time.Duration) {
return res.queryForFQIdentifierIPs(ctx, key)
})
if res.redisClient != nil {
setupRedisEnabledSubscriber(ctx, res)
go res.redisSubscriber(ctx)
}
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
go res.initFQDNIPCache()
go res.initFQDNIPCache(ctx)
})
if err != nil {
@ -173,29 +173,27 @@ func NewBlockingResolver(ctx context.Context,
return res, nil
}
func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) {
go func() {
for {
select {
case em := <-c.redisClient.EnabledChannel:
if em != nil {
c.log().Debug("Received state from redis: ", em)
func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
for {
select {
case em := <-r.redisClient.EnabledChannel:
if em != nil {
r.log().Debug("Received state from redis: ", em)
if em.State {
c.internalEnableBlocking()
} else {
err := c.internalDisableBlocking(em.Duration, em.Groups)
if err != nil {
c.log().Warn("Blocking couldn't be disabled:", err)
}
if em.State {
r.internalEnableBlocking()
} else {
err := r.internalDisableBlocking(em.Duration, em.Groups)
if err != nil {
r.log().Warn("Blocking couldn't be disabled:", err)
}
}
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}()
}
}
// RefreshLists triggers the refresh of all black and white lists in the cache
@ -358,7 +356,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
return false
}
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []string,
request *model.Request, logger *logrus.Entry,
) (bool, *model.Response, error) {
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
@ -371,7 +369,7 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
logger.WithField("groups", groups).Debugf("domain is whitelisted")
resp, err := r.next.Resolve(request)
resp, err := r.next.Resolve(ctx, request)
return true, resp, err
}
@ -393,18 +391,18 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
}
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "blacklist_resolver")
groupsToCheck := r.groupsToCheckForClient(request)
if len(groupsToCheck) > 0 {
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
handled, resp, err := r.handleBlacklist(ctx, groupsToCheck, request, logger)
if handled {
return resp, err
}
}
respFromNext, err := r.next.Resolve(request)
respFromNext, err := r.next.Resolve(ctx, request)
if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil {
for _, rr := range respFromNext.Res.Answer {
@ -574,7 +572,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
}
}
func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP, time.Duration) {
func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
var result []net.IP
@ -582,7 +580,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
var ttl time.Duration
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(&model.Request{
resp, err := r.next.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
Log: prefixedLog,
})
@ -606,12 +604,12 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
return &result, ttl
}
func (r *BlockingResolver) initFQDNIPCache() {
func (r *BlockingResolver) initFQDNIPCache(ctx context.Context) {
identifiers := maps.Keys(r.clientGroupsBlock)
for _, identifier := range identifiers {
if isFQDN(identifier) {
iPs, ttl := r.queryForFQIdentifierIPs(identifier)
iPs, ttl := r.queryForFQIdentifierIPs(ctx, identifier)
r.fqdnIPCache.Put(identifier, iPs, ttl)
}
}

View File

@ -155,7 +155,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}
Bus().Publish(ApplicationStarted, "")
Eventually(func(g Gomega) {
g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
g.Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
Should(And(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 60)),
@ -185,6 +185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Domain is on the black list", func() {
It("should block request", func() {
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
Should(
SatisfyAll(
@ -222,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("client name is defined in client groups block", func() {
It("should block the A query if domain is on the black list (single)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -233,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block the A query if domain is on the black list (multipart 1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -244,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block the A query if domain is on the black list (multipart 2)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -255,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block the A query if domain is on the black list (merged)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
Should(
SatisfyAll(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
@ -266,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block the AAAA query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", AAAA, "::"),
@ -277,18 +278,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block the HTTPS query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
Should(HaveReturnCode(dns.RcodeNameError))
})
It("should block the MX query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
Should(HaveReturnCode(dns.RcodeNameError))
})
})
When("Client ip is defined in client groups block", func() {
It("should block the query if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -301,7 +302,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -313,7 +314,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
m.AssertExpectations(GinkgoT())
})
It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -328,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
It("should block the query for 10.43.8.64 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -339,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block the query for 10.43.8.79 if domain is on the black list", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -353,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Client has multiple names and for each name a client group block definition exists", func() {
It("should block query if domain is in one group", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -364,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
It("should block query if domain is in another group too", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
Should(
SatisfyAll(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
@ -377,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("Client name matches wildcard", func() {
It("should block query if domain is in one group", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -391,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Default group is defined", func() {
It("should block domains from default group for each client", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -418,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should return NXDOMAIN if query is blocked", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -444,7 +445,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should return answer with specified TTL", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -461,7 +462,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should return custom IP with specified TTL", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
@ -489,7 +490,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should return ipv4 address for A query if query is blocked", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
@ -501,7 +502,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should return ipv6 address for AAAA query if query is blocked", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
@ -528,7 +529,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should use fallback for ipv6 and return zero ip", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", AAAA, "::"),
@ -547,7 +548,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
})
It("should block query, if lookup result contains blacklisted IP", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "0.0.0.0"),
@ -567,7 +568,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
)
})
It("should block query, if lookup result contains blacklisted IP", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", AAAA, "::"),
@ -590,7 +591,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
})
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "0.0.0.0"),
@ -617,7 +618,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}
})
It("Should not be blocked", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -649,7 +650,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should block everything else except domains on the white list with default group", func() {
By("querying domain on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -662,7 +663,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("querying another domain, which is not on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("google.com.", A, "0.0.0.0"),
@ -678,7 +679,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -691,7 +692,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("querying another domain, which is not on the whitelist", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
Should(
SatisfyAll(
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
@ -706,7 +707,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist group 1", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -719,7 +720,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("querying another domain, which is in the whitelist group 1", func() {
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -745,7 +746,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
})
It("should not block if DNS answer contains IP from the white list", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.145.123.145"),
@ -775,7 +776,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("domain is not on the black list", func() {
It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -792,7 +793,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}
})
It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -819,7 +820,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking is called", func() {
It("no query should be blocked", func() {
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -830,7 +831,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -847,7 +848,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again (defaultGroup)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -861,7 +862,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again (group1)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -880,7 +881,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again (defaultGroup)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -893,7 +894,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -908,7 +909,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking for all groups is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() {
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -918,7 +919,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -941,7 +942,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -954,7 +955,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -974,7 +975,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
// wait 1 sec
Eventually(enabled, "1s").Should(Receive(BeTrue()))
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -983,7 +984,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -998,7 +999,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking for one group is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() {
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -1008,7 +1009,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
))
})
By("Perform query to ensure that the blocking status is active (group1)", func() {
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
@ -1031,7 +1032,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -1042,7 +1043,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
// now is blocking disabled, query the url again
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -1062,7 +1063,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
// wait 1 sec
Eventually(enabled, "1s").Should(Receive(BeTrue()))
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
@ -1071,7 +1072,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
BeDNSRecord("domain1.com.", A, "0.0.0.0"),

View File

@ -106,14 +106,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
return b, nil
}
func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
hostname := r.upstream.Host
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
return newIPSet([]net.IP{ip}), nil
}
ips, err := b.resolveUpstream(r, hostname)
ips, err := b.resolveUpstream(ctx, r, hostname)
if err != nil {
return nil, err
}
@ -121,10 +121,10 @@ func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
return newIPSet(ips), nil
}
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) {
// Use system resolver if no bootstrap is configured
if b.resolver == nil {
ctx, cancel := context.WithTimeout(context.Background(), b.timeout)
ctx, cancel := context.WithTimeout(ctx, b.timeout)
defer cancel()
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
@ -135,7 +135,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
return ips, nil
}
return b.resolve(host, b.connectIPVersion.QTypes())
return b.resolve(ctx, host, b.connectIPVersion.QTypes())
}
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
@ -175,7 +175,7 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
}
// Resolve the host with the bootstrap DNS
ips, err := b.resolve(host, qTypes)
ips, err := b.resolve(ctx, host, qTypes)
if err != nil {
logger.Errorf("resolve error: %s", err)
@ -192,11 +192,11 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
return b.dialer.DialContext(ctx, network, addrWithIP)
}
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
func (b *Bootstrap) resolve(ctx context.Context, hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
ips = make([]net.IP, 0, len(qTypes))
for _, qType := range qTypes {
qIPs, qErr := b.resolveType(hostname, qType)
qIPs, qErr := b.resolveType(ctx, hostname, qType)
if qErr != nil {
err = multierror.Append(err, qErr)
@ -213,7 +213,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
return
}
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.Type) (ips []net.IP, err error) {
if ip := net.ParseIP(hostname); ip != nil {
return []net.IP{ip}, nil
}
@ -223,7 +223,7 @@ func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP,
Log: b.log,
}
rsp, err := b.resolver.Resolve(&req)
rsp, err := b.resolver.Resolve(ctx, &req)
if err != nil {
return nil, err
}

View File

@ -72,7 +72,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
},
}
_, err := sut.resolveUpstream(nil, "example.com")
_, err := sut.resolveUpstream(ctx, nil, "example.com")
Expect(err).ShouldNot(Succeed())
Expect(usedSystemResolver).Should(Receive(BeTrue()))
})
@ -244,7 +244,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
When("called from bootstrap.upstream", func() {
It("uses hardcoded IPs", func() {
ips, err := sut.resolveUpstream(bootstrapUpstream, "host")
ips, err := sut.resolveUpstream(ctx, bootstrapUpstream, "host")
Expect(err).Should(Succeed())
Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs))
@ -253,7 +253,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
When("hostname is an IP", func() {
It("returns immediately", func() {
ips, err := sut.resolve("0.0.0.0", config.IPVersionDual.QTypes())
ips, err := sut.resolve(ctx, "0.0.0.0", config.IPVersionDual.QTypes())
Expect(err).Should(Succeed())
Expect(ips).Should(ContainElement(net.IPv4zero))
@ -269,7 +269,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
ips, err := sut.resolve("localhost", []dns.Type{AAAA})
ips, err := sut.resolve(ctx, "localhost", []dns.Type{AAAA})
Expect(err).Should(Succeed())
Expect(ips).Should(HaveLen(1))
@ -283,7 +283,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr)
ips, err := sut.resolve("localhost", []dns.Type{A})
ips, err := sut.resolve(ctx, "localhost", []dns.Type{A})
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
@ -297,7 +297,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A})
ips, err := sut.resolve(ctx, "unknownhost.invalid", []dns.Type{A})
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no such host"))
@ -329,7 +329,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
r := newUpstreamResolverUnchecked(upstream, sut)
rsp, err := r.Resolve(mainReq)
rsp, err := r.Resolve(ctx, mainReq)
Expect(err).Should(Succeed())
Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1))
Expect(rsp.Res.Question[0].Name).Should(Equal("example.com."))
@ -373,7 +373,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
// implicit expectation of 0 bootstrapUpstream.Resolve calls
_, err = t.DialContext(context.Background(), "ip", "!bad-addr!")
_, err = t.DialContext(ctx, "ip", "!bad-addr!")
Expect(err).ShouldNot(Succeed())
})
@ -384,7 +384,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
t := sut.NewHTTPTransport()
_, err = t.DialContext(context.Background(), "ip", "abc:123")
_, err = t.DialContext(ctx, "ip", "abc:123")
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
@ -397,7 +397,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
t := sut.NewHTTPTransport()
_, err = t.DialContext(context.Background(), "ip", "abc:123")
_, err = t.DialContext(ctx, "ip", "abc:123")
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(ContainSubstring("no such host"))
@ -437,7 +437,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
Describe("resolve", func() {
AfterEach(func() {
_, err := sut.resolveUpstream(nil, "example.com")
_, err := sut.resolveUpstream(ctx, nil, "example.com")
Expect(err).Should(Succeed())
m.AssertExpectations(GinkgoT())
@ -501,7 +501,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
AfterEach(func() {
t := sut.NewHTTPTransport()
conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0")
conn, err := t.DialContext(ctx, dialIPVersion.Net(), "localhost:0")
Expect(err).Should(Succeed())
Expect(conn).Should(Equal(aMockConn))
@ -583,7 +583,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
})
It("uses both", func() {
_, err := sut.resolve("example.com.", []dns.Type{dns.Type(dns.TypeA)})
_, err := sut.resolve(ctx, "example.com.", []dns.Type{dns.Type(dns.TypeA)})
Expect(err).To(Succeed())

View File

@ -59,7 +59,7 @@ func newCachingResolver(ctx context.Context,
configureCaches(ctx, c, &cfg)
if c.redisClient != nil {
setupRedisCacheSubscriber(ctx, c)
go c.redisSubscriber(ctx)
c.redisClient.GetRedisCache()
}
@ -105,14 +105,14 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
}
}
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Duration) {
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log()
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
req := newRequest(dns.Fqdn(domainName), qType, logger)
response, err := r.next.Resolve(req)
response, err := r.next.Resolve(ctx, req)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
@ -132,22 +132,20 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Durat
return nil, 0
}
func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) {
go func() {
for {
select {
case rc := <-c.redisClient.CacheChannel:
if rc != nil {
c.log().Debug("Received key from redis: ", rc.Key)
ttl := c.adjustTTLs(rc.Response.Res.Answer)
c.putInCache(rc.Key, rc.Response, ttl, false)
}
case <-ctx.Done():
return
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
for {
select {
case rc := <-r.redisClient.CacheChannel:
if rc != nil {
r.log().Debug("Received key from redis: ", rc.Key)
ttl := r.adjustTTLs(rc.Response.Res.Answer)
r.putInCache(rc.Key, rc.Response, ttl, false)
}
case <-ctx.Done():
return
}
}()
}
}
// LogConfig implements `config.Configurable`.
@ -159,13 +157,13 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
// Resolve checks if the current query should use the cache and if the result is already in
// the cache and returns it or delegates to the next resolver
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver")
if !r.IsEnabled() || !isRequestCacheable(request) {
logger.Debug("skip cache")
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
for _, question := range request.Req.Question {
@ -191,7 +189,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
}
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
response, err = r.next.Resolve(request)
response, err = r.next.Resolve(ctx, request)
if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer)
@ -319,7 +317,7 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
}
}
func (r *CachingResolver) FlushCaches() {
func (r *CachingResolver) FlushCaches(context.Context) {
r.log().Debug("flush caches")
r.resultCache.Clear()
}

View File

@ -114,7 +114,7 @@ var _ = Describe("CachingResolver", func() {
})).Should(Succeed())
// first request
_, _ = sut.Resolve(newRequest("example.com.", A))
_, _ = sut.Resolve(ctx, newRequest("example.com.", A))
// Domain is not prefetched
Expect(domainPrefetched).ShouldNot(Receive())
@ -124,7 +124,7 @@ var _ = Describe("CachingResolver", func() {
// now query again > threshold
for i := 0; i < prefetchThreshold+1; i++ {
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(Succeed())
}
@ -132,7 +132,7 @@ var _ = Describe("CachingResolver", func() {
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
// and it should hit from prefetch cache
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
@ -156,7 +156,7 @@ var _ = Describe("CachingResolver", func() {
})
It("should cache response and use response's TTL for multiple records", func() {
By("first request", func() {
result, err := sut.Resolve(newRequest("example.com.", A))
result, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(Succeed())
Expect(result).
Should(
@ -176,7 +176,7 @@ var _ = Describe("CachingResolver", func() {
By("second request", func() {
Eventually(func(g Gomega) {
result, err := sut.Resolve(newRequest("example.com.", A))
result, err := sut.Resolve(ctx, newRequest("example.com.", A))
g.Expect(err).Should(Succeed())
g.Expect(result).
Should(
@ -218,7 +218,7 @@ var _ = Describe("CachingResolver", func() {
_ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) {
totalCacheCount <- d
})
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -239,7 +239,7 @@ var _ = Describe("CachingResolver", func() {
domain <- true
})
g.Expect(sut.Resolve(newRequest("example.com.", A))).
g.Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
@ -264,7 +264,7 @@ var _ = Describe("CachingResolver", func() {
It("should cache response and use min caching time as TTL", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -276,7 +276,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", A)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
@ -299,7 +301,7 @@ var _ = Describe("CachingResolver", func() {
It("should cache response and use min caching time as TTL", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -310,7 +312,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
@ -344,7 +348,7 @@ var _ = Describe("CachingResolver", func() {
It("Shouldn't cache any responses", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -355,7 +359,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -378,7 +384,7 @@ var _ = Describe("CachingResolver", func() {
})
It("should cache response and use max caching time as TTL if response TTL is bigger", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -389,7 +395,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
@ -417,7 +425,7 @@ var _ = Describe("CachingResolver", func() {
})
It("should cache response and return 0 TTL if entry is expired", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -430,7 +438,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)).
Eventually(sut.Resolve, "2s").
WithContext(ctx).
WithArguments(newRequest("example.com.", A)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
@ -461,7 +471,7 @@ var _ = Describe("CachingResolver", func() {
})
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
@ -472,7 +482,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED NEGATIVE"),
@ -495,7 +507,7 @@ var _ = Describe("CachingResolver", func() {
It("response shouldn't be cached", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeNameError),
@ -506,7 +518,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReason(""),
@ -529,7 +543,7 @@ var _ = Describe("CachingResolver", func() {
It("response should be cached", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -540,7 +554,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("example.com.", AAAA)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"),
@ -563,7 +579,7 @@ var _ = Describe("CachingResolver", func() {
})
It("Should be cached", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", MX))).
Expect(sut.Resolve(ctx, newRequest("google.de.", MX))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -575,7 +591,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("google.de.", MX)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"),
@ -599,7 +617,7 @@ var _ = Describe("CachingResolver", func() {
})
It("Should not be cached", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))).
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -611,7 +629,7 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))).
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -633,7 +651,7 @@ var _ = Describe("CachingResolver", func() {
})
It("Should not be cached", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))).
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -645,7 +663,7 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))).
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -671,7 +689,7 @@ var _ = Describe("CachingResolver", func() {
})
It("Should not be cached", func() {
By("first request", func() {
Expect(sut.Resolve(newRequest("google.de.", A))).
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
Should(SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
@ -688,7 +706,9 @@ var _ = Describe("CachingResolver", func() {
})
By("second request", func() {
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", A)).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(newRequest("google.de.", A)).
Should(SatisfyAll(
HaveResponseType(ResponseTypeCACHED),
HaveReason("CACHED"),
@ -750,7 +770,7 @@ var _ = Describe("CachingResolver", func() {
})
It("put in redis", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(HaveResponseType(ResponseTypeRESOLVED))
Eventually(func() []string {
@ -772,7 +792,9 @@ var _ = Describe("CachingResolver", func() {
}
redisClient.CacheChannel <- redisMockMsg
Eventually(sut.Resolve).WithArguments(request).
Eventually(sut.Resolve).
WithContext(ctx).
WithArguments(request).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeCACHED),

View File

@ -32,7 +32,7 @@ func NewClientNamesResolver(ctx context.Context,
) (cr *ClientNamesResolver, err error) {
var r Resolver
if !cfg.Upstream.IsDefault() {
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams)
r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
@ -59,17 +59,17 @@ func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
}
// Resolve tries to resolve the client name from the ip address
func (r *ClientNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
clientNames := r.getClientNames(request)
func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
clientNames := r.getClientNames(ctx, request)
request.ClientNames = clientNames
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
// returns names of client
func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model.Request) []string {
if request.RequestClientID != "" {
return []string{request.RequestClientID}
}
@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
return cpy
}
names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver"))
names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))
r.cache.Put(ip.String(), &names, time.Hour)
@ -111,7 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
}
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
func (r *ClientNamesResolver) resolveClientNames(
ctx context.Context, ip net.IP, logger *logrus.Entry,
) (result []string) {
// try client mapping first
result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 {
@ -124,7 +126,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
reverse, _ := dns.ReverseAddr(ip.String())
resp, err := r.externalResolver.Resolve(&model.Request{
resp, err := r.externalResolver.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
Log: logger,
})

View File

@ -22,8 +22,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
sut *ClientNamesResolver
sutConfig config.ClientLookupConfig
m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -34,6 +35,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
JustBeforeEach(func() {
var err error
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
@ -71,7 +73,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should use clientID if set", func() {
request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -82,7 +84,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should use IP as fallback if clientID not set", func() {
request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -112,7 +114,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve defined name with ipv4 address", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -124,7 +126,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve defined name with ipv6 address", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -135,7 +137,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should resolve multiple names defined names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -168,7 +170,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve client name", func() {
By("first request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -180,7 +182,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
By("second request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -198,7 +200,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
By("third request", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -223,7 +225,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve all client names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -251,7 +253,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -272,7 +274,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should resolve the client name depending to defined order", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -298,7 +300,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -318,7 +320,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -335,7 +337,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should resolve no names", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -351,7 +353,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
})
It("should use fallback for client name", func() {
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"strings"
"github.com/0xERR0R/blocky/config"
@ -23,14 +24,14 @@ type ConditionalUpstreamResolver struct {
// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
@ -48,7 +49,9 @@ func NewConditionalUpstreamResolver(
return &r, nil
}
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
func (r *ConditionalUpstreamResolver) processRequest(
ctx context.Context, request *model.Request,
) (bool, *model.Response, error) {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
domain := domainFromQuestion
@ -56,7 +59,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
// try with domain with and without sub-domains
for len(domain) > 0 {
if resolver, found := r.mapping[domain]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
return true, resp, err
}
@ -68,7 +71,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
}
}
} else if resolver, found := r.mapping["."]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
return true, resp, err
}
@ -77,11 +80,11 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
}
// Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "conditional_resolver")
if len(r.mapping) > 0 {
resolved, resp, err := r.processRequest(request)
resolved, resp, err := r.processRequest(ctx, request)
if resolved {
return resp, err
}
@ -89,17 +92,17 @@ func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Re
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
func (r *ConditionalUpstreamResolver) internalResolve(reso Resolver, doFQ, do string,
func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso Resolver, doFQ, do string,
req *model.Request,
) (*model.Response, error) {
// internal request resolution
logger := log.WithPrefix(req.Log, "conditional_resolver")
req.Req.Question[0].Name = dns.Fqdn(doFQ)
response, err := reso.Resolve(req)
response, err := reso.Resolve(ctx, req)
if err == nil {
response.Reason = "CONDITIONAL"

View File

@ -19,6 +19,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
var (
sut *ConditionalUpstreamResolver
m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -28,6 +31,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
@ -59,7 +65,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
DeferCleanup(refuseTestUpstream.Close)
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{
"fritz.box": {fbTestUpstream.Start()},
@ -93,7 +99,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
Describe("Resolve conditional DNS queries via defined DNS server", func() {
When("conditional resolver returns error code", func() {
It("Should be returned without changes", func() {
Expect(sut.Resolve(newRequest("refused.domain.", A))).
Expect(sut.Resolve(ctx, newRequest("refused.domain.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -109,7 +115,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
When("Query is exact equal defined condition in mapping", func() {
Context("first mapping entry", func() {
It("Should resolve the IP of conditional DNS", func() {
Expect(sut.Resolve(newRequest("fritz.box.", A))).
Expect(sut.Resolve(ctx, newRequest("fritz.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("fritz.box.", A, "123.124.122.122"),
@ -125,7 +131,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
Context("last mapping entry", func() {
It("Should resolve the IP of conditional DNS", func() {
Expect(sut.Resolve(newRequest("other.box.", A))).
Expect(sut.Resolve(ctx, newRequest("other.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("other.box.", A, "192.192.192.192"),
@ -141,7 +147,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
When("Query is a subdomain of defined condition in mapping", func() {
It("Should resolve the IP of subdomain", func() {
Expect(sut.Resolve(newRequest("test.fritz.box.", A))).
Expect(sut.Resolve(ctx, newRequest("test.fritz.box.", A))).
Should(
SatisfyAll(
BeDNSRecord("test.fritz.box.", A, "123.124.122.122"),
@ -156,7 +162,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
})
When("Query is not fqdn and . condition is defined in mapping", func() {
It("Should resolve the IP of .", func() {
Expect(sut.Resolve(newRequest("test.", A))).
Expect(sut.Resolve(ctx, newRequest("test.", A))).
Should(
SatisfyAll(
BeDNSRecord("test.", A, "168.168.168.168"),
@ -173,7 +179,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
Describe("Delegation to next resolver", func() {
When("Query doesn't match defined mapping", func() {
It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequest("google.com.", A))).
Expect(sut.Resolve(ctx, newRequest("google.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -186,11 +192,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
When("upstream is invalid", func() {
It("errors during construction", func() {
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
Mapping: config.ConditionalUpstreamMapping{
Upstreams: map[string][]config.Upstream{
".": {config.Upstream{Host: "example.com"}},

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"net"
"strings"
@ -123,7 +124,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
}
// Resolve uses internal mapping to resolve the query
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
reverseResp := r.handleReverseDNS(request)
@ -140,5 +141,5 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"net"
"time"
@ -21,6 +22,9 @@ var _ = Describe("CustomDNSResolver", func() {
sut *CustomDNSResolver
m *mockResolver
cfg config.CustomDNSConfig
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -30,6 +34,9 @@ var _ = Describe("CustomDNSResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
cfg = config.CustomDNSConfig{
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{
"custom.domain": {net.ParseIP("192.168.143.123")},
@ -73,7 +80,7 @@ var _ = Describe("CustomDNSResolver", func() {
Context("filterUnmappedTypes is true", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = true })
It("defined ip4 query should be resolved", func() {
Expect(sut.Resolve(newRequest("custom.domain.", A))).
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
Should(
SatisfyAll(
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
@ -86,7 +93,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("TXT query for defined mapping should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -98,7 +105,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("ip6 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -114,7 +121,7 @@ var _ = Describe("CustomDNSResolver", func() {
Context("filterUnmappedTypes is false", func() {
BeforeEach(func() { cfg.FilterUnmappedTypes = false })
It("defined ip4 query should be resolved", func() {
Expect(sut.Resolve(newRequest("custom.domain.", A))).
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
Should(
SatisfyAll(
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
@ -127,7 +134,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
})
It("TXT query for defined mapping should be delegated to next resolver", func() {
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -139,7 +146,7 @@ var _ = Describe("CustomDNSResolver", func() {
m.AssertExpectations(GinkgoT())
})
It("ip6 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -154,7 +161,7 @@ var _ = Describe("CustomDNSResolver", func() {
})
When("Ip 6 mapping is defined for custom domain ", func() {
It("ip6 query should be resolved", func() {
Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("ip6.domain.", AAAA))).
Should(
SatisfyAll(
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
@ -170,7 +177,7 @@ var _ = Describe("CustomDNSResolver", func() {
When("Multiple IPs are defined for custom domain ", func() {
It("all IPs for the current type should be returned", func() {
By("IPv6 query", func() {
Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("multiple.ips.", AAAA))).
Should(
SatisfyAll(
BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
@ -185,7 +192,7 @@ var _ = Describe("CustomDNSResolver", func() {
})
By("IPv4 query", func() {
Expect(sut.Resolve(newRequest("multiple.ips.", A))).
Expect(sut.Resolve(ctx, newRequest("multiple.ips.", A))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
@ -207,7 +214,7 @@ var _ = Describe("CustomDNSResolver", func() {
When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() {
By("ipv4", func() {
Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))).
Expect(sut.Resolve(ctx, newRequest("123.143.168.192.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
WithTransform(ToAnswer, SatisfyAll(
@ -226,7 +233,7 @@ var _ = Describe("CustomDNSResolver", func() {
})
By("ipv6", func() {
Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
Expect(sut.Resolve(ctx, newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
PTR))).
Should(
SatisfyAll(
@ -250,7 +257,7 @@ var _ = Describe("CustomDNSResolver", func() {
})
When("Domain mapping is defined", func() {
It("subdomain must also match", func() {
Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))).
Expect(sut.Resolve(ctx, newRequest("ABC.CUSTOM.DOMAIN.", A))).
Should(
SatisfyAll(
BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"),
@ -268,7 +275,7 @@ var _ = Describe("CustomDNSResolver", func() {
Describe("Delegating to next resolver", func() {
When("no mapping for domain exist", func() {
It("should delegate to next resolver", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"net"
@ -46,7 +47,7 @@ func NewECSResolver(cfg config.ECS) ChainedResolver {
// Resolve adds the subnet information as EDNS0 option to the request of the next resolver
// and sets the client IP from the EDNS0 option to the request if this option is enabled
func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *ECSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if r.cfg.IsEnabled() {
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req)
// Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set
@ -67,7 +68,7 @@ func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
}
}
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
// setSubnet appends the subnet information to the request as EDNS0 option

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"net"
"github.com/0xERR0R/blocky/config"
@ -25,6 +26,9 @@ var _ = Describe("EcsResolver", func() {
err error
origIP net.IP
ecsIP net.IP
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -34,6 +38,9 @@ var _ = Describe("EcsResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
err = defaults.Set(&sutConfig)
Expect(err).ShouldNot(HaveOccurred())
@ -86,13 +93,13 @@ var _ = Describe("EcsResolver", func() {
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
m.ResolveFn = func(req *Request) (*Response, error) {
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(ecsIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -107,13 +114,13 @@ var _ = Describe("EcsResolver", func() {
addEcsOption(request.Req, ecsIP, 24)
m.ResolveFn = func(req *Request) (*Response, error) {
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -134,14 +141,14 @@ var _ = Describe("EcsResolver", func() {
request := newRequest("example.com.", A)
request.ClientIP = origIP
m.ResolveFn = func(req *Request) (*Response, error) {
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP))
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -154,13 +161,13 @@ var _ = Describe("EcsResolver", func() {
request := newRequest("example.com.", AAAA)
request.ClientIP = net.ParseIP("2001:db8::68")
m.ResolveFn = func(req *Request) (*Response, error) {
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
HaveNoAnswer(),

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -25,12 +27,12 @@ func NewEDEResolver(cfg config.EDE) *EDEResolver {
// Resolve adds the reason as EDNS0 option to the response of the next resolver
// if it is enabled in the configuration
func (r *EDEResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *EDEResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if !r.cfg.Enable {
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
resp, err := r.next.Resolve(request)
resp, err := r.next.Resolve(ctx, request)
if err != nil {
return nil, err
}

View File

@ -2,6 +2,7 @@
package resolver
import (
"context"
"errors"
"math"
@ -24,6 +25,9 @@ var _ = Describe("EdeResolver", func() {
sutConfig config.EDE
m *mockResolver
mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -33,6 +37,9 @@ var _ = Describe("EdeResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg)
})
@ -57,7 +64,7 @@ var _ = Describe("EdeResolver", func() {
}
})
It("shouldn't add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -89,7 +96,7 @@ var _ = Describe("EdeResolver", func() {
}
It("should add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -115,7 +122,7 @@ var _ = Describe("EdeResolver", func() {
})
It("shouldn't add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -137,7 +144,7 @@ var _ = Describe("EdeResolver", func() {
})
It("should return it", func() {
resp, err := sut.Resolve(newRequest("example.com", A))
resp, err := sut.Resolve(ctx, newRequest("example.com", A))
Expect(resp).To(BeNil())
Expect(err).To(Equal(resolveErr))
})

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
@ -21,7 +23,7 @@ func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver {
}
}
func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *FilteringResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
qType := request.Req.Question[0].Qtype
if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
response := new(dns.Msg)
@ -30,5 +32,5 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er
return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil
}
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
@ -18,6 +20,9 @@ var _ = Describe("FilteringResolver", func() {
sutConfig config.FilteringConfig
m *mockResolver
mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -27,6 +32,9 @@ var _ = Describe("FilteringResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg)
})
@ -60,7 +68,7 @@ var _ = Describe("FilteringResolver", func() {
}
})
It("Should delegate to next resolver if request query has other type", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -72,7 +80,7 @@ var _ = Describe("FilteringResolver", func() {
Expect(m.Calls).Should(HaveLen(1))
})
It("Should return empty answer for defined query type", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -90,7 +98,7 @@ var _ = Describe("FilteringResolver", func() {
sutConfig = config.FilteringConfig{}
})
It("Should return empty answer without error", func() {
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"strings"
"github.com/0xERR0R/blocky/config"
@ -22,7 +23,7 @@ func NewFQDNOnlyResolver(cfg config.FQDNOnly) *FQDNOnlyResolver {
}
}
func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *FQDNOnlyResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if r.IsEnabled() {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
if !strings.Contains(domainFromQuestion, ".") {
@ -33,5 +34,5 @@ func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, err
}
}
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
@ -17,6 +19,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
sutConfig config.FQDNOnly
m *mockResolver
mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -26,6 +31,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg)
})
@ -57,7 +65,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
sutConfig = config.FQDNOnly{Enable: true}
})
It("Should delegate to next resolver if request query is fqdn", func() {
Expect(sut.Resolve(newRequest("example.com", A))).
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -69,7 +77,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
Expect(m.Calls).Should(HaveLen(1))
})
It("Should return NXDOMAIN if request query is not fqdn", func() {
Expect(sut.Resolve(newRequest("example", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -103,7 +111,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
sutConfig = config.FQDNOnly{Enable: false}
})
It("Should delegate to next resolver if request query is fqdn", func() {
Expect(sut.Resolve(newRequest("example.com", A))).
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -115,7 +123,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
Expect(m.Calls).Should(HaveLen(1))
})
It("Should delegate to next resolver if request query is not fqdn", func() {
Expect(sut.Resolve(newRequest("example", AAAA))).
Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),

View File

@ -109,9 +109,9 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil
}
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
if !r.IsEnabled() {
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
reverseResp := r.handleReverseDNS(request)
@ -134,7 +134,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg {

View File

@ -25,6 +25,9 @@ var _ = Describe("HostsFileResolver", func() {
tmpFile *TmpFile
err error
resp *Response
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -34,6 +37,9 @@ var _ = Describe("HostsFileResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
tmpDir = NewTmpFolder("HostsFileResolver")
Expect(tmpDir.Error).Should(Succeed())
DeferCleanup(tmpDir.Clean)
@ -53,8 +59,6 @@ var _ = Describe("HostsFileResolver", func() {
})
JustBeforeEach(func() {
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
Expect(err).Should(Succeed())
@ -96,7 +100,7 @@ var _ = Describe("HostsFileResolver", func() {
Expect(sut.hosts.isEmpty()).Should(BeTrue())
})
It("should go to next resolver on query", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -112,11 +116,11 @@ var _ = Describe("HostsFileResolver", func() {
sutConfig.Sources = make([]config.BytesSource, 0)
})
JustBeforeEach(func() {
err = sut.loadSources(context.Background())
err = sut.loadSources(ctx)
Expect(err).Should(Succeed())
})
It("should go to next resolver on query", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -178,7 +182,7 @@ var _ = Describe("HostsFileResolver", func() {
When("IPv4 mapping is defined for a host", func() {
It("defined ipv4 query should be resolved", func() {
Expect(sut.Resolve(newRequest("ipv4host.", A))).
Expect(sut.Resolve(ctx, newRequest("ipv4host.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -188,7 +192,7 @@ var _ = Describe("HostsFileResolver", func() {
))
})
It("defined ipv4 query for alias should be resolved", func() {
Expect(sut.Resolve(newRequest("router2.", A))).
Expect(sut.Resolve(ctx, newRequest("router2.", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -198,7 +202,7 @@ var _ = Describe("HostsFileResolver", func() {
))
})
It("ipv4 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("does.not.exist.", A))).
Expect(sut.Resolve(ctx, newRequest("does.not.exist.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -210,7 +214,7 @@ var _ = Describe("HostsFileResolver", func() {
When("IPv6 mapping is defined for a host", func() {
It("defined ipv6 query should be resolved", func() {
Expect(sut.Resolve(newRequest("ipv6host.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("ipv6host.", AAAA))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -220,7 +224,7 @@ var _ = Describe("HostsFileResolver", func() {
))
})
It("ipv6 query should return NOERROR and empty result", func() {
Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))).
Expect(sut.Resolve(ctx, newRequest("does.not.exist.", AAAA))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -232,7 +236,7 @@ var _ = Describe("HostsFileResolver", func() {
When("the domain is not known", func() {
It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("not-in-hostsfile.tld.", A))
resp, err = sut.Resolve(ctx, newRequest("not-in-hostsfile.tld.", A))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
@ -241,7 +245,7 @@ var _ = Describe("HostsFileResolver", func() {
When("the question type is not handled", func() {
It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("localhost.", MX))
resp, err = sut.Resolve(ctx, newRequest("localhost.", MX))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
@ -251,7 +255,7 @@ var _ = Describe("HostsFileResolver", func() {
When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() {
By("ipv4 with one hostname", func() {
Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))).
Expect(sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -261,7 +265,7 @@ var _ = Describe("HostsFileResolver", func() {
))
})
By("ipv4 with aliases", func() {
Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))).
Expect(sut.Resolve(ctx, newRequest("1.0.0.10.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -274,7 +278,9 @@ var _ = Describe("HostsFileResolver", func() {
))
})
By("ipv6", func() {
Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))).
Expect(sut.Resolve(ctx,
newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR)),
).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -290,7 +296,7 @@ var _ = Describe("HostsFileResolver", func() {
})
It("should ignore invalid PTR", func() {
resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
resp, err = sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
@ -298,7 +304,7 @@ var _ = Describe("HostsFileResolver", func() {
When("filterLoopback is true", func() {
It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR))
resp, err = sut.Resolve(ctx, newRequest("1.0.0.127.in-addr.arpa.", PTR))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
@ -307,7 +313,7 @@ var _ = Describe("HostsFileResolver", func() {
When("the IP is not known", func() {
It("calls the next resolver", func() {
resp, err = sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR))
resp, err = sut.Resolve(ctx, newRequest("255.255.255.255.in-addr.arpa.", PTR))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
@ -320,7 +326,7 @@ var _ = Describe("HostsFileResolver", func() {
})
It("resolve the defined domain name", func() {
Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))).
Expect(sut.Resolve(ctx, newRequest("1.1.0.127.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
@ -338,7 +344,7 @@ var _ = Describe("HostsFileResolver", func() {
Describe("Delegating to next resolver", func() {
When("no hosts file is provided", func() {
It("should delegate to next resolver", func() {
_, err = sut.Resolve(newRequest("example.com.", A))
_, err = sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(Succeed())
// delegate was executed
m.AssertExpectations(GinkgoT())

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"strings"
"time"
@ -25,8 +26,8 @@ type MetricsResolver struct {
}
// Resolve resolves the passed request
func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
response, err := r.next.Resolve(request)
func (r *MetricsResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
response, err := r.next.Resolve(ctx, request)
if r.cfg.Enable {
r.totalQueries.With(prometheus.Labels{

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"errors"
"github.com/0xERR0R/blocky/config"
@ -21,6 +22,9 @@ var _ = Describe("MetricResolver", func() {
var (
sut *MetricsResolver
m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -30,6 +34,9 @@ var _ = Describe("MetricResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
@ -56,7 +63,7 @@ var _ = Describe("MetricResolver", func() {
Context("Recording request metrics", func() {
When("Request will be performed", func() {
It("Should record metrics", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -77,7 +84,7 @@ var _ = Describe("MetricResolver", func() {
sut.Next(m)
})
It("Error should be recorded", func() {
_, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))
_, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))
Expect(err).Should(HaveOccurred())

View File

@ -4,6 +4,7 @@ import (
"net"
"strings"
"sync/atomic"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
@ -62,6 +63,21 @@ func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response
return t
}
func (t *MockUDPUpstreamServer) WithDelay(delay time.Duration) *MockUDPUpstreamServer {
answerFn := t.answerFn
if answerFn == nil {
panic("WithDelay must be called after a WithAnswer function")
}
t.answerFn = func(request *dns.Msg) *dns.Msg {
time.Sleep(delay)
return answerFn(request)
}
return t
}
func (t *MockUDPUpstreamServer) GetCallCount() int {
return int(atomic.LoadInt32(&t.callCount))
}

View File

@ -23,7 +23,7 @@ type mockResolver struct {
mock.Mock
NextResolver
ResolveFn func(req *model.Request) (*model.Response, error)
ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error)
ResponseFn func(req *dns.Msg) *dns.Msg
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
}
@ -45,11 +45,11 @@ func (r *mockResolver) LogConfig(*logrus.Entry) {
r.Called()
}
func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {
func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
args := r.Called(req)
if r.ResolveFn != nil {
return r.ResolveFn(req)
return r.ResolveFn(ctx, req)
}
if r.ResponseFn != nil {

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/model"
"github.com/sirupsen/logrus"
)
@ -28,6 +30,6 @@ func (NoOpResolver) IsEnabled() bool {
func (NoOpResolver) LogConfig(*logrus.Entry) {
}
func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) {
func (NoOpResolver) Resolve(context.Context, *model.Request) (*model.Response, error) {
return NoResponse, nil
}

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
@ -8,7 +10,12 @@ import (
)
var _ = Describe("NoOpResolver", func() {
var sut *NoOpResolver
var (
sut *NoOpResolver
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
It("follows conventions", func() {
@ -17,12 +24,15 @@ var _ = Describe("NoOpResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewNoOpResolver()
})
Describe("Resolving", func() {
It("returns no response", func() {
resp, err := sut.Resolve(newRequest("test.tld", A))
resp, err := sut.Resolve(ctx, newRequest("test.tld", A))
Expect(err).Should(Succeed())
Expect(resp).Should(Equal(NoResponse))
})

View File

@ -53,17 +53,23 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
return status
}
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) {
resp, err := r.resolver.Resolve(req)
func (r *upstreamResolverStatus) resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
resp, err := r.resolver.Resolve(ctx, req)
if err != nil {
// Ignore `Canceled`: resolver lost the race, not an error
if !errors.Is(err, context.Canceled) {
r.lastErrorTime.Store(time.Now())
}
err = fmt.Errorf("%s: %w", r.resolver, err)
return nil, fmt.Errorf("%s: %w", r.resolver, err)
}
return resp, nil
}
func (r *upstreamResolverStatus) resolveToChan(ctx context.Context, req *model.Request, ch chan<- requestResponse) {
resp, err := r.resolve(ctx, req)
ch <- requestResponse{
resolver: &r.resolver,
response: resp,
@ -78,10 +84,10 @@ type requestResponse struct {
}
// testResolver sends a test query to verify the resolver is reachable and working
func testResolver(r *UpstreamResolver) error {
func testResolver(ctx context.Context, r *UpstreamResolver) error {
request := newRequest("github.com.", dns.Type(dns.TypeA))
resp, err := r.Resolve(request)
resp, err := r.Resolve(ctx, request)
if err != nil || resp.RType != model.ResponseTypeRESOLVED {
return fmt.Errorf("test resolve of upstream server failed: %w", err)
}
@ -91,11 +97,11 @@ func testResolver(r *UpstreamResolver) error {
// NewParallelBestResolver creates new resolver instance
func NewParallelBestResolver(
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ParallelBestResolver, error) {
logger := log.PrefixedLog(parallelResolverType)
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
@ -150,26 +156,18 @@ func (r *ParallelBestResolver) String() string {
}
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, parallelResolverType)
if len(r.resolvers) == 1 {
logger.WithField("resolver", r.resolvers[0].resolver).Debug("delegating to resolver")
resolver := r.resolvers[0]
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
return r.resolvers[0].resolver.Resolve(request)
return resolver.resolve(ctx, request)
}
ctx := context.Background()
// using context with timeout for random upstream strategy
if r.resolverCount == 1 {
var cancel context.CancelFunc
timeout := config.GetConfig().Upstreams.Timeout
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
defer cancel()
}
ctx, cancel := context.WithCancel(ctx)
defer cancel() // abort requests to resolvers that lost the race
resolvers := pickRandom(r.resolvers, r.resolverCount)
ch := make(chan requestResponse, len(resolvers))
@ -177,10 +175,10 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
for _, resolver := range resolvers {
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
go resolver.resolve(request, ch)
go resolver.resolveToChan(ctx, request, ch)
}
response, collectedErrors := evaluateResponses(ctx, logger, ch, resolvers)
response, collectedErrors := evaluateResponses(logger, ch, resolvers)
if response != nil {
return response, nil
}
@ -189,63 +187,51 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...))
}
return r.retryWithDifferent(logger, request, resolvers)
return r.retryWithDifferent(ctx, logger, request, resolvers)
}
func evaluateResponses(
ctx context.Context, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
) (*model.Response, []error) {
collectedErrors := make([]error, 0, len(resolvers))
for len(collectedErrors) < len(resolvers) {
select {
case <-ctx.Done():
// this context currently only has a deadline when resolverCount == 1
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
logger.WithField("resolver", resolvers[0].resolver).
Debug("upstream exceeded timeout, trying other upstream")
resolvers[0].lastErrorTime.Store(time.Now())
}
case result := <-ch:
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
} else {
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
result := <-ch
logger := logger.WithField("resolver", *result.resolver)
return result.response, nil
}
if result.err != nil {
logger.Debug("resolution failed from resolver, cause: ", result.err)
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
continue
}
logger.WithField("answer", util.AnswerToString(result.response.Res.Answer)).Debug("using response from resolver")
return result.response, nil
}
return nil, collectedErrors
}
func (r *ParallelBestResolver) retryWithDifferent(
logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
) (*model.Response, error) {
// second try (if retryWithDifferentResolver == true)
resolver := weightedRandom(r.resolvers, resolvers)
logger.Debugf("using %s as second resolver", resolver.resolver)
ch := make(chan requestResponse, 1)
resolver.resolve(request, ch)
result := <-ch
if result.err != nil {
return nil, fmt.Errorf("resolution retry failed: %w", result.err)
resp, err := resolver.resolve(ctx, request)
if err != nil {
return nil, fmt.Errorf("resolution retry failed: %w", err)
}
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
"resolver": *resolver,
"answer": util.AnswerToString(resp.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
return resp, nil
}
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool

View File

@ -19,6 +19,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
const (
verifyUpstreams = true
noVerifyUpstreams = false
timeout = 50 * time.Millisecond
)
var (
@ -40,6 +42,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
BeforeEach(func() {
old := config.GetConfig().Upstreams.Timeout
DeferCleanup(func() { config.GetConfig().Upstreams.Timeout = old })
config.GetConfig().Upstreams.Timeout = config.Duration(timeout)
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
ctx, cancelFn = context.WithCancel(context.Background())
@ -58,7 +64,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
Upstreams: upstreams,
}
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify)
})
Describe("IsEnabled", func() {
@ -98,7 +104,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream.Start(),
}
_, err := NewParallelBestResolver(config.UpstreamGroup{
_, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
},
@ -133,6 +139,43 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
})
When("upstream is too slow", func() {
BeforeEach(func() {
timeoutUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.1").
WithDelay(2 * timeout)
DeferCleanup(timeoutUpstream.Close)
upstreams = []config.Upstream{timeoutUpstream.Start()}
})
When("strict checking is enabled", func() {
BeforeEach(func() {
sutVerify = verifyUpstreams
})
It("should fail to start", func() {
Expect(err).Should(HaveOccurred())
})
})
When("strict checking is disabled", func() {
BeforeEach(func() {
sutVerify = noVerifyUpstreams
})
It("should start", func() {
Expect(err).Should(Succeed())
})
It("should not resolve", func() {
Expect(err).Should(Succeed())
request := newRequest("example.com.", A)
_, err := sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
Expect(isTimeout(err)).Should(BeTrue())
})
})
})
Describe("Resolving result from fastest upstream resolver", func() {
When("2 Upstream resolvers are defined", func() {
When("one resolver is fast and another is slow", func() {
@ -140,21 +183,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
fastTestUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(fastTestUpstream.Close)
slowTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
time.Sleep(50 * time.Millisecond)
Expect(err).Should(Succeed())
return response
})
slowTestUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.123").
WithDelay(timeout / 2)
DeferCleanup(slowTestUpstream.Close)
upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()}
})
It("Should use result from fastest one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -167,21 +205,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
When("one resolver is slow, but another returns an error", func() {
var slowTestUpstream *MockUDPUpstreamServer
BeforeEach(func() {
slowTestUpstream = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
time.Sleep(50 * time.Millisecond)
Expect(err).Should(Succeed())
return response
})
slowTestUpstream = NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.123").
WithDelay(timeout / 2)
DeferCleanup(slowTestUpstream.Close)
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
Expect(err).Should(Succeed())
})
It("Should use result from successful resolver", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"),
@ -202,7 +235,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("Should return error", func() {
Expect(err).Should(Succeed())
request := newRequest("example.com.", A)
_, err = sut.Resolve(request)
_, err = sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
})
@ -218,7 +251,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -242,7 +275,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
},
@ -268,7 +301,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
By("perform 100 request, error upstream's weight will be reduced", func() {
for i := 0; i < 100; i++ {
request := newRequest("example.com.", A)
_, _ = sut.Resolve(request)
_, _ = sut.Resolve(ctx, request)
}
})
@ -302,7 +335,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
r, err := NewParallelBestResolver(config.UpstreamGroup{
r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: "test",
Upstreams: []config.Upstream{{Host: "example.com"}},
}, b, verifyUpstreams)
@ -313,11 +346,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
Describe("random resolver strategy", func() {
const timeout = config.Duration(time.Second)
BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
config.GetConfig().Upstreams.Timeout = timeout
})
Describe("Name", func() {
@ -342,7 +372,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
It("Should return result from either one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(SatisfyAll(
HaveTTL(BeNumerically("==", 123)),
HaveResponseType(ResponseTypeRESOLVED),
@ -356,24 +386,19 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
When("one upstream exceeds timeout", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(time.Duration(timeout) + 2*time.Second)
Expect(err).To(Succeed())
return response
})
DeferCleanup(testUpstream1.Close)
timeoutUpstream := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.1").
WithDelay(2 * timeout)
DeferCleanup(timeoutUpstream.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
DeferCleanup(testUpstream2.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
upstreams = []config.Upstream{timeoutUpstream.Start(), testUpstream2.Start()}
})
It("should ask a other random upstream and return its response", func() {
It("should ask another random upstream and return its response", func() {
request := newRequest("example.com", A)
Expect(sut.Resolve(request)).Should(
Expect(sut.Resolve(ctx, request)).Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.2"),
HaveTTL(BeNumerically("==", 123)),
@ -382,57 +407,25 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
))
})
})
When("two upstreams exceed timeout", func() {
When("all upstreams exceed timeout", func() {
BeforeEach(func() {
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
time.Sleep(timeout.ToDuration() + 2*time.Second)
Expect(err).To(Succeed())
return response
})
testUpstream1 := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.1").
WithDelay(2 * timeout)
DeferCleanup(testUpstream1.Close)
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
time.Sleep(timeout.ToDuration() + 2*time.Second)
Expect(err).To(Succeed())
return response
})
testUpstream2 := NewMockUDPUpstreamServer().
WithAnswerRR("example.com 123 IN A 123.124.122.2").
WithDelay(2 * timeout)
DeferCleanup(testUpstream2.Close)
testUpstream3 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.3")
DeferCleanup(testUpstream3.Close)
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start(), testUpstream3.Start()}
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
})
// These two tests are flaky -_- (maybe recreate the RandomResolver )
It("should not return error (due to random selection the request could to through)", func() {
Eventually(func() error {
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
return err
}).WithTimeout(30 * time.Second).
Should(Not(HaveOccurred()))
})
It("should return error (because it can be possible that the two broken upstreams are chosen)", func() {
Eventually(func() error {
sutConfig := config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
return err
}).WithTimeout(30 * time.Second).
Should(HaveOccurred())
It("Should return error", func() {
request := newRequest("example.com.", A)
_, err := sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
Expect(isTimeout(err)).Should(BeTrue())
})
})
})
@ -446,7 +439,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
})
It("Should return error", func() {
request := newRequest("example.com.", A)
_, err := sut.Resolve(request)
_, err := sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
})
})
@ -461,7 +454,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -485,7 +478,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream2.Close)
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
},
@ -506,7 +499,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
By("perform 100 request, error upstream's weight will be reduced", func() {
for i := 0; i < 100; i++ {
request := newRequest("example.com.", A)
_, _ = sut.Resolve(request)
_, _ = sut.Resolve(ctx, request)
}
})

View File

@ -111,12 +111,12 @@ func (r *QueryLoggingResolver) doCleanUp() {
}
// Resolve logs the query, duration and the result
func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
start := time.Now()
resp, err := r.next.Resolve(request)
resp, err := r.next.Resolve(ctx, request)
duration := time.Since(start).Milliseconds()

View File

@ -44,6 +44,9 @@ var _ = Describe("QueryLoggingResolver", func() {
m *mockResolver
tmpDir *TmpFolder
mockAnswer *dns.Msg
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -53,6 +56,9 @@ var _ = Describe("QueryLoggingResolver", func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
mockAnswer = new(dns.Msg)
tmpDir = NewTmpFolder("queryLoggingResolver")
Expect(tmpDir.Error).Should(Succeed())
@ -64,9 +70,6 @@ var _ = Describe("QueryLoggingResolver", func() {
sutConfig.SetDefaults() // not called when using a struct literal
}
ctx, cancelFn := context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sut = NewQueryLoggingResolver(ctx, sutConfig)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
@ -98,7 +101,7 @@ var _ = Describe("QueryLoggingResolver", func() {
}
})
It("should process request without query logging", func() {
Expect(sut.Resolve(newRequest("example.com", A))).
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -120,7 +123,7 @@ var _ = Describe("QueryLoggingResolver", func() {
})
It("should create a log file per client", func() {
By("request from client 1", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -128,7 +131,7 @@ var _ = Describe("QueryLoggingResolver", func() {
))
})
By("request from client 2, has name with special chars, should be escaped", func() {
Expect(sut.Resolve(newRequestWithClient(
Expect(sut.Resolve(ctx, newRequestWithClient(
"example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))).
Should(
SatisfyAll(
@ -188,7 +191,7 @@ var _ = Describe("QueryLoggingResolver", func() {
})
It("should create one log file for all clients", func() {
By("request from client 1", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -196,7 +199,7 @@ var _ = Describe("QueryLoggingResolver", func() {
))
})
By("request from client 2, has name with special chars, should be escaped", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -249,7 +252,7 @@ var _ = Describe("QueryLoggingResolver", func() {
})
It("should create one log file", func() {
By("request from client 1", func() {
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeRESOLVED),
@ -297,7 +300,7 @@ var _ = Describe("QueryLoggingResolver", func() {
sut.writer = mockWriter
Eventually(func() int {
_, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
_, ierr := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
Expect(ierr).Should(Succeed())
return len(sut.logChan)

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"net"
"time"
@ -74,7 +75,7 @@ type Resolver interface {
Type() string
// Resolve performs resolution of a DNS request
Resolve(req *model.Request) (*model.Response, error)
Resolve(ctx context.Context, req *model.Request) (*model.Response, error)
}
// ChainedResolver represents a resolver, which can delegate result to the next one
@ -216,13 +217,14 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
}
func createResolvers(
logger *logrus.Entry, cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
ctx context.Context, logger *logrus.Entry,
cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
) ([]Resolver, error) {
resolvers := make([]Resolver, 0, len(cfg.Upstreams))
hasValidResolvers := false
for _, u := range cfg.Upstreams {
resolver, err := NewUpstreamResolver(u, bootstrap, shoudVerifyUpstreams)
resolver, err := NewUpstreamResolver(ctx, u, bootstrap, shoudVerifyUpstreams)
if err != nil {
logger.Warnf("upstream group %s: %v", cfg.Name, err)
@ -230,7 +232,7 @@ func createResolvers(
}
if shoudVerifyUpstreams {
err = testResolver(resolver)
err = testResolver(ctx, resolver)
if err != nil {
logger.Warn(err)
} else {

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"strings"
@ -57,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
}
// Resolve uses the inner resolver to resolve the rewritten query
func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "rewriter_resolver")
original := request.Req
@ -69,7 +70,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver")
response, err := r.inner.Resolve(request)
response, err := r.inner.Resolve(ctx, request)
// Test for error after checking for fallbackUpstream
// Revert the request: must be done before calling r.next
@ -80,7 +81,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
if err != nil {
@ -91,7 +92,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
// Inner resolver had no response, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
// Revert the rewrite in r.inner's response

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
@ -94,7 +96,7 @@ var _ = Describe("RewriterResolver", func() {
return res
}
resp, err := sut.Resolve(request)
resp, err := sut.Resolve(context.Background(), request)
Expect(err).Should(Succeed())
if resp != mNextResponse {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
@ -132,18 +134,18 @@ var _ = Describe("RewriterResolver", func() {
expectNilAnswer = true
// Make inner call the NoOpResolver
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten
Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten))
return mInner.next.Resolve(req)
return mInner.next.Resolve(ctx, req)
}
// Resolver after RewriterResolver should see `fqdnOriginal`
mNext.On("Resolve", mock.Anything)
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
return mNextResponse, nil
@ -156,7 +158,7 @@ var _ = Describe("RewriterResolver", func() {
expectNilAnswer = true
// Make inner return a nil Answer but not an empty Response
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten
@ -179,7 +181,7 @@ var _ = Describe("RewriterResolver", func() {
fqdnRewritten = sampleRewritten
// Make inner return a nil Answer but not an empty Response
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req).Should(Equal(request))
// Inner should see fqdnRewritten
@ -190,7 +192,7 @@ var _ = Describe("RewriterResolver", func() {
// Resolver after RewriterResolver should see `fqdnOriginal`
mNext.On("Resolve", mock.Anything)
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
return mNextResponse, nil

View File

@ -30,11 +30,11 @@ type StrictResolver struct {
// NewStrictResolver creates a new strict resolver instance
func NewStrictResolver(
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*StrictResolver, error) {
logger := log.PrefixedLog(strictResolverType)
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
@ -76,44 +76,27 @@ func (r *StrictResolver) String() string {
}
// Resolve sends the query request in a strict order to the upstream resolvers
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, strictResolverType)
// start with first resolver
for i := range r.resolvers {
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resolver := r.resolvers[i]
for _, resolver := range r.resolvers {
logger.Debugf("using %s as resolver", resolver.resolver)
ch := make(chan requestResponse, 1)
go resolver.resolve(request, ch)
select {
case <-ctx.Done():
// log debug/info that timeout exceeded, call `continue` to try next upstream
logger.WithField("resolver", r.resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
resp, err := resolver.resolve(ctx, request)
if err != nil {
// log error and try next upstream
logger.WithField("resolver", resolver.resolver).Debug("resolution failed from resolver, cause: ", err)
continue
case result := <-ch:
if result.err != nil {
// log error & call `continue` to try next upstream
logger.Debug("resolution failed from resolver, cause: ", result.err)
continue
}
logger.WithFields(logrus.Fields{
"resolver": *result.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
}
logger.WithFields(logrus.Fields{
"resolver": *resolver,
"answer": util.AnswerToString(resp.Res.Answer),
}).Debug("using response from resolver")
return resp, nil
}
return nil, errors.New("resolution was not successful, no resolver returned an answer in time")

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"time"
"github.com/0xERR0R/blocky/config"
@ -27,6 +28,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
err error
bootstrap *Bootstrap
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -36,6 +40,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
upstreams = []config.Upstream{
{Host: "wrong"},
{Host: "127.0.0.2"},
@ -51,7 +58,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
}
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify)
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify)
})
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
@ -100,7 +107,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
mockUpstream.Start(),
}
_, err := NewStrictResolver(config.UpstreamGroup{
_, err := NewStrictResolver(ctx, config.UpstreamGroup{
Name: upstreamDefaultCfgName,
Upstreams: upstreams,
},
@ -151,7 +158,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
It("Should use result from first one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -180,7 +187,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
It("should return response from next upstream", func() {
request := newRequest("example.com", A)
Expect(sut.Resolve(request)).Should(
Expect(sut.Resolve(ctx, request)).Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.2"),
HaveTTL(BeNumerically("==", 123)),
@ -214,7 +221,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
It("should return error", func() {
request := newRequest("example.com", A)
_, err := sut.Resolve(request)
_, err := sut.Resolve(ctx, request)
Expect(err).To(HaveOccurred())
})
})
@ -230,7 +237,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
It("Should use result from second one", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.123"),
@ -247,7 +254,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
})
It("Should return error", func() {
request := newRequest("example.com.", A)
_, err = sut.Resolve(request)
_, err = sut.Resolve(ctx, request)
Expect(err).Should(HaveOccurred())
})
})
@ -262,7 +269,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
It("Should use result from defined resolver", func() {
request := newRequest("example.com.", A)
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"net"
"strings"
@ -99,7 +100,7 @@ func NewSpecialUseDomainNamesResolver(cfg config.SUDN) *SpecialUseDomainNamesRes
}
}
func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *SpecialUseDomainNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
handler := r.handler(request)
if handler != nil {
resp := handler(request, r.cfg)
@ -108,7 +109,7 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.
}
}
return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}
func (r *SpecialUseDomainNamesResolver) handler(request *model.Request) sudnHandler {

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"github.com/0xERR0R/blocky/config"
@ -19,6 +20,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
sut *SpecialUseDomainNamesResolver
sutConfig config.SUDN
m *mockResolver
ctx context.Context
cancelFn context.CancelFunc
)
Describe("Type", func() {
@ -30,6 +34,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
BeforeEach(func() {
var err error
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig, err = config.WithDefaults[config.SUDN]()
Expect(err).Should(Succeed())
})
@ -48,7 +55,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
Describe("handlers", func() {
It("should have correct response type", func() {
for domain, handler := range sudnHandlers {
resp, err := sut.Resolve(newRequest(domain, A))
resp, err := sut.Resolve(ctx, newRequest(domain, A))
Expect(err).Should(Succeed())
if handler == nil {
@ -90,7 +97,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
DescribeTable("handled domains",
func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) {
resp, err := sut.Resolve(newRequest(qName, qType))
resp, err := sut.Resolve(ctx, newRequest(qName, qType))
Expect(err).Should(Succeed())
Expect(resp).Should(SatisfyAll(
HaveResponseType(ResponseTypeSPECIAL),
@ -133,7 +140,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
DescribeTable("",
func(qType dns.Type, qName string, expectedRCode int) {
resp, err := sut.Resolve(newRequest(qName, qType))
resp, err := sut.Resolve(ctx, newRequest(qName, qType))
Expect(err).Should(Succeed())
Expect(resp).Should(HaveReturnCode(expectedRCode))
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
@ -150,7 +157,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
})
It("should forward example.com", func() {
Expect(sut.Resolve(newRequest("example.com", A))).
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.145.123.145"),
@ -161,7 +168,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
})
It("should forward home.arpa. IN DS", func() {
Expect(sut.Resolve(newRequest("something.home.arpa.", DS))).
Expect(sut.Resolve(ctx, newRequest("something.home.arpa.", DS))).
Should(
SatisfyAll(
// setup code doesn't care about the question
@ -173,7 +180,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
})
It("should forward non special use domains", func() {
resp, err := sut.Resolve(newRequest("something.not-special.", AAAA))
resp, err := sut.Resolve(ctx, newRequest("something.not-special.", AAAA))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
})

View File

@ -2,6 +2,7 @@ package resolver
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
@ -39,8 +40,9 @@ type UpstreamResolver struct {
type upstreamClient interface {
fmtURL(ip net.IP, port uint16, path string) string
callExternal(msg *dns.Msg, upstreamURL string,
protocol model.RequestProtocol) (response *dns.Msg, rtt time.Duration, err error)
callExternal(
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error)
}
type dnsUpstreamClient struct {
@ -53,8 +55,6 @@ type httpUpstreamClient struct {
}
func createUpstreamClient(cfg config.Upstream) upstreamClient {
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
tlsConfig := tls.Config{
ServerName: cfg.Host,
MinVersion: tls.VersionTLS12,
@ -73,7 +73,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
ForceAttemptHTTP2: true,
},
Timeout: timeout,
},
host: cfg.Host,
}
@ -83,7 +82,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
tcpClient: &dns.Client{
TLSConfig: &tlsConfig,
Net: cfg.Net.String(),
Timeout: timeout,
SingleInflight: true,
},
}
@ -92,12 +90,10 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
return &dnsUpstreamClient{
tcpClient: &dns.Client{
Net: "tcp",
Timeout: timeout,
SingleInflight: true,
},
udpClient: &dns.Client{
Net: "udp",
Timeout: timeout,
SingleInflight: true,
},
}
@ -112,8 +108,8 @@ func (r *httpUpstreamClient) fmtURL(ip net.IP, port uint16, path string) string
return fmt.Sprintf("https://%s%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), path)
}
func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
upstreamURL string, _ model.RequestProtocol,
func (r *httpUpstreamClient) callExternal(
ctx context.Context, msg *dns.Msg, upstreamURL string, _ model.RequestProtocol,
) (*dns.Msg, time.Duration, error) {
start := time.Now()
@ -122,7 +118,7 @@ func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
return nil, 0, fmt.Errorf("can't pack message: %w", err)
}
req, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
if err != nil {
return nil, 0, fmt.Errorf("can't create the new request %w", err)
}
@ -169,16 +165,16 @@ func (r *dnsUpstreamClient) fmtURL(ip net.IP, port uint16, _ string) string {
return net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
}
func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
upstreamURL string, protocol model.RequestProtocol,
func (r *dnsUpstreamClient) callExternal(
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
) (response *dns.Msg, rtt time.Duration, err error) {
if protocol == model.RequestProtocolTCP {
response, rtt, err = r.tcpClient.Exchange(msg, upstreamURL)
if err != nil {
response, rtt, err = r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
if err != nil && r.udpClient != nil {
// try UDP as fallback
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" && r.udpClient != nil {
return r.udpClient.Exchange(msg, upstreamURL)
if errors.As(err, &opErr) && opErr.Op == "dial" {
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
}
}
@ -186,18 +182,20 @@ func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
}
if r.udpClient != nil {
return r.udpClient.Exchange(msg, upstreamURL)
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
}
return r.tcpClient.Exchange(msg, upstreamURL)
return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
}
// NewUpstreamResolver creates new resolver instance
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap, verify bool) (*UpstreamResolver, error) {
func NewUpstreamResolver(
ctx context.Context, upstream config.Upstream, bootstrap *Bootstrap, verify bool,
) (*UpstreamResolver, error) {
r := newUpstreamResolverUnchecked(upstream, bootstrap)
if verify {
_, err := r.bootstrap.UpstreamIPs(r)
_, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil {
return nil, err
}
@ -234,50 +232,41 @@ func (r UpstreamResolver) String() string {
}
// Resolve calls external resolver
func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) {
ips, err := r.bootstrap.UpstreamIPs(r)
func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
ips, err := r.bootstrap.UpstreamIPs(ctx, r)
if err != nil {
return nil, err
}
var (
rtt time.Duration
resp *dns.Msg
ip net.IP
)
err = retry.Do(
func() error {
ctx, cancel := context.WithTimeout(ctx, config.GetConfig().Upstreams.Timeout.ToDuration())
defer cancel()
ip = ips.Current()
upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path)
var err error
resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol)
if err == nil {
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.upstream.String(),
"upstream_ip": ip.String(),
"protocol": request.Protocol,
"net": r.upstream.Net,
"response_time_ms": rtt.Milliseconds(),
}).Debugf("received response from upstream")
return nil
response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
if err != nil {
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
}
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
resp = response
r.logResponse(request, response, ip, rtt)
return nil
},
retry.Context(ctx),
retry.Attempts(retryAttempts),
retry.DelayType(retry.FixedDelay),
retry.Delay(1*time.Millisecond),
retry.LastErrorOnly(true),
retry.RetryIf(func(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}),
retry.RetryIf(isTimeout),
retry.OnRetry(func(n uint, err error) {
r.log().WithFields(logrus.Fields{
"upstream": r.upstream.String(),
@ -289,8 +278,31 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
ips.Next()
}))
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
// Make the error more user friendly than just "context deadline exceeded"
err = fmt.Errorf("timeout (%w)", err)
}
return nil, err
}
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil
}
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) {
r.log().WithFields(logrus.Fields{
"answer": util.AnswerToString(resp.Answer),
"return_code": dns.RcodeToString[resp.Rcode],
"upstream": r.upstream.String(),
"upstream_ip": ip.String(),
"protocol": request.Protocol,
"net": r.upstream.Net,
"response_time_ms": rtt.Milliseconds(),
}).Debugf("received response from upstream")
}
func isTimeout(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"crypto/tls"
"fmt"
"net/http"
@ -21,9 +22,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
var (
sut *UpstreamResolver
sutConfig config.Upstream
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = config.Upstream{Host: "localhost"}
})
@ -62,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -81,7 +88,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
@ -100,7 +107,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
upstream := mockUpstream.Start()
sut := newUpstreamResolverUnchecked(upstream, nil)
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
})
})
@ -133,7 +140,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
atomic.StoreInt32(&counter, 0)
atomic.StoreInt32(&attemptsWithTimeout, 2)
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -146,12 +153,37 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
By("3 attempts with timeout -> should return error", func() {
atomic.StoreInt32(&counter, 0)
atomic.StoreInt32(&attemptsWithTimeout, 3)
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("i/o timeout"))
})
})
})
When("user request is TCP", func() {
When("TCP upstream connection fails", func() {
BeforeEach(func() {
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
DeferCleanup(mockUpstream.Close)
sutConfig = mockUpstream.Start()
})
It("should retry with UDP", func() {
req := newRequest("example.com.", A)
req.Protocol = RequestProtocolTCP
Expect(sut.Resolve(ctx, req)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveTTL(BeNumerically("==", 123)),
))
})
})
})
})
Describe("Using Dns over HTTP (DOH) upstream", func() {
@ -185,7 +217,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
When("Configured DOH resolver can resolve query", func() {
It("should return answer from DNS upstream", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "123.124.122.122"),
@ -203,7 +235,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
})
@ -215,7 +247,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
@ -228,7 +260,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
})
@ -241,7 +273,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}, systemResolverBootstrap)
})
It("should return error", func() {
_, err := sut.Resolve(newRequest("example.com.", A))
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(Or(
ContainSubstring("no such host"),

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"strings"
@ -64,7 +65,7 @@ func (r *UpstreamTreeResolver) String() string {
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
}
func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
group := r.upstreamGroupByClient(request)
@ -72,7 +73,7 @@ func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response,
// delegate request to group resolver
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
return r.branches[group].Resolve(request)
return r.branches[group].Resolve(ctx, request)
}
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {

View File

@ -1,6 +1,8 @@
package resolver
import (
"context"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
@ -139,7 +141,15 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
})
When("client specific resolvers are defined", func() {
var (
ctx context.Context
cancelFn context.CancelFunc
)
BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
upstreamDefaultCfgName: {config.Upstream{}},
"laptop": {config.Upstream{}},
@ -191,7 +201,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use default if client name or IP don't match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "default"),
@ -202,7 +212,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name matches exact", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"),
@ -213,7 +223,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name matches with wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "client-*-m"),
@ -224,7 +234,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name matches with range wildcard", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "client[0-9]"),
@ -235,7 +245,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client IP matches", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
@ -246,7 +256,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client name (containing IP) matches", func() {
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
@ -257,7 +267,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
@ -268,7 +278,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use exact IP match before client name match", func() {
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "192.168.178.33"),
@ -279,7 +289,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
It("Should use client name match before CIDR match", func() {
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "laptop"),
@ -293,7 +303,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
request.Log = logger
Expect(sut.Resolve(request)).
Expect(sut.Resolve(ctx, request)).
Should(
SatisfyAll(
SatisfyAny(

View File

@ -389,7 +389,7 @@ func createQueryResolver(
bootstrap *resolver.Bootstrap,
redisClient *redis.Client,
) (r resolver.ChainedResolver, err error) {
upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap)
upstreamBranches, uErr := createUpstreamBranches(ctx, cfg, bootstrap)
if uErr != nil {
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
}
@ -398,7 +398,9 @@ func createQueryResolver(
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(
ctx, cfg.Conditional, bootstrap, cfg.StartVerifyUpstream,
)
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
err = multierror.Append(
@ -433,6 +435,7 @@ func createQueryResolver(
}
func createUpstreamBranches(
ctx context.Context,
cfg *config.Config,
bootstrap *resolver.Bootstrap,
) (map[string]resolver.Resolver, error) {
@ -453,11 +456,11 @@ func createUpstreamBranches(
switch cfg.Upstreams.Strategy {
case config.UpstreamStrategyParallelBest:
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyStrict:
upstream, err = resolver.NewStrictResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
upstream, err = resolver.NewStrictResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
case config.UpstreamStrategyRandom:
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
}
upstreamBranches[group] = upstream
@ -643,7 +646,7 @@ func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
r := createResolverRequest(w, request)
response, err := s.queryResolver.Resolve(r)
response, err := s.queryResolver.Resolve(context.Background(), r)
if err != nil {
logger().Error("error on processing request:", err)

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"encoding/base64"
"fmt"
"html/template"
@ -149,7 +150,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)
resResponse, err := s.queryResolver.Resolve(r)
resResponse, err := s.queryResolver.Resolve(req.Context(), r)
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)
@ -192,11 +193,11 @@ func extractIP(r *http.Request) string {
return hostPort
}
func (s *Server) Query(question string, qType dns.Type) (*model.Response, error) {
func (s *Server) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
r := createResolverRequest(nil, dnsRequest)
return s.queryResolver.Resolve(r)
return s.queryResolver.Resolve(ctx, r)
}
func createHTTPSRouter(cfg *config.Config) *chi.Mux {

View File

@ -76,10 +76,9 @@ var _ = BeforeSuite(func() {
clientMockUpstream = resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
var clientName string
client := mockClientName.Load()
if client != nil {
clientName = mockClientName.Load().(string)
if name, ok := mockClientName.Load().(string); ok {
clientName = name
}
response, err := util.NewMsgWithAnswer(
@ -118,8 +117,7 @@ var _ = BeforeSuite(func() {
youtubeFile := tmpDir.CreateStringFile("youtube.com.txt", "youtube.com")
Expect(youtubeFile.Error).Should(Succeed())
// create server
sut, err = NewServer(ctx, &config.Config{
cfg := &config.Config{
CustomDNS: config.CustomDNSConfig{
CustomTTL: config.Duration(3600 * time.Second),
Mapping: config.CustomDNSMapping{
@ -160,7 +158,8 @@ var _ = BeforeSuite(func() {
BlockTTL: config.Duration(6 * time.Hour),
},
Upstreams: config.UpstreamsConfig{
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
Timeout: config.Duration(250 * time.Millisecond),
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
},
ClientLookup: config.ClientLookupConfig{
Upstream: upstreamClient,
@ -178,16 +177,19 @@ var _ = BeforeSuite(func() {
Enable: true,
Path: "/metrics",
},
})
}
// Hacky but needed to update the global since we still have code that reads it
*config.GetConfig() = *cfg
// create server
sut, err = NewServer(ctx, cfg)
Expect(err).Should(Succeed())
errChan := make(chan error, 10)
// start server
go func() {
sut.Start(ctx, errChan)
}()
go sut.Start(ctx, errChan)
DeferCleanup(sut.Stop)
Consistently(errChan, "1s").ShouldNot(Receive())
@ -697,7 +699,7 @@ var _ = Describe("Running DNS server", func() {
Describe("NewServer with strict upstream strategy", func() {
It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(&config.Config{
branches, err := createUpstreamBranches(context.Background(), &config.Config{
Upstreams: config.UpstreamsConfig{
Strategy: config.UpstreamStrategyStrict,
Groups: config.UpstreamGroups{
@ -715,7 +717,7 @@ var _ = Describe("Running DNS server", func() {
Describe("NewServer with random upstream strategy", func() {
It("successfully returns upstream branches", func() {
branches, err := createUpstreamBranches(&config.Config{
branches, err := createUpstreamBranches(context.Background(), &config.Config{
Upstreams: config.UpstreamsConfig{
Strategy: config.UpstreamStrategyRandom,
Groups: config.UpstreamGroups{