mirror of https://github.com/0xERR0R/blocky.git
refactor: make use of contexts in more places
- `CacheControl.FlushCaches` - `Querier.Query` - `Resolver.Resolve` Besides all the API churn, this leads to `ParallelBestResolver`, `StrictResolver` and `UpstreamResolver` simplification: timeouts only need to be setup in one place, `UpstreamResolver`. We also benefit from using HTTP request contexts, so if the client closes the connection we stop processing on our side.
This commit is contained in:
parent
e4ebc16ccc
commit
eae99ec550
|
@ -40,11 +40,11 @@ type ListRefresher interface {
|
|||
}
|
||||
|
||||
type Querier interface {
|
||||
Query(question string, qType dns.Type) (*model.Response, error)
|
||||
Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
|
||||
}
|
||||
|
||||
type CacheControl interface {
|
||||
FlushCaches()
|
||||
FlushCaches(ctx context.Context)
|
||||
}
|
||||
|
||||
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
||||
|
@ -137,13 +137,13 @@ func (i *OpenAPIInterfaceImpl) ListRefresh(_ context.Context,
|
|||
return ListRefresh200Response{}, nil
|
||||
}
|
||||
|
||||
func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObject) (QueryResponseObject, error) {
|
||||
func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestObject) (QueryResponseObject, error) {
|
||||
qType := dns.Type(dns.StringToType[request.Body.Type])
|
||||
if qType == dns.Type(dns.TypeNone) {
|
||||
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
|
||||
}
|
||||
|
||||
resp, err := i.querier.Query(dns.Fqdn(request.Body.Query), qType)
|
||||
resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -156,10 +156,10 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
|
|||
}), nil
|
||||
}
|
||||
|
||||
func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context,
|
||||
func (i *OpenAPIInterfaceImpl) CacheFlush(ctx context.Context,
|
||||
_ CacheFlushRequestObject,
|
||||
) (CacheFlushResponseObject, error) {
|
||||
i.cacheControl.FlushCaches()
|
||||
i.cacheControl.FlushCaches(ctx)
|
||||
|
||||
return CacheFlush200Response{}, nil
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
// . "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
|
|||
return args.Get(0).(BlockingStatus)
|
||||
}
|
||||
|
||||
func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) {
|
||||
args := m.Called(question, qType)
|
||||
func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
|
||||
args := m.Called(ctx, question, qType)
|
||||
|
||||
return args.Get(0).(*model.Response), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *CacheControlMock) FlushCaches() {
|
||||
_ = m.Called()
|
||||
func (m *CacheControlMock) FlushCaches(ctx context.Context) {
|
||||
_ = m.Called(ctx)
|
||||
}
|
||||
|
||||
var _ = Describe("API implementation tests", func() {
|
||||
|
@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
|
|||
listRefreshMock *ListRefreshMock
|
||||
cacheControlMock *CacheControlMock
|
||||
sut *OpenAPIInterfaceImpl
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
blockingControlMock = &BlockingControlMock{}
|
||||
querierMock = &QuerierMock{}
|
||||
listRefreshMock = &ListRefreshMock{}
|
||||
|
@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
|
|||
)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
querierMock.On("Query", "google.com.", A).Return(&model.Response{
|
||||
querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
|
||||
Res: queryResponse,
|
||||
Reason: "reason",
|
||||
}, nil)
|
||||
|
||||
resp, err := sut.Query(context.Background(), QueryRequestObject{
|
||||
resp, err := sut.Query(ctx, QueryRequestObject{
|
||||
Body: &ApiQueryRequest{
|
||||
Query: "google.com", Type: "A",
|
||||
},
|
||||
|
@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
})
|
||||
|
||||
It("should return 400 on wrong parameter", func() {
|
||||
resp, err := sut.Query(context.Background(), QueryRequestObject{
|
||||
resp, err := sut.Query(ctx, QueryRequestObject{
|
||||
Body: &ApiQueryRequest{
|
||||
Query: "google.com",
|
||||
Type: "WRONGTYPE",
|
||||
|
@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
It("should return 200 on success", func() {
|
||||
listRefreshMock.On("RefreshLists").Return(nil)
|
||||
|
||||
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
|
||||
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
|
||||
Expect(err).Should(Succeed())
|
||||
var resp200 ListRefresh200Response
|
||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||
|
@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
It("should return 500 on failure", func() {
|
||||
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))
|
||||
|
||||
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
|
||||
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
|
||||
Expect(err).Should(Succeed())
|
||||
var resp500 ListRefresh500TextResponse
|
||||
Expect(resp).Should(BeAssignableToTypeOf(resp500))
|
||||
|
@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
duration := "3s"
|
||||
grroups := "gr1,gr2"
|
||||
|
||||
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
|
||||
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
|
||||
Params: DisableBlockingParams{
|
||||
Duration: &duration,
|
||||
Groups: &grroups,
|
||||
|
@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
|
||||
It("should return 400 on failure", func() {
|
||||
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
|
||||
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{})
|
||||
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
|
||||
Expect(err).Should(Succeed())
|
||||
var resp400 DisableBlocking400TextResponse
|
||||
Expect(resp).Should(BeAssignableToTypeOf(resp400))
|
||||
|
@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
|
||||
It("should return 400 on wrong duration parameter", func() {
|
||||
wrongDuration := "4sds"
|
||||
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
|
||||
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
|
||||
Params: DisableBlockingParams{
|
||||
Duration: &wrongDuration,
|
||||
},
|
||||
|
@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
It("should return 200 on success", func() {
|
||||
blockingControlMock.On("EnableBlocking").Return()
|
||||
|
||||
resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{})
|
||||
resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
|
||||
Expect(err).Should(Succeed())
|
||||
var resp200 EnableBlocking200Response
|
||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||
|
@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
AutoEnableInSec: 47,
|
||||
})
|
||||
|
||||
resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{})
|
||||
resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
|
||||
Expect(err).Should(Succeed())
|
||||
var resp200 BlockingStatus200JSONResponse
|
||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||
|
@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
|
|||
Describe("Cache API", func() {
|
||||
When("Cache flush is called", func() {
|
||||
It("should return 200 on success", func() {
|
||||
cacheControlMock.On("FlushCaches").Return()
|
||||
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
|
||||
cacheControlMock.On("FlushCaches", ctx).Return()
|
||||
resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
|
||||
Expect(err).Should(Succeed())
|
||||
var resp200 CacheFlush200Response
|
||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||
|
|
|
@ -37,7 +37,7 @@ type Options struct {
|
|||
// OnExpirationCallback will be called just before an element gets expired and will
|
||||
// be removed from cache. This function can return new value and TTL to leave the
|
||||
// element in the cache or nil to remove it
|
||||
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
|
||||
type OnExpirationCallback[T any] func(ctx context.Context, key string) (val *T, ttl time.Duration)
|
||||
|
||||
// OnCacheHitCallback will be called on cache get if entry was found
|
||||
type OnCacheHitCallback func(key string)
|
||||
|
@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
|
|||
l, _ := lru.New(defaultSize)
|
||||
c := &ExpiringLRUCache[T]{
|
||||
cleanUpInterval: defaultCleanUpInterval,
|
||||
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
|
||||
preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
|
||||
return nil, 0
|
||||
},
|
||||
onCacheHit: func(key string) {},
|
||||
|
@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
|
|||
var keysToDelete []string
|
||||
|
||||
for _, key := range expiredKeys {
|
||||
newVal, newTTL := e.preExpirationFn(key)
|
||||
newVal, newTTL := e.preExpirationFn(context.Background(), key)
|
||||
if newVal != nil {
|
||||
e.Put(key, newVal, newTTL)
|
||||
} else {
|
||||
|
|
|
@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
Describe("preExpiration function", func() {
|
||||
When("function is defined", func() {
|
||||
It("should update the value and TTL if function returns values", func() {
|
||||
fn := func(key string) (val *string, ttl time.Duration) {
|
||||
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
|
||||
v2 := "v2"
|
||||
|
||||
return &v2, time.Second
|
||||
|
@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
|
||||
It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
|
||||
fn := func(key string) (val *string, ttl time.Duration) {
|
||||
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
|
||||
v2 := "val2"
|
||||
|
||||
return &v2, time.Second
|
||||
|
@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
|
|||
})
|
||||
|
||||
It("should delete the key if function returns nil", func() {
|
||||
fn := func(key string) (val *string, ttl time.Duration) {
|
||||
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
|
||||
return nil, 0
|
||||
}
|
||||
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
||||
|
|
|
@ -25,11 +25,11 @@ type cacheValue[T any] struct {
|
|||
type OnEntryReloadedCallback func(key string)
|
||||
|
||||
// ReloadEntryFn reloads a prefetched entry by key
|
||||
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
|
||||
type ReloadEntryFn[T any] func(ctx context.Context, key string) (*T, time.Duration)
|
||||
|
||||
type PrefetchingOptions[T any] struct {
|
||||
Options
|
||||
ReloadFn func(cacheKey string) (*T, time.Duration)
|
||||
ReloadFn ReloadEntryFn[T]
|
||||
PrefetchThreshold int
|
||||
PrefetchExpires time.Duration
|
||||
PrefetchMaxItemsCount int
|
||||
|
@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
|
|||
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
|
||||
}
|
||||
|
||||
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
|
||||
func (e *PrefetchingExpiringLRUCache[T]) onExpired(
|
||||
ctx context.Context, cacheKey string,
|
||||
) (val *cacheValue[T], ttl time.Duration) {
|
||||
if e.shouldPrefetch(cacheKey) {
|
||||
loadedVal, ttl := e.reloadFn(cacheKey)
|
||||
loadedVal, ttl := e.reloadFn(ctx, cacheKey)
|
||||
if loadedVal != nil {
|
||||
if e.onPrefetchEntryReloaded != nil {
|
||||
e.onPrefetchEntryReloaded(cacheKey)
|
||||
|
|
|
@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
},
|
||||
PrefetchThreshold: 2,
|
||||
PrefetchExpires: 100 * time.Millisecond,
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
|
@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
},
|
||||
PrefetchThreshold: 2,
|
||||
PrefetchExpires: 100 * time.Millisecond,
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
|
@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
Options: Options{
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
},
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
|
@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
|||
},
|
||||
PrefetchThreshold: 2,
|
||||
PrefetchExpires: 100 * time.Millisecond,
|
||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
||||
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||
v := "v2"
|
||||
|
||||
return &v, 50 * time.Millisecond
|
||||
|
|
|
@ -154,16 +154,16 @@ func NewBlockingResolver(ctx context.Context,
|
|||
|
||||
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
|
||||
CleanupInterval: defaultBlockingCleanUpInterval,
|
||||
}, func(key string) (val *[]net.IP, ttl time.Duration) {
|
||||
return res.queryForFQIdentifierIPs(key)
|
||||
}, func(ctx context.Context, key string) (val *[]net.IP, ttl time.Duration) {
|
||||
return res.queryForFQIdentifierIPs(ctx, key)
|
||||
})
|
||||
|
||||
if res.redisClient != nil {
|
||||
setupRedisEnabledSubscriber(ctx, res)
|
||||
go res.redisSubscriber(ctx)
|
||||
}
|
||||
|
||||
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
|
||||
go res.initFQDNIPCache()
|
||||
go res.initFQDNIPCache(ctx)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -173,29 +173,27 @@ func NewBlockingResolver(ctx context.Context,
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case em := <-c.redisClient.EnabledChannel:
|
||||
if em != nil {
|
||||
c.log().Debug("Received state from redis: ", em)
|
||||
func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case em := <-r.redisClient.EnabledChannel:
|
||||
if em != nil {
|
||||
r.log().Debug("Received state from redis: ", em)
|
||||
|
||||
if em.State {
|
||||
c.internalEnableBlocking()
|
||||
} else {
|
||||
err := c.internalDisableBlocking(em.Duration, em.Groups)
|
||||
if err != nil {
|
||||
c.log().Warn("Blocking couldn't be disabled:", err)
|
||||
}
|
||||
if em.State {
|
||||
r.internalEnableBlocking()
|
||||
} else {
|
||||
err := r.internalDisableBlocking(em.Duration, em.Groups)
|
||||
if err != nil {
|
||||
r.log().Warn("Blocking couldn't be disabled:", err)
|
||||
}
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshLists triggers the refresh of all black and white lists in the cache
|
||||
|
@ -358,7 +356,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
|
|||
return false
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
||||
func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []string,
|
||||
request *model.Request, logger *logrus.Entry,
|
||||
) (bool, *model.Response, error) {
|
||||
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
|
||||
|
@ -371,7 +369,7 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
|||
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
|
||||
logger.WithField("groups", groups).Debugf("domain is whitelisted")
|
||||
|
||||
resp, err := r.next.Resolve(request)
|
||||
resp, err := r.next.Resolve(ctx, request)
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
@ -393,18 +391,18 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
|||
}
|
||||
|
||||
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
|
||||
func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, "blacklist_resolver")
|
||||
groupsToCheck := r.groupsToCheckForClient(request)
|
||||
|
||||
if len(groupsToCheck) > 0 {
|
||||
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
|
||||
handled, resp, err := r.handleBlacklist(ctx, groupsToCheck, request, logger)
|
||||
if handled {
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
respFromNext, err := r.next.Resolve(request)
|
||||
respFromNext, err := r.next.Resolve(ctx, request)
|
||||
|
||||
if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil {
|
||||
for _, rr := range respFromNext.Res.Answer {
|
||||
|
@ -574,7 +572,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP, time.Duration) {
|
||||
func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
|
||||
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
|
||||
|
||||
var result []net.IP
|
||||
|
@ -582,7 +580,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
|
|||
var ttl time.Duration
|
||||
|
||||
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
resp, err := r.next.Resolve(&model.Request{
|
||||
resp, err := r.next.Resolve(ctx, &model.Request{
|
||||
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
|
||||
Log: prefixedLog,
|
||||
})
|
||||
|
@ -606,12 +604,12 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
|
|||
return &result, ttl
|
||||
}
|
||||
|
||||
func (r *BlockingResolver) initFQDNIPCache() {
|
||||
func (r *BlockingResolver) initFQDNIPCache(ctx context.Context) {
|
||||
identifiers := maps.Keys(r.clientGroupsBlock)
|
||||
|
||||
for _, identifier := range identifiers {
|
||||
if isFQDN(identifier) {
|
||||
iPs, ttl := r.queryForFQIdentifierIPs(identifier)
|
||||
iPs, ttl := r.queryForFQIdentifierIPs(ctx, identifier)
|
||||
r.fqdnIPCache.Put(identifier, iPs, ttl)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -155,7 +155,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
}
|
||||
Bus().Publish(ApplicationStarted, "")
|
||||
Eventually(func(g Gomega) {
|
||||
g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
|
||||
g.Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
|
||||
Should(And(
|
||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||
HaveTTL(BeNumerically("==", 60)),
|
||||
|
@ -185,6 +185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
When("Domain is on the black list", func() {
|
||||
It("should block request", func() {
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -222,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("client name is defined in client groups block", func() {
|
||||
It("should block the A query if domain is on the black list (single)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -233,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block the A query if domain is on the black list (multipart 1)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -244,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block the A query if domain is on the black list (multipart 2)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -255,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block the A query if domain is on the black list (merged)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||
|
@ -266,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block the AAAA query if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", AAAA, "::"),
|
||||
|
@ -277,18 +278,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block the HTTPS query if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
|
||||
Should(HaveReturnCode(dns.RcodeNameError))
|
||||
})
|
||||
It("should block the MX query if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
|
||||
Should(HaveReturnCode(dns.RcodeNameError))
|
||||
})
|
||||
})
|
||||
|
||||
When("Client ip is defined in client groups block", func() {
|
||||
It("should block the query if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -301,7 +302,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
||||
It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -313,7 +314,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -328,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
||||
It("should block the query for 10.43.8.64 if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -339,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block the query for 10.43.8.79 if domain is on the black list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -353,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Client has multiple names and for each name a client group block definition exists", func() {
|
||||
It("should block query if domain is in one group", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -364,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
It("should block query if domain is in another group too", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||
|
@ -377,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("Client name matches wildcard", func() {
|
||||
It("should block query if domain is in one group", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -391,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Default group is defined", func() {
|
||||
It("should block domains from default group for each client", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -418,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
It("should return NXDOMAIN if query is blocked", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -444,7 +445,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
It("should return answer with specified TTL", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -461,7 +462,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
It("should return custom IP with specified TTL", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
|
||||
|
@ -489,7 +490,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
It("should return ipv4 address for A query if query is blocked", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
|
||||
|
@ -501,7 +502,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
It("should return ipv6 address for AAAA query if query is blocked", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
||||
|
@ -528,7 +529,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
It("should use fallback for ipv6 and return zero ip", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", AAAA, "::"),
|
||||
|
@ -547,7 +548,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
})
|
||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "0.0.0.0"),
|
||||
|
@ -567,7 +568,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
)
|
||||
})
|
||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", AAAA, "::"),
|
||||
|
@ -590,7 +591,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
|
||||
})
|
||||
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "0.0.0.0"),
|
||||
|
@ -617,7 +618,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("Should not be blocked", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -649,7 +650,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
It("should block everything else except domains on the white list with default group", func() {
|
||||
By("querying domain on the whitelist", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -662,7 +663,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("querying another domain, which is not on the whitelist", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("google.com.", A, "0.0.0.0"),
|
||||
|
@ -678,7 +679,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
It("should block everything else except domains on the white list "+
|
||||
"if multiple white list only groups are defined", func() {
|
||||
By("querying domain on the whitelist", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -691,7 +692,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("querying another domain, which is not on the whitelist", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||
|
@ -706,7 +707,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
It("should block everything else except domains on the white list "+
|
||||
"if multiple white list only groups are defined", func() {
|
||||
By("querying domain on the whitelist group 1", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -719,7 +720,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("querying another domain, which is in the whitelist group 1", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -745,7 +746,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
})
|
||||
It("should not block if DNS answer contains IP from the white list", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.145.123.145"),
|
||||
|
@ -775,7 +776,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("domain is not on the black list", func() {
|
||||
It("should delegate to next resolver", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -792,7 +793,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("should delegate to next resolver", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -819,7 +820,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
When("Disable blocking is called", func() {
|
||||
It("no query should be blocked", func() {
|
||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -830,7 +831,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -847,7 +848,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
By("perform the same query again (defaultGroup)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -861,7 +862,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
By("perform the same query again (group1)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -880,7 +881,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
By("perform the same query again (defaultGroup)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -893,7 +894,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -908,7 +909,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
When("Disable blocking for all groups is called with a duration parameter", func() {
|
||||
It("No query should be blocked only for passed amount of time", func() {
|
||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -918,7 +919,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -941,7 +942,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -954,7 +955,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -974,7 +975,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
// wait 1 sec
|
||||
Eventually(enabled, "1s").Should(Receive(BeTrue()))
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -983,7 +984,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -998,7 +999,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
When("Disable blocking for one group is called with a duration parameter", func() {
|
||||
It("No query should be blocked only for passed amount of time", func() {
|
||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -1008,7 +1009,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
))
|
||||
})
|
||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
@ -1031,7 +1032,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -1042,7 +1043,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -1062,7 +1063,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
// wait 1 sec
|
||||
Eventually(enabled, "1s").Should(Receive(BeTrue()))
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||
|
@ -1071,7 +1072,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||
|
|
|
@ -106,14 +106,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
|
|||
return b, nil
|
||||
}
|
||||
|
||||
func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
|
||||
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
|
||||
hostname := r.upstream.Host
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
|
||||
return newIPSet([]net.IP{ip}), nil
|
||||
}
|
||||
|
||||
ips, err := b.resolveUpstream(r, hostname)
|
||||
ips, err := b.resolveUpstream(ctx, r, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -121,10 +121,10 @@ func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
|
|||
return newIPSet(ips), nil
|
||||
}
|
||||
|
||||
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
||||
func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) {
|
||||
// Use system resolver if no bootstrap is configured
|
||||
if b.resolver == nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), b.timeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, b.timeout)
|
||||
defer cancel()
|
||||
|
||||
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
|
||||
|
@ -135,7 +135,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
return b.resolve(host, b.connectIPVersion.QTypes())
|
||||
return b.resolve(ctx, host, b.connectIPVersion.QTypes())
|
||||
}
|
||||
|
||||
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
|
||||
|
@ -175,7 +175,7 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
|
|||
}
|
||||
|
||||
// Resolve the host with the bootstrap DNS
|
||||
ips, err := b.resolve(host, qTypes)
|
||||
ips, err := b.resolve(ctx, host, qTypes)
|
||||
if err != nil {
|
||||
logger.Errorf("resolve error: %s", err)
|
||||
|
||||
|
@ -192,11 +192,11 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
|
|||
return b.dialer.DialContext(ctx, network, addrWithIP)
|
||||
}
|
||||
|
||||
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
||||
func (b *Bootstrap) resolve(ctx context.Context, hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
||||
ips = make([]net.IP, 0, len(qTypes))
|
||||
|
||||
for _, qType := range qTypes {
|
||||
qIPs, qErr := b.resolveType(hostname, qType)
|
||||
qIPs, qErr := b.resolveType(ctx, hostname, qType)
|
||||
if qErr != nil {
|
||||
err = multierror.Append(err, qErr)
|
||||
|
||||
|
@ -213,7 +213,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
|
|||
return
|
||||
}
|
||||
|
||||
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
|
||||
func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.Type) (ips []net.IP, err error) {
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return []net.IP{ip}, nil
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP,
|
|||
Log: b.log,
|
||||
}
|
||||
|
||||
rsp, err := b.resolver.Resolve(&req)
|
||||
rsp, err := b.resolver.Resolve(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
},
|
||||
}
|
||||
|
||||
_, err := sut.resolveUpstream(nil, "example.com")
|
||||
_, err := sut.resolveUpstream(ctx, nil, "example.com")
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(usedSystemResolver).Should(Receive(BeTrue()))
|
||||
})
|
||||
|
@ -244,7 +244,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
When("called from bootstrap.upstream", func() {
|
||||
It("uses hardcoded IPs", func() {
|
||||
ips, err := sut.resolveUpstream(bootstrapUpstream, "host")
|
||||
ips, err := sut.resolveUpstream(ctx, bootstrapUpstream, "host")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs))
|
||||
|
@ -253,7 +253,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
When("hostname is an IP", func() {
|
||||
It("returns immediately", func() {
|
||||
ips, err := sut.resolve("0.0.0.0", config.IPVersionDual.QTypes())
|
||||
ips, err := sut.resolve(ctx, "0.0.0.0", config.IPVersionDual.QTypes())
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(ips).Should(ContainElement(net.IPv4zero))
|
||||
|
@ -269,7 +269,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
|
||||
|
||||
ips, err := sut.resolve("localhost", []dns.Type{AAAA})
|
||||
ips, err := sut.resolve(ctx, "localhost", []dns.Type{AAAA})
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(ips).Should(HaveLen(1))
|
||||
|
@ -283,7 +283,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr)
|
||||
|
||||
ips, err := sut.resolve("localhost", []dns.Type{A})
|
||||
ips, err := sut.resolve(ctx, "localhost", []dns.Type{A})
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
||||
|
@ -297,7 +297,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
|
||||
|
||||
ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A})
|
||||
ips, err := sut.resolve(ctx, "unknownhost.invalid", []dns.Type{A})
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
||||
|
@ -329,7 +329,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
r := newUpstreamResolverUnchecked(upstream, sut)
|
||||
|
||||
rsp, err := r.Resolve(mainReq)
|
||||
rsp, err := r.Resolve(ctx, mainReq)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1))
|
||||
Expect(rsp.Res.Question[0].Name).Should(Equal("example.com."))
|
||||
|
@ -373,7 +373,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
// implicit expectation of 0 bootstrapUpstream.Resolve calls
|
||||
|
||||
_, err = t.DialContext(context.Background(), "ip", "!bad-addr!")
|
||||
_, err = t.DialContext(ctx, "ip", "!bad-addr!")
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
})
|
||||
|
||||
|
@ -384,7 +384,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
t := sut.NewHTTPTransport()
|
||||
|
||||
_, err = t.DialContext(context.Background(), "ip", "abc:123")
|
||||
_, err = t.DialContext(ctx, "ip", "abc:123")
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
||||
|
@ -397,7 +397,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
t := sut.NewHTTPTransport()
|
||||
|
||||
_, err = t.DialContext(context.Background(), "ip", "abc:123")
|
||||
_, err = t.DialContext(ctx, "ip", "abc:123")
|
||||
|
||||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
||||
|
@ -437,7 +437,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
|
||||
Describe("resolve", func() {
|
||||
AfterEach(func() {
|
||||
_, err := sut.resolveUpstream(nil, "example.com")
|
||||
_, err := sut.resolveUpstream(ctx, nil, "example.com")
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
@ -501,7 +501,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
AfterEach(func() {
|
||||
t := sut.NewHTTPTransport()
|
||||
|
||||
conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0")
|
||||
conn, err := t.DialContext(ctx, dialIPVersion.Net(), "localhost:0")
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(conn).Should(Equal(aMockConn))
|
||||
|
||||
|
@ -583,7 +583,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
})
|
||||
|
||||
It("uses both", func() {
|
||||
_, err := sut.resolve("example.com.", []dns.Type{dns.Type(dns.TypeA)})
|
||||
_, err := sut.resolve(ctx, "example.com.", []dns.Type{dns.Type(dns.TypeA)})
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ func newCachingResolver(ctx context.Context,
|
|||
configureCaches(ctx, c, &cfg)
|
||||
|
||||
if c.redisClient != nil {
|
||||
setupRedisCacheSubscriber(ctx, c)
|
||||
go c.redisSubscriber(ctx)
|
||||
c.redisClient.GetRedisCache()
|
||||
}
|
||||
|
||||
|
@ -105,14 +105,14 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
|
|||
}
|
||||
}
|
||||
|
||||
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Duration) {
|
||||
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
|
||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||
logger := r.log()
|
||||
|
||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||
|
||||
req := newRequest(dns.Fqdn(domainName), qType, logger)
|
||||
response, err := r.next.Resolve(req)
|
||||
response, err := r.next.Resolve(ctx, req)
|
||||
|
||||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
|
@ -132,22 +132,20 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Durat
|
|||
return nil, 0
|
||||
}
|
||||
|
||||
func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case rc := <-c.redisClient.CacheChannel:
|
||||
if rc != nil {
|
||||
c.log().Debug("Received key from redis: ", rc.Key)
|
||||
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
||||
c.putInCache(rc.Key, rc.Response, ttl, false)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case rc := <-r.redisClient.CacheChannel:
|
||||
if rc != nil {
|
||||
r.log().Debug("Received key from redis: ", rc.Key)
|
||||
ttl := r.adjustTTLs(rc.Response.Res.Answer)
|
||||
r.putInCache(rc.Key, rc.Response, ttl, false)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// LogConfig implements `config.Configurable`.
|
||||
|
@ -159,13 +157,13 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
|||
|
||||
// Resolve checks if the current query should use the cache and if the result is already in
|
||||
// the cache and returns it or delegates to the next resolver
|
||||
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
||||
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
|
||||
logger := log.WithPrefix(request.Log, "caching_resolver")
|
||||
|
||||
if !r.IsEnabled() || !isRequestCacheable(request) {
|
||||
logger.Debug("skip cache")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
for _, question := range request.Req.Question {
|
||||
|
@ -191,7 +189,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
|||
}
|
||||
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
|
||||
response, err = r.next.Resolve(request)
|
||||
response, err = r.next.Resolve(ctx, request)
|
||||
|
||||
if err == nil {
|
||||
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
||||
|
@ -319,7 +317,7 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
|
|||
}
|
||||
}
|
||||
|
||||
func (r *CachingResolver) FlushCaches() {
|
||||
func (r *CachingResolver) FlushCaches(context.Context) {
|
||||
r.log().Debug("flush caches")
|
||||
r.resultCache.Clear()
|
||||
}
|
||||
|
|
|
@ -114,7 +114,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})).Should(Succeed())
|
||||
|
||||
// first request
|
||||
_, _ = sut.Resolve(newRequest("example.com.", A))
|
||||
_, _ = sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
|
||||
// Domain is not prefetched
|
||||
Expect(domainPrefetched).ShouldNot(Receive())
|
||||
|
@ -124,7 +124,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
// now query again > threshold
|
||||
for i := 0; i < prefetchThreshold+1; i++ {
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(Succeed())
|
||||
}
|
||||
|
||||
|
@ -132,7 +132,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
|
||||
|
||||
// and it should hit from prefetch cache
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
@ -156,7 +156,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("should cache response and use response's TTL for multiple records", func() {
|
||||
By("first request", func() {
|
||||
result, err := sut.Resolve(newRequest("example.com.", A))
|
||||
result, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(result).
|
||||
Should(
|
||||
|
@ -176,7 +176,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
By("second request", func() {
|
||||
Eventually(func(g Gomega) {
|
||||
result, err := sut.Resolve(newRequest("example.com.", A))
|
||||
result, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
g.Expect(err).Should(Succeed())
|
||||
g.Expect(result).
|
||||
Should(
|
||||
|
@ -218,7 +218,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
_ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) {
|
||||
totalCacheCount <- d
|
||||
})
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -239,7 +239,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
domain <- true
|
||||
})
|
||||
|
||||
g.Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
g.Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
@ -264,7 +264,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
It("should cache response and use min caching time as TTL", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -276,7 +276,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", A)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
@ -299,7 +301,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
It("should cache response and use min caching time as TTL", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -310,7 +312,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", AAAA)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
@ -344,7 +348,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
It("Shouldn't cache any responses", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -355,7 +359,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", AAAA)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -378,7 +384,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("should cache response and use max caching time as TTL if response TTL is bigger", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -389,7 +395,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", AAAA)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
@ -417,7 +425,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("should cache response and return 0 TTL if entry is expired", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -430,7 +438,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)).
|
||||
Eventually(sut.Resolve, "2s").
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", A)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
@ -461,7 +471,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeNameError),
|
||||
|
@ -472,7 +482,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", AAAA)).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
HaveReason("CACHED NEGATIVE"),
|
||||
|
@ -495,7 +507,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
It("response shouldn't be cached", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeNameError),
|
||||
|
@ -506,7 +518,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", AAAA)).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReason(""),
|
||||
|
@ -529,7 +543,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
It("response should be cached", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -540,7 +554,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("example.com.", AAAA)).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
HaveReason("CACHED"),
|
||||
|
@ -563,7 +579,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("Should be cached", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("google.de.", MX))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.de.", MX))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -575,7 +591,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("google.de.", MX)).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
HaveReason("CACHED"),
|
||||
|
@ -599,7 +617,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("Should not be cached", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -611,7 +629,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -633,7 +651,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("Should not be cached", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -645,7 +663,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -671,7 +689,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
It("Should not be cached", func() {
|
||||
By("first request", func() {
|
||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
|
@ -688,7 +706,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
By("second request", func() {
|
||||
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", A)).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(newRequest("google.de.", A)).
|
||||
Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
HaveReason("CACHED"),
|
||||
|
@ -750,7 +770,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
})
|
||||
|
||||
It("put in redis", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(HaveResponseType(ResponseTypeRESOLVED))
|
||||
|
||||
Eventually(func() []string {
|
||||
|
@ -772,7 +792,9 @@ var _ = Describe("CachingResolver", func() {
|
|||
}
|
||||
redisClient.CacheChannel <- redisMockMsg
|
||||
|
||||
Eventually(sut.Resolve).WithArguments(request).
|
||||
Eventually(sut.Resolve).
|
||||
WithContext(ctx).
|
||||
WithArguments(request).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeCACHED),
|
||||
|
|
|
@ -32,7 +32,7 @@ func NewClientNamesResolver(ctx context.Context,
|
|||
) (cr *ClientNamesResolver, err error) {
|
||||
var r Resolver
|
||||
if !cfg.Upstream.IsDefault() {
|
||||
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams)
|
||||
r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -59,17 +59,17 @@ func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
|
|||
}
|
||||
|
||||
// Resolve tries to resolve the client name from the ip address
|
||||
func (r *ClientNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
clientNames := r.getClientNames(request)
|
||||
func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
clientNames := r.getClientNames(ctx, request)
|
||||
|
||||
request.ClientNames = clientNames
|
||||
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
// returns names of client
|
||||
func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
||||
func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model.Request) []string {
|
||||
if request.RequestClientID != "" {
|
||||
return []string{request.RequestClientID}
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
|||
return cpy
|
||||
}
|
||||
|
||||
names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver"))
|
||||
names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))
|
||||
|
||||
r.cache.Put(ip.String(), &names, time.Hour)
|
||||
|
||||
|
@ -111,7 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
|
|||
}
|
||||
|
||||
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
|
||||
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
|
||||
func (r *ClientNamesResolver) resolveClientNames(
|
||||
ctx context.Context, ip net.IP, logger *logrus.Entry,
|
||||
) (result []string) {
|
||||
// try client mapping first
|
||||
result = r.getNameFromIPMapping(ip, result)
|
||||
if len(result) > 0 {
|
||||
|
@ -124,7 +126,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
|
|||
|
||||
reverse, _ := dns.ReverseAddr(ip.String())
|
||||
|
||||
resp, err := r.externalResolver.Resolve(&model.Request{
|
||||
resp, err := r.externalResolver.Resolve(ctx, &model.Request{
|
||||
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
||||
Log: logger,
|
||||
})
|
||||
|
|
|
@ -22,8 +22,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
sut *ClientNamesResolver
|
||||
sutConfig config.ClientLookupConfig
|
||||
m *mockResolver
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -34,6 +35,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
JustBeforeEach(func() {
|
||||
var err error
|
||||
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
|
@ -71,7 +73,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should use clientID if set", func() {
|
||||
request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -82,7 +84,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
It("should use IP as fallback if clientID not set", func() {
|
||||
request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -112,7 +114,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should resolve defined name with ipv4 address", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -124,7 +126,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should resolve defined name with ipv6 address", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -135,7 +137,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
It("should resolve multiple names defined names", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -168,7 +170,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
It("should resolve client name", func() {
|
||||
By("first request", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -180,7 +182,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
By("second request", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -198,7 +200,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
By("third request", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -223,7 +225,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should resolve all client names", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -251,7 +253,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should resolve client name", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -272,7 +274,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should resolve the client name depending to defined order", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -298,7 +300,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
|
||||
It("should use fallback for client name", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -318,7 +320,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
It("should use fallback for client name", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -335,7 +337,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
It("should resolve no names", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -351,7 +353,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
|||
})
|
||||
It("should use fallback for client name", func() {
|
||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -23,14 +24,14 @@ type ConditionalUpstreamResolver struct {
|
|||
|
||||
// NewConditionalUpstreamResolver returns new resolver instance
|
||||
func NewConditionalUpstreamResolver(
|
||||
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (*ConditionalUpstreamResolver, error) {
|
||||
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
|
||||
|
||||
for domain, upstreams := range cfg.Mapping.Upstreams {
|
||||
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
|
||||
|
||||
r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams)
|
||||
r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -48,7 +49,9 @@ func NewConditionalUpstreamResolver(
|
|||
return &r, nil
|
||||
}
|
||||
|
||||
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
|
||||
func (r *ConditionalUpstreamResolver) processRequest(
|
||||
ctx context.Context, request *model.Request,
|
||||
) (bool, *model.Response, error) {
|
||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||
domain := domainFromQuestion
|
||||
|
||||
|
@ -56,7 +59,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
|
|||
// try with domain with and without sub-domains
|
||||
for len(domain) > 0 {
|
||||
if resolver, found := r.mapping[domain]; found {
|
||||
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
@ -68,7 +71,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
|
|||
}
|
||||
}
|
||||
} else if resolver, found := r.mapping["."]; found {
|
||||
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
@ -77,11 +80,11 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
|
|||
}
|
||||
|
||||
// Resolve uses the conditional resolver to resolve the query
|
||||
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, "conditional_resolver")
|
||||
|
||||
if len(r.mapping) > 0 {
|
||||
resolved, resp, err := r.processRequest(request)
|
||||
resolved, resp, err := r.processRequest(ctx, request)
|
||||
if resolved {
|
||||
return resp, err
|
||||
}
|
||||
|
@ -89,17 +92,17 @@ func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Re
|
|||
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
func (r *ConditionalUpstreamResolver) internalResolve(reso Resolver, doFQ, do string,
|
||||
func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso Resolver, doFQ, do string,
|
||||
req *model.Request,
|
||||
) (*model.Response, error) {
|
||||
// internal request resolution
|
||||
logger := log.WithPrefix(req.Log, "conditional_resolver")
|
||||
|
||||
req.Req.Question[0].Name = dns.Fqdn(doFQ)
|
||||
response, err := reso.Resolve(req)
|
||||
response, err := reso.Resolve(ctx, req)
|
||||
|
||||
if err == nil {
|
||||
response.Reason = "CONDITIONAL"
|
||||
|
|
|
@ -19,6 +19,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
var (
|
||||
sut *ConditionalUpstreamResolver
|
||||
m *mockResolver
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -28,6 +31,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
|
||||
|
||||
|
@ -59,7 +65,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
})
|
||||
DeferCleanup(refuseTestUpstream.Close)
|
||||
|
||||
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
|
||||
sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
|
||||
Mapping: config.ConditionalUpstreamMapping{
|
||||
Upstreams: map[string][]config.Upstream{
|
||||
"fritz.box": {fbTestUpstream.Start()},
|
||||
|
@ -93,7 +99,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
Describe("Resolve conditional DNS queries via defined DNS server", func() {
|
||||
When("conditional resolver returns error code", func() {
|
||||
It("Should be returned without changes", func() {
|
||||
Expect(sut.Resolve(newRequest("refused.domain.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("refused.domain.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -109,7 +115,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
When("Query is exact equal defined condition in mapping", func() {
|
||||
Context("first mapping entry", func() {
|
||||
It("Should resolve the IP of conditional DNS", func() {
|
||||
Expect(sut.Resolve(newRequest("fritz.box.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("fritz.box.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("fritz.box.", A, "123.124.122.122"),
|
||||
|
@ -125,7 +131,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
})
|
||||
Context("last mapping entry", func() {
|
||||
It("Should resolve the IP of conditional DNS", func() {
|
||||
Expect(sut.Resolve(newRequest("other.box.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("other.box.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("other.box.", A, "192.192.192.192"),
|
||||
|
@ -141,7 +147,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
})
|
||||
When("Query is a subdomain of defined condition in mapping", func() {
|
||||
It("Should resolve the IP of subdomain", func() {
|
||||
Expect(sut.Resolve(newRequest("test.fritz.box.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("test.fritz.box.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("test.fritz.box.", A, "123.124.122.122"),
|
||||
|
@ -156,7 +162,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
})
|
||||
When("Query is not fqdn and . condition is defined in mapping", func() {
|
||||
It("Should resolve the IP of .", func() {
|
||||
Expect(sut.Resolve(newRequest("test.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("test.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("test.", A, "168.168.168.168"),
|
||||
|
@ -173,7 +179,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
Describe("Delegation to next resolver", func() {
|
||||
When("Query doesn't match defined mapping", func() {
|
||||
It("should delegate to next resolver", func() {
|
||||
Expect(sut.Resolve(newRequest("google.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("google.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -186,11 +192,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
|||
|
||||
When("upstream is invalid", func() {
|
||||
It("errors during construction", func() {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
|
||||
r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
|
||||
Mapping: config.ConditionalUpstreamMapping{
|
||||
Upstreams: map[string][]config.Upstream{
|
||||
".": {config.Upstream{Host: "example.com"}},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
|
@ -123,7 +124,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
|||
}
|
||||
|
||||
// Resolve uses internal mapping to resolve the query
|
||||
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
|
||||
|
||||
reverseResp := r.handleReverseDNS(request)
|
||||
|
@ -140,5 +141,5 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
@ -21,6 +22,9 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
sut *CustomDNSResolver
|
||||
m *mockResolver
|
||||
cfg config.CustomDNSConfig
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -30,6 +34,9 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
cfg = config.CustomDNSConfig{
|
||||
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{
|
||||
"custom.domain": {net.ParseIP("192.168.143.123")},
|
||||
|
@ -73,7 +80,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
Context("filterUnmappedTypes is true", func() {
|
||||
BeforeEach(func() { cfg.FilterUnmappedTypes = true })
|
||||
It("defined ip4 query should be resolved", func() {
|
||||
Expect(sut.Resolve(newRequest("custom.domain.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
|
||||
|
@ -86,7 +93,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
It("TXT query for defined mapping should return NOERROR and empty result", func() {
|
||||
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
|
||||
Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -98,7 +105,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
It("ip6 query should return NOERROR and empty result", func() {
|
||||
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -114,7 +121,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
Context("filterUnmappedTypes is false", func() {
|
||||
BeforeEach(func() { cfg.FilterUnmappedTypes = false })
|
||||
It("defined ip4 query should be resolved", func() {
|
||||
Expect(sut.Resolve(newRequest("custom.domain.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
|
||||
|
@ -127,7 +134,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||
})
|
||||
It("TXT query for defined mapping should be delegated to next resolver", func() {
|
||||
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
|
||||
Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -139,7 +146,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
It("ip6 query should return NOERROR and empty result", func() {
|
||||
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -154,7 +161,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
When("Ip 6 mapping is defined for custom domain ", func() {
|
||||
It("ip6 query should be resolved", func() {
|
||||
Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("ip6.domain.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
||||
|
@ -170,7 +177,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
When("Multiple IPs are defined for custom domain ", func() {
|
||||
It("all IPs for the current type should be returned", func() {
|
||||
By("IPv6 query", func() {
|
||||
Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("multiple.ips.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
||||
|
@ -185,7 +192,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
|
||||
By("IPv4 query", func() {
|
||||
Expect(sut.Resolve(newRequest("multiple.ips.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("multiple.ips.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
WithTransform(ToAnswer, SatisfyAll(
|
||||
|
@ -207,7 +214,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
When("Reverse DNS request is received", func() {
|
||||
It("should resolve the defined domain name", func() {
|
||||
By("ipv4", func() {
|
||||
Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))).
|
||||
Expect(sut.Resolve(ctx, newRequest("123.143.168.192.in-addr.arpa.", PTR))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
WithTransform(ToAnswer, SatisfyAll(
|
||||
|
@ -226,7 +233,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
|
||||
By("ipv6", func() {
|
||||
Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
Expect(sut.Resolve(ctx, newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
PTR))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -250,7 +257,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
})
|
||||
When("Domain mapping is defined", func() {
|
||||
It("subdomain must also match", func() {
|
||||
Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("ABC.CUSTOM.DOMAIN.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"),
|
||||
|
@ -268,7 +275,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
|||
Describe("Delegating to next resolver", func() {
|
||||
When("no mapping for domain exist", func() {
|
||||
It("should delegate to next resolver", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
|
@ -46,7 +47,7 @@ func NewECSResolver(cfg config.ECS) ChainedResolver {
|
|||
|
||||
// Resolve adds the subnet information as EDNS0 option to the request of the next resolver
|
||||
// and sets the client IP from the EDNS0 option to the request if this option is enabled
|
||||
func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *ECSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
if r.cfg.IsEnabled() {
|
||||
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req)
|
||||
// Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set
|
||||
|
@ -67,7 +68,7 @@ func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
// setSubnet appends the subnet information to the request as EDNS0 option
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -25,6 +26,9 @@ var _ = Describe("EcsResolver", func() {
|
|||
err error
|
||||
origIP net.IP
|
||||
ecsIP net.IP
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -34,6 +38,9 @@ var _ = Describe("EcsResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
err = defaults.Set(&sutConfig)
|
||||
Expect(err).ShouldNot(HaveOccurred())
|
||||
|
||||
|
@ -86,13 +93,13 @@ var _ = Describe("EcsResolver", func() {
|
|||
|
||||
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
|
||||
|
||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
||||
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||
Expect(req.ClientIP).Should(Equal(ecsIP))
|
||||
|
||||
return respondWith(mockAnswer), nil
|
||||
}
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -107,13 +114,13 @@ var _ = Describe("EcsResolver", func() {
|
|||
|
||||
addEcsOption(request.Req, ecsIP, 24)
|
||||
|
||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
||||
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||
Expect(req.ClientIP).Should(Equal(origIP))
|
||||
|
||||
return respondWith(mockAnswer), nil
|
||||
}
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -134,14 +141,14 @@ var _ = Describe("EcsResolver", func() {
|
|||
request := newRequest("example.com.", A)
|
||||
request.ClientIP = origIP
|
||||
|
||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
||||
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||
Expect(req.ClientIP).Should(Equal(origIP))
|
||||
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
|
||||
|
||||
return respondWith(mockAnswer), nil
|
||||
}
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -154,13 +161,13 @@ var _ = Describe("EcsResolver", func() {
|
|||
request := newRequest("example.com.", AAAA)
|
||||
request.ClientIP = net.ParseIP("2001:db8::68")
|
||||
|
||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
||||
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
|
||||
|
||||
return respondWith(mockAnswer), nil
|
||||
}
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
@ -25,12 +27,12 @@ func NewEDEResolver(cfg config.EDE) *EDEResolver {
|
|||
|
||||
// Resolve adds the reason as EDNS0 option to the response of the next resolver
|
||||
// if it is enabled in the configuration
|
||||
func (r *EDEResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *EDEResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
if !r.cfg.Enable {
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
resp, err := r.next.Resolve(request)
|
||||
resp, err := r.next.Resolve(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
|
@ -24,6 +25,9 @@ var _ = Describe("EdeResolver", func() {
|
|||
sutConfig config.EDE
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -33,6 +37,9 @@ var _ = Describe("EdeResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
mockAnswer = new(dns.Msg)
|
||||
})
|
||||
|
||||
|
@ -57,7 +64,7 @@ var _ = Describe("EdeResolver", func() {
|
|||
}
|
||||
})
|
||||
It("shouldn't add EDE information", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -89,7 +96,7 @@ var _ = Describe("EdeResolver", func() {
|
|||
}
|
||||
|
||||
It("should add EDE information", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -115,7 +122,7 @@ var _ = Describe("EdeResolver", func() {
|
|||
})
|
||||
|
||||
It("shouldn't add EDE information", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -137,7 +144,7 @@ var _ = Describe("EdeResolver", func() {
|
|||
})
|
||||
|
||||
It("should return it", func() {
|
||||
resp, err := sut.Resolve(newRequest("example.com", A))
|
||||
resp, err := sut.Resolve(ctx, newRequest("example.com", A))
|
||||
Expect(resp).To(BeNil())
|
||||
Expect(err).To(Equal(resolveErr))
|
||||
})
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -21,7 +23,7 @@ func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *FilteringResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
qType := request.Req.Question[0].Qtype
|
||||
if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
|
||||
response := new(dns.Msg)
|
||||
|
@ -30,5 +32,5 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil
|
||||
}
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -18,6 +20,9 @@ var _ = Describe("FilteringResolver", func() {
|
|||
sutConfig config.FilteringConfig
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -27,6 +32,9 @@ var _ = Describe("FilteringResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
mockAnswer = new(dns.Msg)
|
||||
})
|
||||
|
||||
|
@ -60,7 +68,7 @@ var _ = Describe("FilteringResolver", func() {
|
|||
}
|
||||
})
|
||||
It("Should delegate to next resolver if request query has other type", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -72,7 +80,7 @@ var _ = Describe("FilteringResolver", func() {
|
|||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
It("Should return empty answer for defined query type", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -90,7 +98,7 @@ var _ = Describe("FilteringResolver", func() {
|
|||
sutConfig = config.FilteringConfig{}
|
||||
})
|
||||
It("Should return empty answer without error", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -22,7 +23,7 @@ func NewFQDNOnlyResolver(cfg config.FQDNOnly) *FQDNOnlyResolver {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *FQDNOnlyResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
if r.IsEnabled() {
|
||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||
if !strings.Contains(domainFromQuestion, ".") {
|
||||
|
@ -33,5 +34,5 @@ func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
}
|
||||
}
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -17,6 +19,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
sutConfig config.FQDNOnly
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -26,6 +31,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
mockAnswer = new(dns.Msg)
|
||||
})
|
||||
|
||||
|
@ -57,7 +65,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
sutConfig = config.FQDNOnly{Enable: true}
|
||||
})
|
||||
It("Should delegate to next resolver if request query is fqdn", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -69,7 +77,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
It("Should return NXDOMAIN if request query is not fqdn", func() {
|
||||
Expect(sut.Resolve(newRequest("example", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -103,7 +111,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
sutConfig = config.FQDNOnly{Enable: false}
|
||||
})
|
||||
It("Should delegate to next resolver if request query is fqdn", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -115,7 +123,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
|||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
It("Should delegate to next resolver if request query is not fqdn", func() {
|
||||
Expect(sut.Resolve(newRequest("example", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
|
|
@ -109,9 +109,9 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
if !r.IsEnabled() {
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
reverseResp := r.handleReverseDNS(request)
|
||||
|
@ -134,7 +134,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
|
||||
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg {
|
||||
|
|
|
@ -25,6 +25,9 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
tmpFile *TmpFile
|
||||
err error
|
||||
resp *Response
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -34,6 +37,9 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
tmpDir = NewTmpFolder("HostsFileResolver")
|
||||
Expect(tmpDir.Error).Should(Succeed())
|
||||
DeferCleanup(tmpDir.Clean)
|
||||
|
@ -53,8 +59,6 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
|
@ -96,7 +100,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
Expect(sut.hosts.isEmpty()).Should(BeTrue())
|
||||
})
|
||||
It("should go to next resolver on query", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -112,11 +116,11 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
sutConfig.Sources = make([]config.BytesSource, 0)
|
||||
})
|
||||
JustBeforeEach(func() {
|
||||
err = sut.loadSources(context.Background())
|
||||
err = sut.loadSources(ctx)
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("should go to next resolver on query", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -178,7 +182,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("IPv4 mapping is defined for a host", func() {
|
||||
It("defined ipv4 query should be resolved", func() {
|
||||
Expect(sut.Resolve(newRequest("ipv4host.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("ipv4host.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -188,7 +192,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
))
|
||||
})
|
||||
It("defined ipv4 query for alias should be resolved", func() {
|
||||
Expect(sut.Resolve(newRequest("router2.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("router2.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -198,7 +202,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
))
|
||||
})
|
||||
It("ipv4 query should return NOERROR and empty result", func() {
|
||||
Expect(sut.Resolve(newRequest("does.not.exist.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("does.not.exist.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -210,7 +214,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("IPv6 mapping is defined for a host", func() {
|
||||
It("defined ipv6 query should be resolved", func() {
|
||||
Expect(sut.Resolve(newRequest("ipv6host.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("ipv6host.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -220,7 +224,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
))
|
||||
})
|
||||
It("ipv6 query should return NOERROR and empty result", func() {
|
||||
Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))).
|
||||
Expect(sut.Resolve(ctx, newRequest("does.not.exist.", AAAA))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -232,7 +236,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("the domain is not known", func() {
|
||||
It("calls the next resolver", func() {
|
||||
resp, err = sut.Resolve(newRequest("not-in-hostsfile.tld.", A))
|
||||
resp, err = sut.Resolve(ctx, newRequest("not-in-hostsfile.tld.", A))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
@ -241,7 +245,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("the question type is not handled", func() {
|
||||
It("calls the next resolver", func() {
|
||||
resp, err = sut.Resolve(newRequest("localhost.", MX))
|
||||
resp, err = sut.Resolve(ctx, newRequest("localhost.", MX))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
@ -251,7 +255,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
When("Reverse DNS request is received", func() {
|
||||
It("should resolve the defined domain name", func() {
|
||||
By("ipv4 with one hostname", func() {
|
||||
Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))).
|
||||
Expect(sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.arpa.", PTR))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -261,7 +265,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
))
|
||||
})
|
||||
By("ipv4 with aliases", func() {
|
||||
Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))).
|
||||
Expect(sut.Resolve(ctx, newRequest("1.0.0.10.in-addr.arpa.", PTR))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -274,7 +278,9 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
))
|
||||
})
|
||||
By("ipv6", func() {
|
||||
Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))).
|
||||
Expect(sut.Resolve(ctx,
|
||||
newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR)),
|
||||
).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -290,7 +296,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
})
|
||||
|
||||
It("should ignore invalid PTR", func() {
|
||||
resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
|
||||
resp, err = sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
@ -298,7 +304,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("filterLoopback is true", func() {
|
||||
It("calls the next resolver", func() {
|
||||
resp, err = sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR))
|
||||
resp, err = sut.Resolve(ctx, newRequest("1.0.0.127.in-addr.arpa.", PTR))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
@ -307,7 +313,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
|
||||
When("the IP is not known", func() {
|
||||
It("calls the next resolver", func() {
|
||||
resp, err = sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR))
|
||||
resp, err = sut.Resolve(ctx, newRequest("255.255.255.255.in-addr.arpa.", PTR))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
@ -320,7 +326,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
})
|
||||
|
||||
It("resolve the defined domain name", func() {
|
||||
Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))).
|
||||
Expect(sut.Resolve(ctx, newRequest("1.1.0.127.in-addr.arpa.", PTR))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||
|
@ -338,7 +344,7 @@ var _ = Describe("HostsFileResolver", func() {
|
|||
Describe("Delegating to next resolver", func() {
|
||||
When("no hosts file is provided", func() {
|
||||
It("should delegate to next resolver", func() {
|
||||
_, err = sut.Resolve(newRequest("example.com.", A))
|
||||
_, err = sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(Succeed())
|
||||
// delegate was executed
|
||||
m.AssertExpectations(GinkgoT())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -25,8 +26,8 @@ type MetricsResolver struct {
|
|||
}
|
||||
|
||||
// Resolve resolves the passed request
|
||||
func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
response, err := r.next.Resolve(request)
|
||||
func (r *MetricsResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
response, err := r.next.Resolve(ctx, request)
|
||||
|
||||
if r.cfg.Enable {
|
||||
r.totalQueries.With(prometheus.Labels{
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -21,6 +22,9 @@ var _ = Describe("MetricResolver", func() {
|
|||
var (
|
||||
sut *MetricsResolver
|
||||
m *mockResolver
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -30,6 +34,9 @@ var _ = Describe("MetricResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||
|
@ -56,7 +63,7 @@ var _ = Describe("MetricResolver", func() {
|
|||
Context("Recording request metrics", func() {
|
||||
When("Request will be performed", func() {
|
||||
It("Should record metrics", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -77,7 +84,7 @@ var _ = Describe("MetricResolver", func() {
|
|||
sut.Next(m)
|
||||
})
|
||||
It("Error should be recorded", func() {
|
||||
_, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))
|
||||
_, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
@ -62,6 +63,21 @@ func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response
|
|||
return t
|
||||
}
|
||||
|
||||
func (t *MockUDPUpstreamServer) WithDelay(delay time.Duration) *MockUDPUpstreamServer {
|
||||
answerFn := t.answerFn
|
||||
if answerFn == nil {
|
||||
panic("WithDelay must be called after a WithAnswer function")
|
||||
}
|
||||
|
||||
t.answerFn = func(request *dns.Msg) *dns.Msg {
|
||||
time.Sleep(delay)
|
||||
|
||||
return answerFn(request)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *MockUDPUpstreamServer) GetCallCount() int {
|
||||
return int(atomic.LoadInt32(&t.callCount))
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ type mockResolver struct {
|
|||
mock.Mock
|
||||
NextResolver
|
||||
|
||||
ResolveFn func(req *model.Request) (*model.Response, error)
|
||||
ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error)
|
||||
ResponseFn func(req *dns.Msg) *dns.Msg
|
||||
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
|
||||
}
|
||||
|
@ -45,11 +45,11 @@ func (r *mockResolver) LogConfig(*logrus.Entry) {
|
|||
r.Called()
|
||||
}
|
||||
|
||||
func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {
|
||||
func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
args := r.Called(req)
|
||||
|
||||
if r.ResolveFn != nil {
|
||||
return r.ResolveFn(req)
|
||||
return r.ResolveFn(ctx, req)
|
||||
}
|
||||
|
||||
if r.ResponseFn != nil {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -28,6 +30,6 @@ func (NoOpResolver) IsEnabled() bool {
|
|||
func (NoOpResolver) LogConfig(*logrus.Entry) {
|
||||
}
|
||||
|
||||
func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) {
|
||||
func (NoOpResolver) Resolve(context.Context, *model.Request) (*model.Response, error) {
|
||||
return NoResponse, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
@ -8,7 +10,12 @@ import (
|
|||
)
|
||||
|
||||
var _ = Describe("NoOpResolver", func() {
|
||||
var sut *NoOpResolver
|
||||
var (
|
||||
sut *NoOpResolver
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
It("follows conventions", func() {
|
||||
|
@ -17,12 +24,15 @@ var _ = Describe("NoOpResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sut = NewNoOpResolver()
|
||||
})
|
||||
|
||||
Describe("Resolving", func() {
|
||||
It("returns no response", func() {
|
||||
resp, err := sut.Resolve(newRequest("test.tld", A))
|
||||
resp, err := sut.Resolve(ctx, newRequest("test.tld", A))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).Should(Equal(NoResponse))
|
||||
})
|
||||
|
|
|
@ -53,17 +53,23 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
|
|||
return status
|
||||
}
|
||||
|
||||
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) {
|
||||
resp, err := r.resolver.Resolve(req)
|
||||
func (r *upstreamResolverStatus) resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
resp, err := r.resolver.Resolve(ctx, req)
|
||||
if err != nil {
|
||||
// Ignore `Canceled`: resolver lost the race, not an error
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
r.lastErrorTime.Store(time.Now())
|
||||
}
|
||||
|
||||
err = fmt.Errorf("%s: %w", r.resolver, err)
|
||||
return nil, fmt.Errorf("%s: %w", r.resolver, err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (r *upstreamResolverStatus) resolveToChan(ctx context.Context, req *model.Request, ch chan<- requestResponse) {
|
||||
resp, err := r.resolve(ctx, req)
|
||||
|
||||
ch <- requestResponse{
|
||||
resolver: &r.resolver,
|
||||
response: resp,
|
||||
|
@ -78,10 +84,10 @@ type requestResponse struct {
|
|||
}
|
||||
|
||||
// testResolver sends a test query to verify the resolver is reachable and working
|
||||
func testResolver(r *UpstreamResolver) error {
|
||||
func testResolver(ctx context.Context, r *UpstreamResolver) error {
|
||||
request := newRequest("github.com.", dns.Type(dns.TypeA))
|
||||
|
||||
resp, err := r.Resolve(request)
|
||||
resp, err := r.Resolve(ctx, request)
|
||||
if err != nil || resp.RType != model.ResponseTypeRESOLVED {
|
||||
return fmt.Errorf("test resolve of upstream server failed: %w", err)
|
||||
}
|
||||
|
@ -91,11 +97,11 @@ func testResolver(r *UpstreamResolver) error {
|
|||
|
||||
// NewParallelBestResolver creates new resolver instance
|
||||
func NewParallelBestResolver(
|
||||
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (*ParallelBestResolver, error) {
|
||||
logger := log.PrefixedLog(parallelResolverType)
|
||||
|
||||
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
|
||||
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -150,26 +156,18 @@ func (r *ParallelBestResolver) String() string {
|
|||
}
|
||||
|
||||
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
|
||||
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, parallelResolverType)
|
||||
|
||||
if len(r.resolvers) == 1 {
|
||||
logger.WithField("resolver", r.resolvers[0].resolver).Debug("delegating to resolver")
|
||||
resolver := r.resolvers[0]
|
||||
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
|
||||
|
||||
return r.resolvers[0].resolver.Resolve(request)
|
||||
return resolver.resolve(ctx, request)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// using context with timeout for random upstream strategy
|
||||
if r.resolverCount == 1 {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
timeout := config.GetConfig().Upstreams.Timeout
|
||||
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
|
||||
defer cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel() // abort requests to resolvers that lost the race
|
||||
|
||||
resolvers := pickRandom(r.resolvers, r.resolverCount)
|
||||
ch := make(chan requestResponse, len(resolvers))
|
||||
|
@ -177,10 +175,10 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
for _, resolver := range resolvers {
|
||||
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
|
||||
|
||||
go resolver.resolve(request, ch)
|
||||
go resolver.resolveToChan(ctx, request, ch)
|
||||
}
|
||||
|
||||
response, collectedErrors := evaluateResponses(ctx, logger, ch, resolvers)
|
||||
response, collectedErrors := evaluateResponses(logger, ch, resolvers)
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
|
@ -189,63 +187,51 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...))
|
||||
}
|
||||
|
||||
return r.retryWithDifferent(logger, request, resolvers)
|
||||
return r.retryWithDifferent(ctx, logger, request, resolvers)
|
||||
}
|
||||
|
||||
func evaluateResponses(
|
||||
ctx context.Context, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
|
||||
logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
|
||||
) (*model.Response, []error) {
|
||||
collectedErrors := make([]error, 0, len(resolvers))
|
||||
|
||||
for len(collectedErrors) < len(resolvers) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// this context currently only has a deadline when resolverCount == 1
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
logger.WithField("resolver", resolvers[0].resolver).
|
||||
Debug("upstream exceeded timeout, trying other upstream")
|
||||
resolvers[0].lastErrorTime.Store(time.Now())
|
||||
}
|
||||
case result := <-ch:
|
||||
if result.err != nil {
|
||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
|
||||
} else {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *result.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
result := <-ch
|
||||
logger := logger.WithField("resolver", *result.resolver)
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
if result.err != nil {
|
||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithField("answer", util.AnswerToString(result.response.Res.Answer)).Debug("using response from resolver")
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
|
||||
return nil, collectedErrors
|
||||
}
|
||||
|
||||
func (r *ParallelBestResolver) retryWithDifferent(
|
||||
logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
|
||||
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
|
||||
) (*model.Response, error) {
|
||||
// second try (if retryWithDifferentResolver == true)
|
||||
resolver := weightedRandom(r.resolvers, resolvers)
|
||||
logger.Debugf("using %s as second resolver", resolver.resolver)
|
||||
|
||||
ch := make(chan requestResponse, 1)
|
||||
|
||||
resolver.resolve(request, ch)
|
||||
|
||||
result := <-ch
|
||||
if result.err != nil {
|
||||
return nil, fmt.Errorf("resolution retry failed: %w", result.err)
|
||||
resp, err := resolver.resolve(ctx, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolution retry failed: %w", err)
|
||||
}
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *result.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
"resolver": *resolver,
|
||||
"answer": util.AnswerToString(resp.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
|
||||
return result.response, nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
|
||||
|
|
|
@ -19,6 +19,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
const (
|
||||
verifyUpstreams = true
|
||||
noVerifyUpstreams = false
|
||||
|
||||
timeout = 50 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -40,6 +42,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
old := config.GetConfig().Upstreams.Timeout
|
||||
DeferCleanup(func() { config.GetConfig().Upstreams.Timeout = old })
|
||||
|
||||
config.GetConfig().Upstreams.Timeout = config.Duration(timeout)
|
||||
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
|
||||
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
|
@ -58,7 +64,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
Upstreams: upstreams,
|
||||
}
|
||||
|
||||
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
|
||||
sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
|
@ -98,7 +104,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
mockUpstream.Start(),
|
||||
}
|
||||
|
||||
_, err := NewParallelBestResolver(config.UpstreamGroup{
|
||||
_, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||
Name: upstreamDefaultCfgName,
|
||||
Upstreams: upstreams,
|
||||
},
|
||||
|
@ -133,6 +139,43 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
})
|
||||
|
||||
When("upstream is too slow", func() {
|
||||
BeforeEach(func() {
|
||||
timeoutUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.1").
|
||||
WithDelay(2 * timeout)
|
||||
DeferCleanup(timeoutUpstream.Close)
|
||||
|
||||
upstreams = []config.Upstream{timeoutUpstream.Start()}
|
||||
})
|
||||
|
||||
When("strict checking is enabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutVerify = verifyUpstreams
|
||||
})
|
||||
It("should fail to start", func() {
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
When("strict checking is disabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutVerify = noVerifyUpstreams
|
||||
})
|
||||
It("should start", func() {
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("should not resolve", func() {
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
request := newRequest("example.com.", A)
|
||||
_, err := sut.Resolve(ctx, request)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(isTimeout(err)).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Resolving result from fastest upstream resolver", func() {
|
||||
When("2 Upstream resolvers are defined", func() {
|
||||
When("one resolver is fast and another is slow", func() {
|
||||
|
@ -140,21 +183,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
fastTestUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(fastTestUpstream.Close)
|
||||
|
||||
slowTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
slowTestUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.123").
|
||||
WithDelay(timeout / 2)
|
||||
DeferCleanup(slowTestUpstream.Close)
|
||||
|
||||
upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()}
|
||||
})
|
||||
It("Should use result from fastest one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -167,21 +205,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
When("one resolver is slow, but another returns an error", func() {
|
||||
var slowTestUpstream *MockUDPUpstreamServer
|
||||
BeforeEach(func() {
|
||||
slowTestUpstream = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
slowTestUpstream = NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.123").
|
||||
WithDelay(timeout / 2)
|
||||
DeferCleanup(slowTestUpstream.Close)
|
||||
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
It("Should use result from successful resolver", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
||||
|
@ -202,7 +235,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
It("Should return error", func() {
|
||||
Expect(err).Should(Succeed())
|
||||
request := newRequest("example.com.", A)
|
||||
_, err = sut.Resolve(request)
|
||||
_, err = sut.Resolve(ctx, request)
|
||||
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
|
@ -218,7 +251,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
It("Should use result from defined resolver", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -242,7 +275,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream2.Close)
|
||||
|
||||
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
|
||||
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||
Name: upstreamDefaultCfgName,
|
||||
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||
},
|
||||
|
@ -268,7 +301,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
By("perform 100 request, error upstream's weight will be reduced", func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
request := newRequest("example.com.", A)
|
||||
_, _ = sut.Resolve(request)
|
||||
_, _ = sut.Resolve(ctx, request)
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -302,7 +335,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
It("errors during construction", func() {
|
||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||
|
||||
r, err := NewParallelBestResolver(config.UpstreamGroup{
|
||||
r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||
Name: "test",
|
||||
Upstreams: []config.Upstream{{Host: "example.com"}},
|
||||
}, b, verifyUpstreams)
|
||||
|
@ -313,11 +346,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
|
||||
Describe("random resolver strategy", func() {
|
||||
const timeout = config.Duration(time.Second)
|
||||
|
||||
BeforeEach(func() {
|
||||
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
|
||||
config.GetConfig().Upstreams.Timeout = timeout
|
||||
})
|
||||
|
||||
Describe("Name", func() {
|
||||
|
@ -342,7 +372,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
It("Should return result from either one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(SatisfyAll(
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -356,24 +386,19 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
When("one upstream exceeds timeout", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||
time.Sleep(time.Duration(timeout) + 2*time.Second)
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
timeoutUpstream := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.1").
|
||||
WithDelay(2 * timeout)
|
||||
DeferCleanup(timeoutUpstream.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||
upstreams = []config.Upstream{timeoutUpstream.Start(), testUpstream2.Start()}
|
||||
})
|
||||
It("should ask a other random upstream and return its response", func() {
|
||||
It("should ask another random upstream and return its response", func() {
|
||||
request := newRequest("example.com", A)
|
||||
Expect(sut.Resolve(request)).Should(
|
||||
Expect(sut.Resolve(ctx, request)).Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
|
@ -382,57 +407,25 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
))
|
||||
})
|
||||
})
|
||||
When("two upstreams exceed timeout", func() {
|
||||
When("all upstreams exceed timeout", func() {
|
||||
BeforeEach(func() {
|
||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
||||
time.Sleep(timeout.ToDuration() + 2*time.Second)
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
testUpstream1 := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.1").
|
||||
WithDelay(2 * timeout)
|
||||
DeferCleanup(testUpstream1.Close)
|
||||
|
||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
|
||||
time.Sleep(timeout.ToDuration() + 2*time.Second)
|
||||
|
||||
Expect(err).To(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
testUpstream2 := NewMockUDPUpstreamServer().
|
||||
WithAnswerRR("example.com 123 IN A 123.124.122.2").
|
||||
WithDelay(2 * timeout)
|
||||
DeferCleanup(testUpstream2.Close)
|
||||
|
||||
testUpstream3 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.3")
|
||||
DeferCleanup(testUpstream3.Close)
|
||||
|
||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start(), testUpstream3.Start()}
|
||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||
})
|
||||
// These two tests are flaky -_- (maybe recreate the RandomResolver )
|
||||
It("should not return error (due to random selection the request could to through)", func() {
|
||||
Eventually(func() error {
|
||||
request := newRequest("example.com", A)
|
||||
_, err := sut.Resolve(request)
|
||||
|
||||
return err
|
||||
}).WithTimeout(30 * time.Second).
|
||||
Should(Not(HaveOccurred()))
|
||||
})
|
||||
It("should return error (because it can be possible that the two broken upstreams are chosen)", func() {
|
||||
Eventually(func() error {
|
||||
sutConfig := config.UpstreamGroup{
|
||||
Name: upstreamDefaultCfgName,
|
||||
Upstreams: upstreams,
|
||||
}
|
||||
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
|
||||
|
||||
request := newRequest("example.com", A)
|
||||
_, err := sut.Resolve(request)
|
||||
|
||||
return err
|
||||
}).WithTimeout(30 * time.Second).
|
||||
Should(HaveOccurred())
|
||||
It("Should return error", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
_, err := sut.Resolve(ctx, request)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(isTimeout(err)).Should(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -446,7 +439,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
})
|
||||
It("Should return error", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
_, err := sut.Resolve(request)
|
||||
_, err := sut.Resolve(ctx, request)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
@ -461,7 +454,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
It("Should use result from defined resolver", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -485,7 +478,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream2.Close)
|
||||
|
||||
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
|
||||
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||
Name: upstreamDefaultCfgName,
|
||||
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||
},
|
||||
|
@ -506,7 +499,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
By("perform 100 request, error upstream's weight will be reduced", func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
request := newRequest("example.com.", A)
|
||||
_, _ = sut.Resolve(request)
|
||||
_, _ = sut.Resolve(ctx, request)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -111,12 +111,12 @@ func (r *QueryLoggingResolver) doCleanUp() {
|
|||
}
|
||||
|
||||
// Resolve logs the query, duration and the result
|
||||
func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := r.next.Resolve(request)
|
||||
resp, err := r.next.Resolve(ctx, request)
|
||||
|
||||
duration := time.Since(start).Milliseconds()
|
||||
|
||||
|
|
|
@ -44,6 +44,9 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
m *mockResolver
|
||||
tmpDir *TmpFolder
|
||||
mockAnswer *dns.Msg
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -53,6 +56,9 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
mockAnswer = new(dns.Msg)
|
||||
tmpDir = NewTmpFolder("queryLoggingResolver")
|
||||
Expect(tmpDir.Error).Should(Succeed())
|
||||
|
@ -64,9 +70,6 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
sutConfig.SetDefaults() // not called when using a struct literal
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sut = NewQueryLoggingResolver(ctx, sutConfig)
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
|
||||
|
@ -98,7 +101,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
}
|
||||
})
|
||||
It("should process request without query logging", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -120,7 +123,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
})
|
||||
It("should create a log file per client", func() {
|
||||
By("request from client 1", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -128,7 +131,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
))
|
||||
})
|
||||
By("request from client 2, has name with special chars, should be escaped", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient(
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient(
|
||||
"example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -188,7 +191,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
})
|
||||
It("should create one log file for all clients", func() {
|
||||
By("request from client 1", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -196,7 +199,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
))
|
||||
})
|
||||
By("request from client 2, has name with special chars, should be escaped", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -249,7 +252,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
})
|
||||
It("should create one log file", func() {
|
||||
By("request from client 1", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
|
@ -297,7 +300,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
|||
sut.writer = mockWriter
|
||||
|
||||
Eventually(func() int {
|
||||
_, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
|
||||
_, ierr := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
|
||||
Expect(ierr).Should(Succeed())
|
||||
|
||||
return len(sut.logChan)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -74,7 +75,7 @@ type Resolver interface {
|
|||
Type() string
|
||||
|
||||
// Resolve performs resolution of a DNS request
|
||||
Resolve(req *model.Request) (*model.Response, error)
|
||||
Resolve(ctx context.Context, req *model.Request) (*model.Response, error)
|
||||
}
|
||||
|
||||
// ChainedResolver represents a resolver, which can delegate result to the next one
|
||||
|
@ -216,13 +217,14 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
|
|||
}
|
||||
|
||||
func createResolvers(
|
||||
logger *logrus.Entry, cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
|
||||
ctx context.Context, logger *logrus.Entry,
|
||||
cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
|
||||
) ([]Resolver, error) {
|
||||
resolvers := make([]Resolver, 0, len(cfg.Upstreams))
|
||||
hasValidResolvers := false
|
||||
|
||||
for _, u := range cfg.Upstreams {
|
||||
resolver, err := NewUpstreamResolver(u, bootstrap, shoudVerifyUpstreams)
|
||||
resolver, err := NewUpstreamResolver(ctx, u, bootstrap, shoudVerifyUpstreams)
|
||||
if err != nil {
|
||||
logger.Warnf("upstream group %s: %v", cfg.Name, err)
|
||||
|
||||
|
@ -230,7 +232,7 @@ func createResolvers(
|
|||
}
|
||||
|
||||
if shoudVerifyUpstreams {
|
||||
err = testResolver(resolver)
|
||||
err = testResolver(ctx, resolver)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
|
@ -57,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
|
|||
}
|
||||
|
||||
// Resolve uses the inner resolver to resolve the rewritten query
|
||||
func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, "rewriter_resolver")
|
||||
|
||||
original := request.Req
|
||||
|
@ -69,7 +70,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
|
||||
logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver")
|
||||
|
||||
response, err := r.inner.Resolve(request)
|
||||
response, err := r.inner.Resolve(ctx, request)
|
||||
// Test for error after checking for fallbackUpstream
|
||||
|
||||
// Revert the request: must be done before calling r.next
|
||||
|
@ -80,7 +81,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -91,7 +92,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
// Inner resolver had no response, continue with the normal chain
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
// Revert the rewrite in r.inner's response
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
|
@ -94,7 +96,7 @@ var _ = Describe("RewriterResolver", func() {
|
|||
return res
|
||||
}
|
||||
|
||||
resp, err := sut.Resolve(request)
|
||||
resp, err := sut.Resolve(context.Background(), request)
|
||||
Expect(err).Should(Succeed())
|
||||
if resp != mNextResponse {
|
||||
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||
|
@ -132,18 +134,18 @@ var _ = Describe("RewriterResolver", func() {
|
|||
expectNilAnswer = true
|
||||
|
||||
// Make inner call the NoOpResolver
|
||||
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
Expect(req).Should(Equal(request))
|
||||
|
||||
// Inner should see fqdnRewritten
|
||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten))
|
||||
|
||||
return mInner.next.Resolve(req)
|
||||
return mInner.next.Resolve(ctx, req)
|
||||
}
|
||||
|
||||
// Resolver after RewriterResolver should see `fqdnOriginal`
|
||||
mNext.On("Resolve", mock.Anything)
|
||||
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||
mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||
|
||||
return mNextResponse, nil
|
||||
|
@ -156,7 +158,7 @@ var _ = Describe("RewriterResolver", func() {
|
|||
expectNilAnswer = true
|
||||
|
||||
// Make inner return a nil Answer but not an empty Response
|
||||
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
Expect(req).Should(Equal(request))
|
||||
|
||||
// Inner should see fqdnRewritten
|
||||
|
@ -179,7 +181,7 @@ var _ = Describe("RewriterResolver", func() {
|
|||
fqdnRewritten = sampleRewritten
|
||||
|
||||
// Make inner return a nil Answer but not an empty Response
|
||||
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
Expect(req).Should(Equal(request))
|
||||
|
||||
// Inner should see fqdnRewritten
|
||||
|
@ -190,7 +192,7 @@ var _ = Describe("RewriterResolver", func() {
|
|||
|
||||
// Resolver after RewriterResolver should see `fqdnOriginal`
|
||||
mNext.On("Resolve", mock.Anything)
|
||||
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||
mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||
|
||||
return mNextResponse, nil
|
||||
|
|
|
@ -30,11 +30,11 @@ type StrictResolver struct {
|
|||
|
||||
// NewStrictResolver creates a new strict resolver instance
|
||||
func NewStrictResolver(
|
||||
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||
) (*StrictResolver, error) {
|
||||
logger := log.PrefixedLog(strictResolverType)
|
||||
|
||||
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
|
||||
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -76,44 +76,27 @@ func (r *StrictResolver) String() string {
|
|||
}
|
||||
|
||||
// Resolve sends the query request in a strict order to the upstream resolvers
|
||||
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, strictResolverType)
|
||||
|
||||
// start with first resolver
|
||||
for i := range r.resolvers {
|
||||
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
resolver := r.resolvers[i]
|
||||
for _, resolver := range r.resolvers {
|
||||
logger.Debugf("using %s as resolver", resolver.resolver)
|
||||
|
||||
ch := make(chan requestResponse, 1)
|
||||
|
||||
go resolver.resolve(request, ch)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// log debug/info that timeout exceeded, call `continue` to try next upstream
|
||||
logger.WithField("resolver", r.resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
|
||||
resp, err := resolver.resolve(ctx, request)
|
||||
if err != nil {
|
||||
// log error and try next upstream
|
||||
logger.WithField("resolver", resolver.resolver).Debug("resolution failed from resolver, cause: ", err)
|
||||
|
||||
continue
|
||||
case result := <-ch:
|
||||
if result.err != nil {
|
||||
// log error & call `continue` to try next upstream
|
||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *result.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"resolver": *resolver,
|
||||
"answer": util.AnswerToString(resp.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("resolution was not successful, no resolver returned an answer in time")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -27,6 +28,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
err error
|
||||
|
||||
bootstrap *Bootstrap
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -36,6 +40,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
upstreams = []config.Upstream{
|
||||
{Host: "wrong"},
|
||||
{Host: "127.0.0.2"},
|
||||
|
@ -51,7 +58,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
Name: upstreamDefaultCfgName,
|
||||
Upstreams: upstreams,
|
||||
}
|
||||
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify)
|
||||
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify)
|
||||
})
|
||||
|
||||
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
|
||||
|
@ -100,7 +107,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
mockUpstream.Start(),
|
||||
}
|
||||
|
||||
_, err := NewStrictResolver(config.UpstreamGroup{
|
||||
_, err := NewStrictResolver(ctx, config.UpstreamGroup{
|
||||
Name: upstreamDefaultCfgName,
|
||||
Upstreams: upstreams,
|
||||
},
|
||||
|
@ -151,7 +158,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
It("Should use result from first one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -180,7 +187,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
It("should return response from next upstream", func() {
|
||||
request := newRequest("example.com", A)
|
||||
Expect(sut.Resolve(request)).Should(
|
||||
Expect(sut.Resolve(ctx, request)).Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
|
@ -214,7 +221,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
It("should return error", func() {
|
||||
request := newRequest("example.com", A)
|
||||
_, err := sut.Resolve(request)
|
||||
_, err := sut.Resolve(ctx, request)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
@ -230,7 +237,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
It("Should use result from second one", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
||||
|
@ -247,7 +254,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
})
|
||||
It("Should return error", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
_, err = sut.Resolve(request)
|
||||
_, err = sut.Resolve(ctx, request)
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
@ -262,7 +269,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
|||
It("Should use result from defined resolver", func() {
|
||||
request := newRequest("example.com.", A)
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
|
@ -99,7 +100,7 @@ func NewSpecialUseDomainNamesResolver(cfg config.SUDN) *SpecialUseDomainNamesRes
|
|||
}
|
||||
}
|
||||
|
||||
func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *SpecialUseDomainNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
handler := r.handler(request)
|
||||
if handler != nil {
|
||||
resp := handler(request, r.cfg)
|
||||
|
@ -108,7 +109,7 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.
|
|||
}
|
||||
}
|
||||
|
||||
return r.next.Resolve(request)
|
||||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
func (r *SpecialUseDomainNamesResolver) handler(request *model.Request) sudnHandler {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -19,6 +20,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
sut *SpecialUseDomainNamesResolver
|
||||
sutConfig config.SUDN
|
||||
m *mockResolver
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
|
@ -30,6 +34,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
BeforeEach(func() {
|
||||
var err error
|
||||
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sutConfig, err = config.WithDefaults[config.SUDN]()
|
||||
Expect(err).Should(Succeed())
|
||||
})
|
||||
|
@ -48,7 +55,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
Describe("handlers", func() {
|
||||
It("should have correct response type", func() {
|
||||
for domain, handler := range sudnHandlers {
|
||||
resp, err := sut.Resolve(newRequest(domain, A))
|
||||
resp, err := sut.Resolve(ctx, newRequest(domain, A))
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
if handler == nil {
|
||||
|
@ -90,7 +97,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
|
||||
DescribeTable("handled domains",
|
||||
func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) {
|
||||
resp, err := sut.Resolve(newRequest(qName, qType))
|
||||
resp, err := sut.Resolve(ctx, newRequest(qName, qType))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).Should(SatisfyAll(
|
||||
HaveResponseType(ResponseTypeSPECIAL),
|
||||
|
@ -133,7 +140,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
|
||||
DescribeTable("",
|
||||
func(qType dns.Type, qName string, expectedRCode int) {
|
||||
resp, err := sut.Resolve(newRequest(qName, qType))
|
||||
resp, err := sut.Resolve(ctx, newRequest(qName, qType))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).Should(HaveReturnCode(expectedRCode))
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
|
||||
|
@ -150,7 +157,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
})
|
||||
|
||||
It("should forward example.com", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.145.123.145"),
|
||||
|
@ -161,7 +168,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
})
|
||||
|
||||
It("should forward home.arpa. IN DS", func() {
|
||||
Expect(sut.Resolve(newRequest("something.home.arpa.", DS))).
|
||||
Expect(sut.Resolve(ctx, newRequest("something.home.arpa.", DS))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
// setup code doesn't care about the question
|
||||
|
@ -173,7 +180,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
|||
})
|
||||
|
||||
It("should forward non special use domains", func() {
|
||||
resp, err := sut.Resolve(newRequest("something.not-special.", AAAA))
|
||||
resp, err := sut.Resolve(ctx, newRequest("something.not-special.", AAAA))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
|
||||
})
|
||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -39,8 +40,9 @@ type UpstreamResolver struct {
|
|||
|
||||
type upstreamClient interface {
|
||||
fmtURL(ip net.IP, port uint16, path string) string
|
||||
callExternal(msg *dns.Msg, upstreamURL string,
|
||||
protocol model.RequestProtocol) (response *dns.Msg, rtt time.Duration, err error)
|
||||
callExternal(
|
||||
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
|
||||
) (response *dns.Msg, rtt time.Duration, err error)
|
||||
}
|
||||
|
||||
type dnsUpstreamClient struct {
|
||||
|
@ -53,8 +55,6 @@ type httpUpstreamClient struct {
|
|||
}
|
||||
|
||||
func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
|
||||
|
||||
tlsConfig := tls.Config{
|
||||
ServerName: cfg.Host,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
|
@ -73,7 +73,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
|||
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
Timeout: timeout,
|
||||
},
|
||||
host: cfg.Host,
|
||||
}
|
||||
|
@ -83,7 +82,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
|||
tcpClient: &dns.Client{
|
||||
TLSConfig: &tlsConfig,
|
||||
Net: cfg.Net.String(),
|
||||
Timeout: timeout,
|
||||
SingleInflight: true,
|
||||
},
|
||||
}
|
||||
|
@ -92,12 +90,10 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
|||
return &dnsUpstreamClient{
|
||||
tcpClient: &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: timeout,
|
||||
SingleInflight: true,
|
||||
},
|
||||
udpClient: &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: timeout,
|
||||
SingleInflight: true,
|
||||
},
|
||||
}
|
||||
|
@ -112,8 +108,8 @@ func (r *httpUpstreamClient) fmtURL(ip net.IP, port uint16, path string) string
|
|||
return fmt.Sprintf("https://%s%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), path)
|
||||
}
|
||||
|
||||
func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
|
||||
upstreamURL string, _ model.RequestProtocol,
|
||||
func (r *httpUpstreamClient) callExternal(
|
||||
ctx context.Context, msg *dns.Msg, upstreamURL string, _ model.RequestProtocol,
|
||||
) (*dns.Msg, time.Duration, error) {
|
||||
start := time.Now()
|
||||
|
||||
|
@ -122,7 +118,7 @@ func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
|
|||
return nil, 0, fmt.Errorf("can't pack message: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("can't create the new request %w", err)
|
||||
}
|
||||
|
@ -169,16 +165,16 @@ func (r *dnsUpstreamClient) fmtURL(ip net.IP, port uint16, _ string) string {
|
|||
return net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
|
||||
}
|
||||
|
||||
func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
|
||||
upstreamURL string, protocol model.RequestProtocol,
|
||||
func (r *dnsUpstreamClient) callExternal(
|
||||
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
|
||||
) (response *dns.Msg, rtt time.Duration, err error) {
|
||||
if protocol == model.RequestProtocolTCP {
|
||||
response, rtt, err = r.tcpClient.Exchange(msg, upstreamURL)
|
||||
if err != nil {
|
||||
response, rtt, err = r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||
if err != nil && r.udpClient != nil {
|
||||
// try UDP as fallback
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" && r.udpClient != nil {
|
||||
return r.udpClient.Exchange(msg, upstreamURL)
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" {
|
||||
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,18 +182,20 @@ func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
|
|||
}
|
||||
|
||||
if r.udpClient != nil {
|
||||
return r.udpClient.Exchange(msg, upstreamURL)
|
||||
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||
}
|
||||
|
||||
return r.tcpClient.Exchange(msg, upstreamURL)
|
||||
return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||
}
|
||||
|
||||
// NewUpstreamResolver creates new resolver instance
|
||||
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap, verify bool) (*UpstreamResolver, error) {
|
||||
func NewUpstreamResolver(
|
||||
ctx context.Context, upstream config.Upstream, bootstrap *Bootstrap, verify bool,
|
||||
) (*UpstreamResolver, error) {
|
||||
r := newUpstreamResolverUnchecked(upstream, bootstrap)
|
||||
|
||||
if verify {
|
||||
_, err := r.bootstrap.UpstreamIPs(r)
|
||||
_, err := r.bootstrap.UpstreamIPs(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -234,50 +232,41 @@ func (r UpstreamResolver) String() string {
|
|||
}
|
||||
|
||||
// Resolve calls external resolver
|
||||
func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
||||
ips, err := r.bootstrap.UpstreamIPs(r)
|
||||
func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
ips, err := r.bootstrap.UpstreamIPs(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
rtt time.Duration
|
||||
resp *dns.Msg
|
||||
ip net.IP
|
||||
)
|
||||
|
||||
err = retry.Do(
|
||||
func() error {
|
||||
ctx, cancel := context.WithTimeout(ctx, config.GetConfig().Upstreams.Timeout.ToDuration())
|
||||
defer cancel()
|
||||
|
||||
ip = ips.Current()
|
||||
upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path)
|
||||
|
||||
var err error
|
||||
resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol)
|
||||
if err == nil {
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"answer": util.AnswerToString(resp.Answer),
|
||||
"return_code": dns.RcodeToString[resp.Rcode],
|
||||
"upstream": r.upstream.String(),
|
||||
"upstream_ip": ip.String(),
|
||||
"protocol": request.Protocol,
|
||||
"net": r.upstream.Net,
|
||||
"response_time_ms": rtt.Milliseconds(),
|
||||
}).Debugf("received response from upstream")
|
||||
|
||||
return nil
|
||||
response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
|
||||
resp = response
|
||||
r.logResponse(request, response, ip, rtt)
|
||||
|
||||
return nil
|
||||
},
|
||||
retry.Context(ctx),
|
||||
retry.Attempts(retryAttempts),
|
||||
retry.DelayType(retry.FixedDelay),
|
||||
retry.Delay(1*time.Millisecond),
|
||||
retry.LastErrorOnly(true),
|
||||
retry.RetryIf(func(err error) bool {
|
||||
var netErr net.Error
|
||||
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}),
|
||||
retry.RetryIf(isTimeout),
|
||||
retry.OnRetry(func(n uint, err error) {
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"upstream": r.upstream.String(),
|
||||
|
@ -289,8 +278,31 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
|
|||
ips.Next()
|
||||
}))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
// Make the error more user friendly than just "context deadline exceeded"
|
||||
err = fmt.Errorf("timeout (%w)", err)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil
|
||||
}
|
||||
|
||||
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) {
|
||||
r.log().WithFields(logrus.Fields{
|
||||
"answer": util.AnswerToString(resp.Answer),
|
||||
"return_code": dns.RcodeToString[resp.Rcode],
|
||||
"upstream": r.upstream.String(),
|
||||
"upstream_ip": ip.String(),
|
||||
"protocol": request.Protocol,
|
||||
"net": r.upstream.Net,
|
||||
"response_time_ms": rtt.Milliseconds(),
|
||||
}).Debugf("received response from upstream")
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -21,9 +22,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
var (
|
||||
sut *UpstreamResolver
|
||||
sutConfig config.Upstream
|
||||
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sutConfig = config.Upstream{Host: "localhost"}
|
||||
})
|
||||
|
||||
|
@ -62,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
upstream := mockUpstream.Start()
|
||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -81,7 +88,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
upstream := mockUpstream.Start()
|
||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
HaveNoAnswer(),
|
||||
|
@ -100,7 +107,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
upstream := mockUpstream.Start()
|
||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
@ -133,7 +140,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
atomic.StoreInt32(&counter, 0)
|
||||
atomic.StoreInt32(&attemptsWithTimeout, 2)
|
||||
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -146,12 +153,37 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
By("3 attempts with timeout -> should return error", func() {
|
||||
atomic.StoreInt32(&counter, 0)
|
||||
atomic.StoreInt32(&attemptsWithTimeout, 3)
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("i/o timeout"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
When("user request is TCP", func() {
|
||||
When("TCP upstream connection fails", func() {
|
||||
BeforeEach(func() {
|
||||
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||
DeferCleanup(mockUpstream.Close)
|
||||
|
||||
sutConfig = mockUpstream.Start()
|
||||
})
|
||||
|
||||
It("should retry with UDP", func() {
|
||||
req := newRequest("example.com.", A)
|
||||
req.Protocol = RequestProtocolTCP
|
||||
|
||||
Expect(sut.Resolve(ctx, req)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
HaveTTL(BeNumerically("==", 123)),
|
||||
))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Using Dns over HTTP (DOH) upstream", func() {
|
||||
|
@ -185,7 +217,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
})
|
||||
When("Configured DOH resolver can resolve query", func() {
|
||||
It("should return answer from DNS upstream", func() {
|
||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||
|
@ -203,7 +235,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("should return error", func() {
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
|
||||
})
|
||||
|
@ -215,7 +247,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("should return error", func() {
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(
|
||||
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
|
||||
|
@ -228,7 +260,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("should return error", func() {
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
|
||||
})
|
||||
|
@ -241,7 +273,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
}, systemResolverBootstrap)
|
||||
})
|
||||
It("should return error", func() {
|
||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
Expect(err.Error()).Should(Or(
|
||||
ContainSubstring("no such host"),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
|
@ -64,7 +65,7 @@ func (r *UpstreamTreeResolver) String() string {
|
|||
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
|
||||
}
|
||||
|
||||
func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
|
||||
|
||||
group := r.upstreamGroupByClient(request)
|
||||
|
@ -72,7 +73,7 @@ func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response,
|
|||
// delegate request to group resolver
|
||||
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
|
||||
|
||||
return r.branches[group].Resolve(request)
|
||||
return r.branches[group].Resolve(ctx, request)
|
||||
}
|
||||
|
||||
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
. "github.com/0xERR0R/blocky/helpertest"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
|
@ -139,7 +141,15 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
})
|
||||
|
||||
When("client specific resolvers are defined", func() {
|
||||
var (
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx, cancelFn = context.WithCancel(context.Background())
|
||||
DeferCleanup(cancelFn)
|
||||
|
||||
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
|
||||
upstreamDefaultCfgName: {config.Upstream{}},
|
||||
"laptop": {config.Upstream{}},
|
||||
|
@ -191,7 +201,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use default if client name or IP don't match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "default"),
|
||||
|
@ -202,7 +212,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client specific resolver if client name matches exact", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "laptop"),
|
||||
|
@ -213,7 +223,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client specific resolver if client name matches with wildcard", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "client-*-m"),
|
||||
|
@ -224,7 +234,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client specific resolver if client name matches with range wildcard", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "client[0-9]"),
|
||||
|
@ -235,7 +245,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client specific resolver if client IP matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||
|
@ -246,7 +256,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client specific resolver if client name (containing IP) matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||
|
@ -257,7 +267,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
|
||||
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
|
||||
|
@ -268,7 +278,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use exact IP match before client name match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||
|
@ -279,7 +289,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
It("Should use client name match before CIDR match", func() {
|
||||
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
BeDNSRecord("example.com.", A, "laptop"),
|
||||
|
@ -293,7 +303,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
|||
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
|
||||
request.Log = logger
|
||||
|
||||
Expect(sut.Resolve(request)).
|
||||
Expect(sut.Resolve(ctx, request)).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
SatisfyAny(
|
||||
|
|
|
@ -389,7 +389,7 @@ func createQueryResolver(
|
|||
bootstrap *resolver.Bootstrap,
|
||||
redisClient *redis.Client,
|
||||
) (r resolver.ChainedResolver, err error) {
|
||||
upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap)
|
||||
upstreamBranches, uErr := createUpstreamBranches(ctx, cfg, bootstrap)
|
||||
if uErr != nil {
|
||||
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
|
||||
}
|
||||
|
@ -398,7 +398,9 @@ func createQueryResolver(
|
|||
|
||||
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
|
||||
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(
|
||||
ctx, cfg.Conditional, bootstrap, cfg.StartVerifyUpstream,
|
||||
)
|
||||
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
|
||||
|
||||
err = multierror.Append(
|
||||
|
@ -433,6 +435,7 @@ func createQueryResolver(
|
|||
}
|
||||
|
||||
func createUpstreamBranches(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
bootstrap *resolver.Bootstrap,
|
||||
) (map[string]resolver.Resolver, error) {
|
||||
|
@ -453,11 +456,11 @@ func createUpstreamBranches(
|
|||
|
||||
switch cfg.Upstreams.Strategy {
|
||||
case config.UpstreamStrategyParallelBest:
|
||||
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||
case config.UpstreamStrategyStrict:
|
||||
upstream, err = resolver.NewStrictResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||
upstream, err = resolver.NewStrictResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||
case config.UpstreamStrategyRandom:
|
||||
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||
}
|
||||
|
||||
upstreamBranches[group] = upstream
|
||||
|
@ -643,7 +646,7 @@ func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
|
|||
|
||||
r := createResolverRequest(w, request)
|
||||
|
||||
response, err := s.queryResolver.Resolve(r)
|
||||
response, err := s.queryResolver.Resolve(context.Background(), r)
|
||||
|
||||
if err != nil {
|
||||
logger().Error("error on processing request:", err)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
@ -149,7 +150,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
|||
|
||||
r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)
|
||||
|
||||
resResponse, err := s.queryResolver.Resolve(r)
|
||||
resResponse, err := s.queryResolver.Resolve(req.Context(), r)
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||
|
||||
|
@ -192,11 +193,11 @@ func extractIP(r *http.Request) string {
|
|||
return hostPort
|
||||
}
|
||||
|
||||
func (s *Server) Query(question string, qType dns.Type) (*model.Response, error) {
|
||||
func (s *Server) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
|
||||
dnsRequest := util.NewMsgWithQuestion(question, qType)
|
||||
r := createResolverRequest(nil, dnsRequest)
|
||||
|
||||
return s.queryResolver.Resolve(r)
|
||||
return s.queryResolver.Resolve(ctx, r)
|
||||
}
|
||||
|
||||
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
||||
|
|
|
@ -76,10 +76,9 @@ var _ = BeforeSuite(func() {
|
|||
|
||||
clientMockUpstream = resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||
var clientName string
|
||||
client := mockClientName.Load()
|
||||
|
||||
if client != nil {
|
||||
clientName = mockClientName.Load().(string)
|
||||
if name, ok := mockClientName.Load().(string); ok {
|
||||
clientName = name
|
||||
}
|
||||
|
||||
response, err := util.NewMsgWithAnswer(
|
||||
|
@ -118,8 +117,7 @@ var _ = BeforeSuite(func() {
|
|||
youtubeFile := tmpDir.CreateStringFile("youtube.com.txt", "youtube.com")
|
||||
Expect(youtubeFile.Error).Should(Succeed())
|
||||
|
||||
// create server
|
||||
sut, err = NewServer(ctx, &config.Config{
|
||||
cfg := &config.Config{
|
||||
CustomDNS: config.CustomDNSConfig{
|
||||
CustomTTL: config.Duration(3600 * time.Second),
|
||||
Mapping: config.CustomDNSMapping{
|
||||
|
@ -160,7 +158,8 @@ var _ = BeforeSuite(func() {
|
|||
BlockTTL: config.Duration(6 * time.Hour),
|
||||
},
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
|
||||
Timeout: config.Duration(250 * time.Millisecond),
|
||||
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
|
||||
},
|
||||
ClientLookup: config.ClientLookupConfig{
|
||||
Upstream: upstreamClient,
|
||||
|
@ -178,16 +177,19 @@ var _ = BeforeSuite(func() {
|
|||
Enable: true,
|
||||
Path: "/metrics",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Hacky but needed to update the global since we still have code that reads it
|
||||
*config.GetConfig() = *cfg
|
||||
|
||||
// create server
|
||||
sut, err = NewServer(ctx, cfg)
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
errChan := make(chan error, 10)
|
||||
|
||||
// start server
|
||||
go func() {
|
||||
sut.Start(ctx, errChan)
|
||||
}()
|
||||
go sut.Start(ctx, errChan)
|
||||
DeferCleanup(sut.Stop)
|
||||
|
||||
Consistently(errChan, "1s").ShouldNot(Receive())
|
||||
|
@ -697,7 +699,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
|
||||
Describe("NewServer with strict upstream strategy", func() {
|
||||
It("successfully returns upstream branches", func() {
|
||||
branches, err := createUpstreamBranches(&config.Config{
|
||||
branches, err := createUpstreamBranches(context.Background(), &config.Config{
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Strategy: config.UpstreamStrategyStrict,
|
||||
Groups: config.UpstreamGroups{
|
||||
|
@ -715,7 +717,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
|
||||
Describe("NewServer with random upstream strategy", func() {
|
||||
It("successfully returns upstream branches", func() {
|
||||
branches, err := createUpstreamBranches(&config.Config{
|
||||
branches, err := createUpstreamBranches(context.Background(), &config.Config{
|
||||
Upstreams: config.UpstreamsConfig{
|
||||
Strategy: config.UpstreamStrategyRandom,
|
||||
Groups: config.UpstreamGroups{
|
||||
|
|
Loading…
Reference in New Issue