mirror of https://github.com/0xERR0R/blocky.git
refactor: make use of contexts in more places
- `CacheControl.FlushCaches` - `Querier.Query` - `Resolver.Resolve` Besides all the API churn, this leads to `ParallelBestResolver`, `StrictResolver` and `UpstreamResolver` simplification: timeouts only need to be setup in one place, `UpstreamResolver`. We also benefit from using HTTP request contexts, so if the client closes the connection we stop processing on our side.
This commit is contained in:
parent
e4ebc16ccc
commit
eae99ec550
|
@ -40,11 +40,11 @@ type ListRefresher interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Querier interface {
|
type Querier interface {
|
||||||
Query(question string, qType dns.Type) (*model.Response, error)
|
Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type CacheControl interface {
|
type CacheControl interface {
|
||||||
FlushCaches()
|
FlushCaches(ctx context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
||||||
|
@ -137,13 +137,13 @@ func (i *OpenAPIInterfaceImpl) ListRefresh(_ context.Context,
|
||||||
return ListRefresh200Response{}, nil
|
return ListRefresh200Response{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObject) (QueryResponseObject, error) {
|
func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestObject) (QueryResponseObject, error) {
|
||||||
qType := dns.Type(dns.StringToType[request.Body.Type])
|
qType := dns.Type(dns.StringToType[request.Body.Type])
|
||||||
if qType == dns.Type(dns.TypeNone) {
|
if qType == dns.Type(dns.TypeNone) {
|
||||||
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
|
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := i.querier.Query(dns.Fqdn(request.Body.Query), qType)
|
resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -156,10 +156,10 @@ func (i *OpenAPIInterfaceImpl) Query(_ context.Context, request QueryRequestObje
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *OpenAPIInterfaceImpl) CacheFlush(_ context.Context,
|
func (i *OpenAPIInterfaceImpl) CacheFlush(ctx context.Context,
|
||||||
_ CacheFlushRequestObject,
|
_ CacheFlushRequestObject,
|
||||||
) (CacheFlushResponseObject, error) {
|
) (CacheFlushResponseObject, error) {
|
||||||
i.cacheControl.FlushCaches()
|
i.cacheControl.FlushCaches(ctx)
|
||||||
|
|
||||||
return CacheFlush200Response{}, nil
|
return CacheFlush200Response{}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
// . "github.com/0xERR0R/blocky/helpertest"
|
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
"github.com/0xERR0R/blocky/util"
|
"github.com/0xERR0R/blocky/util"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
|
||||||
return args.Get(0).(BlockingStatus)
|
return args.Get(0).(BlockingStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) {
|
func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
|
||||||
args := m.Called(question, qType)
|
args := m.Called(ctx, question, qType)
|
||||||
|
|
||||||
return args.Get(0).(*model.Response), args.Error(1)
|
return args.Get(0).(*model.Response), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CacheControlMock) FlushCaches() {
|
func (m *CacheControlMock) FlushCaches(ctx context.Context) {
|
||||||
_ = m.Called()
|
_ = m.Called(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = Describe("API implementation tests", func() {
|
var _ = Describe("API implementation tests", func() {
|
||||||
|
@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
|
||||||
listRefreshMock *ListRefreshMock
|
listRefreshMock *ListRefreshMock
|
||||||
cacheControlMock *CacheControlMock
|
cacheControlMock *CacheControlMock
|
||||||
sut *OpenAPIInterfaceImpl
|
sut *OpenAPIInterfaceImpl
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
blockingControlMock = &BlockingControlMock{}
|
blockingControlMock = &BlockingControlMock{}
|
||||||
querierMock = &QuerierMock{}
|
querierMock = &QuerierMock{}
|
||||||
listRefreshMock = &ListRefreshMock{}
|
listRefreshMock = &ListRefreshMock{}
|
||||||
|
@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
|
||||||
)
|
)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
querierMock.On("Query", "google.com.", A).Return(&model.Response{
|
querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
|
||||||
Res: queryResponse,
|
Res: queryResponse,
|
||||||
Reason: "reason",
|
Reason: "reason",
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
resp, err := sut.Query(context.Background(), QueryRequestObject{
|
resp, err := sut.Query(ctx, QueryRequestObject{
|
||||||
Body: &ApiQueryRequest{
|
Body: &ApiQueryRequest{
|
||||||
Query: "google.com", Type: "A",
|
Query: "google.com", Type: "A",
|
||||||
},
|
},
|
||||||
|
@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return 400 on wrong parameter", func() {
|
It("should return 400 on wrong parameter", func() {
|
||||||
resp, err := sut.Query(context.Background(), QueryRequestObject{
|
resp, err := sut.Query(ctx, QueryRequestObject{
|
||||||
Body: &ApiQueryRequest{
|
Body: &ApiQueryRequest{
|
||||||
Query: "google.com",
|
Query: "google.com",
|
||||||
Type: "WRONGTYPE",
|
Type: "WRONGTYPE",
|
||||||
|
@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
It("should return 200 on success", func() {
|
It("should return 200 on success", func() {
|
||||||
listRefreshMock.On("RefreshLists").Return(nil)
|
listRefreshMock.On("RefreshLists").Return(nil)
|
||||||
|
|
||||||
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
|
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
var resp200 ListRefresh200Response
|
var resp200 ListRefresh200Response
|
||||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||||
|
@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
It("should return 500 on failure", func() {
|
It("should return 500 on failure", func() {
|
||||||
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))
|
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))
|
||||||
|
|
||||||
resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
|
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
var resp500 ListRefresh500TextResponse
|
var resp500 ListRefresh500TextResponse
|
||||||
Expect(resp).Should(BeAssignableToTypeOf(resp500))
|
Expect(resp).Should(BeAssignableToTypeOf(resp500))
|
||||||
|
@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
duration := "3s"
|
duration := "3s"
|
||||||
grroups := "gr1,gr2"
|
grroups := "gr1,gr2"
|
||||||
|
|
||||||
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
|
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
|
||||||
Params: DisableBlockingParams{
|
Params: DisableBlockingParams{
|
||||||
Duration: &duration,
|
Duration: &duration,
|
||||||
Groups: &grroups,
|
Groups: &grroups,
|
||||||
|
@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
|
|
||||||
It("should return 400 on failure", func() {
|
It("should return 400 on failure", func() {
|
||||||
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
|
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
|
||||||
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{})
|
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
var resp400 DisableBlocking400TextResponse
|
var resp400 DisableBlocking400TextResponse
|
||||||
Expect(resp).Should(BeAssignableToTypeOf(resp400))
|
Expect(resp).Should(BeAssignableToTypeOf(resp400))
|
||||||
|
@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
|
|
||||||
It("should return 400 on wrong duration parameter", func() {
|
It("should return 400 on wrong duration parameter", func() {
|
||||||
wrongDuration := "4sds"
|
wrongDuration := "4sds"
|
||||||
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
|
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
|
||||||
Params: DisableBlockingParams{
|
Params: DisableBlockingParams{
|
||||||
Duration: &wrongDuration,
|
Duration: &wrongDuration,
|
||||||
},
|
},
|
||||||
|
@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
It("should return 200 on success", func() {
|
It("should return 200 on success", func() {
|
||||||
blockingControlMock.On("EnableBlocking").Return()
|
blockingControlMock.On("EnableBlocking").Return()
|
||||||
|
|
||||||
resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{})
|
resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
var resp200 EnableBlocking200Response
|
var resp200 EnableBlocking200Response
|
||||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||||
|
@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
AutoEnableInSec: 47,
|
AutoEnableInSec: 47,
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{})
|
resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
var resp200 BlockingStatus200JSONResponse
|
var resp200 BlockingStatus200JSONResponse
|
||||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||||
|
@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
|
||||||
Describe("Cache API", func() {
|
Describe("Cache API", func() {
|
||||||
When("Cache flush is called", func() {
|
When("Cache flush is called", func() {
|
||||||
It("should return 200 on success", func() {
|
It("should return 200 on success", func() {
|
||||||
cacheControlMock.On("FlushCaches").Return()
|
cacheControlMock.On("FlushCaches", ctx).Return()
|
||||||
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
|
resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
var resp200 CacheFlush200Response
|
var resp200 CacheFlush200Response
|
||||||
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
Expect(resp).Should(BeAssignableToTypeOf(resp200))
|
||||||
|
|
|
@ -37,7 +37,7 @@ type Options struct {
|
||||||
// OnExpirationCallback will be called just before an element gets expired and will
|
// OnExpirationCallback will be called just before an element gets expired and will
|
||||||
// be removed from cache. This function can return new value and TTL to leave the
|
// be removed from cache. This function can return new value and TTL to leave the
|
||||||
// element in the cache or nil to remove it
|
// element in the cache or nil to remove it
|
||||||
type OnExpirationCallback[T any] func(key string) (val *T, ttl time.Duration)
|
type OnExpirationCallback[T any] func(ctx context.Context, key string) (val *T, ttl time.Duration)
|
||||||
|
|
||||||
// OnCacheHitCallback will be called on cache get if entry was found
|
// OnCacheHitCallback will be called on cache get if entry was found
|
||||||
type OnCacheHitCallback func(key string)
|
type OnCacheHitCallback func(key string)
|
||||||
|
@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
|
||||||
l, _ := lru.New(defaultSize)
|
l, _ := lru.New(defaultSize)
|
||||||
c := &ExpiringLRUCache[T]{
|
c := &ExpiringLRUCache[T]{
|
||||||
cleanUpInterval: defaultCleanUpInterval,
|
cleanUpInterval: defaultCleanUpInterval,
|
||||||
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
|
preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
},
|
},
|
||||||
onCacheHit: func(key string) {},
|
onCacheHit: func(key string) {},
|
||||||
|
@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() {
|
||||||
var keysToDelete []string
|
var keysToDelete []string
|
||||||
|
|
||||||
for _, key := range expiredKeys {
|
for _, key := range expiredKeys {
|
||||||
newVal, newTTL := e.preExpirationFn(key)
|
newVal, newTTL := e.preExpirationFn(context.Background(), key)
|
||||||
if newVal != nil {
|
if newVal != nil {
|
||||||
e.Put(key, newVal, newTTL)
|
e.Put(key, newVal, newTTL)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
Describe("preExpiration function", func() {
|
Describe("preExpiration function", func() {
|
||||||
When("function is defined", func() {
|
When("function is defined", func() {
|
||||||
It("should update the value and TTL if function returns values", func() {
|
It("should update the value and TTL if function returns values", func() {
|
||||||
fn := func(key string) (val *string, ttl time.Duration) {
|
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
|
||||||
v2 := "v2"
|
v2 := "v2"
|
||||||
|
|
||||||
return &v2, time.Second
|
return &v2, time.Second
|
||||||
|
@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
|
It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
|
||||||
fn := func(key string) (val *string, ttl time.Duration) {
|
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
|
||||||
v2 := "val2"
|
v2 := "val2"
|
||||||
|
|
||||||
return &v2, time.Second
|
return &v2, time.Second
|
||||||
|
@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should delete the key if function returns nil", func() {
|
It("should delete the key if function returns nil", func() {
|
||||||
fn := func(key string) (val *string, ttl time.Duration) {
|
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
|
||||||
|
|
|
@ -25,11 +25,11 @@ type cacheValue[T any] struct {
|
||||||
type OnEntryReloadedCallback func(key string)
|
type OnEntryReloadedCallback func(key string)
|
||||||
|
|
||||||
// ReloadEntryFn reloads a prefetched entry by key
|
// ReloadEntryFn reloads a prefetched entry by key
|
||||||
type ReloadEntryFn[T any] func(key string) (*T, time.Duration)
|
type ReloadEntryFn[T any] func(ctx context.Context, key string) (*T, time.Duration)
|
||||||
|
|
||||||
type PrefetchingOptions[T any] struct {
|
type PrefetchingOptions[T any] struct {
|
||||||
Options
|
Options
|
||||||
ReloadFn func(cacheKey string) (*T, time.Duration)
|
ReloadFn ReloadEntryFn[T]
|
||||||
PrefetchThreshold int
|
PrefetchThreshold int
|
||||||
PrefetchExpires time.Duration
|
PrefetchExpires time.Duration
|
||||||
PrefetchMaxItemsCount int
|
PrefetchMaxItemsCount int
|
||||||
|
@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
|
||||||
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
|
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
|
func (e *PrefetchingExpiringLRUCache[T]) onExpired(
|
||||||
|
ctx context.Context, cacheKey string,
|
||||||
|
) (val *cacheValue[T], ttl time.Duration) {
|
||||||
if e.shouldPrefetch(cacheKey) {
|
if e.shouldPrefetch(cacheKey) {
|
||||||
loadedVal, ttl := e.reloadFn(cacheKey)
|
loadedVal, ttl := e.reloadFn(ctx, cacheKey)
|
||||||
if loadedVal != nil {
|
if loadedVal != nil {
|
||||||
if e.onPrefetchEntryReloaded != nil {
|
if e.onPrefetchEntryReloaded != nil {
|
||||||
e.onPrefetchEntryReloaded(cacheKey)
|
e.onPrefetchEntryReloaded(cacheKey)
|
||||||
|
|
|
@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
||||||
},
|
},
|
||||||
PrefetchThreshold: 2,
|
PrefetchThreshold: 2,
|
||||||
PrefetchExpires: 100 * time.Millisecond,
|
PrefetchExpires: 100 * time.Millisecond,
|
||||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||||
v := "v2"
|
v := "v2"
|
||||||
|
|
||||||
return &v, 50 * time.Millisecond
|
return &v, 50 * time.Millisecond
|
||||||
|
@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
||||||
},
|
},
|
||||||
PrefetchThreshold: 2,
|
PrefetchThreshold: 2,
|
||||||
PrefetchExpires: 100 * time.Millisecond,
|
PrefetchExpires: 100 * time.Millisecond,
|
||||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||||
v := "v2"
|
v := "v2"
|
||||||
|
|
||||||
return &v, 50 * time.Millisecond
|
return &v, 50 * time.Millisecond
|
||||||
|
@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
||||||
Options: Options{
|
Options: Options{
|
||||||
CleanupInterval: 100 * time.Millisecond,
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
},
|
},
|
||||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||||
v := "v2"
|
v := "v2"
|
||||||
|
|
||||||
return &v, 50 * time.Millisecond
|
return &v, 50 * time.Millisecond
|
||||||
|
@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
|
||||||
},
|
},
|
||||||
PrefetchThreshold: 2,
|
PrefetchThreshold: 2,
|
||||||
PrefetchExpires: 100 * time.Millisecond,
|
PrefetchExpires: 100 * time.Millisecond,
|
||||||
ReloadFn: func(cacheKey string) (*string, time.Duration) {
|
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
|
||||||
v := "v2"
|
v := "v2"
|
||||||
|
|
||||||
return &v, 50 * time.Millisecond
|
return &v, 50 * time.Millisecond
|
||||||
|
|
|
@ -154,16 +154,16 @@ func NewBlockingResolver(ctx context.Context,
|
||||||
|
|
||||||
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
|
res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
|
||||||
CleanupInterval: defaultBlockingCleanUpInterval,
|
CleanupInterval: defaultBlockingCleanUpInterval,
|
||||||
}, func(key string) (val *[]net.IP, ttl time.Duration) {
|
}, func(ctx context.Context, key string) (val *[]net.IP, ttl time.Duration) {
|
||||||
return res.queryForFQIdentifierIPs(key)
|
return res.queryForFQIdentifierIPs(ctx, key)
|
||||||
})
|
})
|
||||||
|
|
||||||
if res.redisClient != nil {
|
if res.redisClient != nil {
|
||||||
setupRedisEnabledSubscriber(ctx, res)
|
go res.redisSubscriber(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
|
err = evt.Bus().SubscribeOnce(evt.ApplicationStarted, func(_ ...string) {
|
||||||
go res.initFQDNIPCache()
|
go res.initFQDNIPCache(ctx)
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -173,29 +173,27 @@ func NewBlockingResolver(ctx context.Context,
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRedisEnabledSubscriber(ctx context.Context, c *BlockingResolver) {
|
func (r *BlockingResolver) redisSubscriber(ctx context.Context) {
|
||||||
go func() {
|
for {
|
||||||
for {
|
select {
|
||||||
select {
|
case em := <-r.redisClient.EnabledChannel:
|
||||||
case em := <-c.redisClient.EnabledChannel:
|
if em != nil {
|
||||||
if em != nil {
|
r.log().Debug("Received state from redis: ", em)
|
||||||
c.log().Debug("Received state from redis: ", em)
|
|
||||||
|
|
||||||
if em.State {
|
if em.State {
|
||||||
c.internalEnableBlocking()
|
r.internalEnableBlocking()
|
||||||
} else {
|
} else {
|
||||||
err := c.internalDisableBlocking(em.Duration, em.Groups)
|
err := r.internalDisableBlocking(em.Duration, em.Groups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log().Warn("Blocking couldn't be disabled:", err)
|
r.log().Warn("Blocking couldn't be disabled:", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshLists triggers the refresh of all black and white lists in the cache
|
// RefreshLists triggers the refresh of all black and white lists in the cache
|
||||||
|
@ -358,7 +356,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []string,
|
||||||
request *model.Request, logger *logrus.Entry,
|
request *model.Request, logger *logrus.Entry,
|
||||||
) (bool, *model.Response, error) {
|
) (bool, *model.Response, error) {
|
||||||
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
|
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
|
||||||
|
@ -371,7 +369,7 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
||||||
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
|
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
|
||||||
logger.WithField("groups", groups).Debugf("domain is whitelisted")
|
logger.WithField("groups", groups).Debugf("domain is whitelisted")
|
||||||
|
|
||||||
resp, err := r.next.Resolve(request)
|
resp, err := r.next.Resolve(ctx, request)
|
||||||
|
|
||||||
return true, resp, err
|
return true, resp, err
|
||||||
}
|
}
|
||||||
|
@ -393,18 +391,18 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
|
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
|
||||||
func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, "blacklist_resolver")
|
logger := log.WithPrefix(request.Log, "blacklist_resolver")
|
||||||
groupsToCheck := r.groupsToCheckForClient(request)
|
groupsToCheck := r.groupsToCheckForClient(request)
|
||||||
|
|
||||||
if len(groupsToCheck) > 0 {
|
if len(groupsToCheck) > 0 {
|
||||||
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
|
handled, resp, err := r.handleBlacklist(ctx, groupsToCheck, request, logger)
|
||||||
if handled {
|
if handled {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
respFromNext, err := r.next.Resolve(request)
|
respFromNext, err := r.next.Resolve(ctx, request)
|
||||||
|
|
||||||
if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil {
|
if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil {
|
||||||
for _, rr := range respFromNext.Res.Answer {
|
for _, rr := range respFromNext.Res.Answer {
|
||||||
|
@ -574,7 +572,7 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP, time.Duration) {
|
func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
|
||||||
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
|
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")
|
||||||
|
|
||||||
var result []net.IP
|
var result []net.IP
|
||||||
|
@ -582,7 +580,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
|
||||||
var ttl time.Duration
|
var ttl time.Duration
|
||||||
|
|
||||||
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||||
resp, err := r.next.Resolve(&model.Request{
|
resp, err := r.next.Resolve(ctx, &model.Request{
|
||||||
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
|
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
|
||||||
Log: prefixedLog,
|
Log: prefixedLog,
|
||||||
})
|
})
|
||||||
|
@ -606,12 +604,12 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
|
||||||
return &result, ttl
|
return &result, ttl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *BlockingResolver) initFQDNIPCache() {
|
func (r *BlockingResolver) initFQDNIPCache(ctx context.Context) {
|
||||||
identifiers := maps.Keys(r.clientGroupsBlock)
|
identifiers := maps.Keys(r.clientGroupsBlock)
|
||||||
|
|
||||||
for _, identifier := range identifiers {
|
for _, identifier := range identifiers {
|
||||||
if isFQDN(identifier) {
|
if isFQDN(identifier) {
|
||||||
iPs, ttl := r.queryForFQIdentifierIPs(identifier)
|
iPs, ttl := r.queryForFQIdentifierIPs(ctx, identifier)
|
||||||
r.fqdnIPCache.Put(identifier, iPs, ttl)
|
r.fqdnIPCache.Put(identifier, iPs, ttl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -155,7 +155,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
}
|
}
|
||||||
Bus().Publish(ApplicationStarted, "")
|
Bus().Publish(ApplicationStarted, "")
|
||||||
Eventually(func(g Gomega) {
|
Eventually(func(g Gomega) {
|
||||||
g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
|
g.Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
|
||||||
Should(And(
|
Should(And(
|
||||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||||
HaveTTL(BeNumerically("==", 60)),
|
HaveTTL(BeNumerically("==", 60)),
|
||||||
|
@ -185,6 +185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
When("Domain is on the black list", func() {
|
When("Domain is on the black list", func() {
|
||||||
It("should block request", func() {
|
It("should block request", func() {
|
||||||
Eventually(sut.Resolve).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
|
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
|
@ -222,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
When("client name is defined in client groups block", func() {
|
When("client name is defined in client groups block", func() {
|
||||||
It("should block the A query if domain is on the black list (single)", func() {
|
It("should block the A query if domain is on the black list (single)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -233,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block the A query if domain is on the black list (multipart 1)", func() {
|
It("should block the A query if domain is on the black list (multipart 1)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -244,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block the A query if domain is on the black list (multipart 2)", func() {
|
It("should block the A query if domain is on the black list (multipart 2)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -255,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block the A query if domain is on the black list (merged)", func() {
|
It("should block the A query if domain is on the black list (merged)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||||
|
@ -266,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block the AAAA query if domain is on the black list", func() {
|
It("should block the AAAA query if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", AAAA, "::"),
|
BeDNSRecord("domain1.com.", AAAA, "::"),
|
||||||
|
@ -277,18 +278,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block the HTTPS query if domain is on the black list", func() {
|
It("should block the HTTPS query if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))).
|
||||||
Should(HaveReturnCode(dns.RcodeNameError))
|
Should(HaveReturnCode(dns.RcodeNameError))
|
||||||
})
|
})
|
||||||
It("should block the MX query if domain is on the black list", func() {
|
It("should block the MX query if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))).
|
||||||
Should(HaveReturnCode(dns.RcodeNameError))
|
Should(HaveReturnCode(dns.RcodeNameError))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
When("Client ip is defined in client groups block", func() {
|
When("Client ip is defined in client groups block", func() {
|
||||||
It("should block the query if domain is on the black list", func() {
|
It("should block the query if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -301,7 +302,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
||||||
It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
|
It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -313,7 +314,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
})
|
})
|
||||||
It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
|
It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -328,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
||||||
It("should block the query for 10.43.8.64 if domain is on the black list", func() {
|
It("should block the query for 10.43.8.64 if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -339,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block the query for 10.43.8.79 if domain is on the black list", func() {
|
It("should block the query for 10.43.8.79 if domain is on the black list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -353,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
When("Client has multiple names and for each name a client group block definition exists", func() {
|
When("Client has multiple names and for each name a client group block definition exists", func() {
|
||||||
It("should block query if domain is in one group", func() {
|
It("should block query if domain is in one group", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -364,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("should block query if domain is in another group too", func() {
|
It("should block query if domain is in another group too", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||||
|
@ -377,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
When("Client name matches wildcard", func() {
|
When("Client name matches wildcard", func() {
|
||||||
It("should block query if domain is in one group", func() {
|
It("should block query if domain is in one group", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -391,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
When("Default group is defined", func() {
|
When("Default group is defined", func() {
|
||||||
It("should block domains from default group for each client", func() {
|
It("should block domains from default group for each client", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -418,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return NXDOMAIN if query is blocked", func() {
|
It("should return NXDOMAIN if query is blocked", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -444,7 +445,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return answer with specified TTL", func() {
|
It("should return answer with specified TTL", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -461,7 +462,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return custom IP with specified TTL", func() {
|
It("should return custom IP with specified TTL", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
|
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
|
||||||
|
@ -489,7 +490,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return ipv4 address for A query if query is blocked", func() {
|
It("should return ipv4 address for A query if query is blocked", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
|
BeDNSRecord("blocked3.com.", A, "12.12.12.12"),
|
||||||
|
@ -501,7 +502,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return ipv6 address for AAAA query if query is blocked", func() {
|
It("should return ipv6 address for AAAA query if query is blocked", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
||||||
|
@ -528,7 +529,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should use fallback for ipv6 and return zero ip", func() {
|
It("should use fallback for ipv6 and return zero ip", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", AAAA, "::"),
|
BeDNSRecord("blocked3.com.", AAAA, "::"),
|
||||||
|
@ -547,7 +548,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||||
})
|
})
|
||||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "0.0.0.0"),
|
BeDNSRecord("example.com.", A, "0.0.0.0"),
|
||||||
|
@ -567,7 +568,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", AAAA, "::"),
|
BeDNSRecord("example.com.", AAAA, "::"),
|
||||||
|
@ -590,7 +591,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
|
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
|
||||||
})
|
})
|
||||||
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
|
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "0.0.0.0"),
|
BeDNSRecord("example.com.", A, "0.0.0.0"),
|
||||||
|
@ -617,7 +618,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("Should not be blocked", func() {
|
It("Should not be blocked", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -649,7 +650,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should block everything else except domains on the white list with default group", func() {
|
It("should block everything else except domains on the white list with default group", func() {
|
||||||
By("querying domain on the whitelist", func() {
|
By("querying domain on the whitelist", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -662,7 +663,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("querying another domain, which is not on the whitelist", func() {
|
By("querying another domain, which is not on the whitelist", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("google.com.", A, "0.0.0.0"),
|
BeDNSRecord("google.com.", A, "0.0.0.0"),
|
||||||
|
@ -678,7 +679,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
It("should block everything else except domains on the white list "+
|
It("should block everything else except domains on the white list "+
|
||||||
"if multiple white list only groups are defined", func() {
|
"if multiple white list only groups are defined", func() {
|
||||||
By("querying domain on the whitelist", func() {
|
By("querying domain on the whitelist", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -691,7 +692,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("querying another domain, which is not on the whitelist", func() {
|
By("querying another domain, which is not on the whitelist", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked2.com.", A, "0.0.0.0"),
|
||||||
|
@ -706,7 +707,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
It("should block everything else except domains on the white list "+
|
It("should block everything else except domains on the white list "+
|
||||||
"if multiple white list only groups are defined", func() {
|
"if multiple white list only groups are defined", func() {
|
||||||
By("querying domain on the whitelist group 1", func() {
|
By("querying domain on the whitelist group 1", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -719,7 +720,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("querying another domain, which is in the whitelist group 1", func() {
|
By("querying another domain, which is in the whitelist group 1", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -745,7 +746,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||||
})
|
})
|
||||||
It("should not block if DNS answer contains IP from the white list", func() {
|
It("should not block if DNS answer contains IP from the white list", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.145.123.145"),
|
BeDNSRecord("example.com.", A, "123.145.123.145"),
|
||||||
|
@ -775,7 +776,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
When("domain is not on the black list", func() {
|
When("domain is not on the black list", func() {
|
||||||
It("should delegate to next resolver", func() {
|
It("should delegate to next resolver", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -792,7 +793,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("should delegate to next resolver", func() {
|
It("should delegate to next resolver", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -819,7 +820,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
When("Disable blocking is called", func() {
|
When("Disable blocking is called", func() {
|
||||||
It("no query should be blocked", func() {
|
It("no query should be blocked", func() {
|
||||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -830,7 +831,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -847,7 +848,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
By("perform the same query again (defaultGroup)", func() {
|
By("perform the same query again (defaultGroup)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -861,7 +862,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
By("perform the same query again (group1)", func() {
|
By("perform the same query again (group1)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -880,7 +881,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
By("perform the same query again (defaultGroup)", func() {
|
By("perform the same query again (defaultGroup)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -893,7 +894,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -908,7 +909,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
When("Disable blocking for all groups is called with a duration parameter", func() {
|
When("Disable blocking for all groups is called with a duration parameter", func() {
|
||||||
It("No query should be blocked only for passed amount of time", func() {
|
It("No query should be blocked only for passed amount of time", func() {
|
||||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -918,7 +919,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -941,7 +942,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
|
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -954,7 +955,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -974,7 +975,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
// wait 1 sec
|
// wait 1 sec
|
||||||
Eventually(enabled, "1s").Should(Receive(BeTrue()))
|
Eventually(enabled, "1s").Should(Receive(BeTrue()))
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -983,7 +984,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
))
|
))
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -998,7 +999,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
When("Disable blocking for one group is called with a duration parameter", func() {
|
When("Disable blocking for one group is called with a duration parameter", func() {
|
||||||
It("No query should be blocked only for passed amount of time", func() {
|
It("No query should be blocked only for passed amount of time", func() {
|
||||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -1008,7 +1009,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
@ -1031,7 +1032,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
|
|
||||||
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
|
By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -1042,7 +1043,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
})
|
})
|
||||||
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
||||||
// now is blocking disabled, query the url again
|
// now is blocking disabled, query the url again
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -1062,7 +1063,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
// wait 1 sec
|
// wait 1 sec
|
||||||
Eventually(enabled, "1s").Should(Receive(BeTrue()))
|
Eventually(enabled, "1s").Should(Receive(BeTrue()))
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
BeDNSRecord("blocked3.com.", A, "0.0.0.0"),
|
||||||
|
@ -1071,7 +1072,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
))
|
))
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
BeDNSRecord("domain1.com.", A, "0.0.0.0"),
|
||||||
|
|
|
@ -106,14 +106,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
|
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
|
||||||
hostname := r.upstream.Host
|
hostname := r.upstream.Host
|
||||||
|
|
||||||
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
|
if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
|
||||||
return newIPSet([]net.IP{ip}), nil
|
return newIPSet([]net.IP{ip}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, err := b.resolveUpstream(r, hostname)
|
ips, err := b.resolveUpstream(ctx, r, hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -121,10 +121,10 @@ func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
|
||||||
return newIPSet(ips), nil
|
return newIPSet(ips), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) {
|
||||||
// Use system resolver if no bootstrap is configured
|
// Use system resolver if no bootstrap is configured
|
||||||
if b.resolver == nil {
|
if b.resolver == nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), b.timeout)
|
ctx, cancel := context.WithTimeout(ctx, b.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
|
return b.systemResolver.LookupIP(ctx, config.GetConfig().ConnectIPVersion.Net(), host)
|
||||||
|
@ -135,7 +135,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.resolve(host, b.connectIPVersion.QTypes())
|
return b.resolve(ctx, host, b.connectIPVersion.QTypes())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
|
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
|
||||||
|
@ -175,7 +175,7 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the host with the bootstrap DNS
|
// Resolve the host with the bootstrap DNS
|
||||||
ips, err := b.resolve(host, qTypes)
|
ips, err := b.resolve(ctx, host, qTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("resolve error: %s", err)
|
logger.Errorf("resolve error: %s", err)
|
||||||
|
|
||||||
|
@ -192,11 +192,11 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
|
||||||
return b.dialer.DialContext(ctx, network, addrWithIP)
|
return b.dialer.DialContext(ctx, network, addrWithIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
func (b *Bootstrap) resolve(ctx context.Context, hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
||||||
ips = make([]net.IP, 0, len(qTypes))
|
ips = make([]net.IP, 0, len(qTypes))
|
||||||
|
|
||||||
for _, qType := range qTypes {
|
for _, qType := range qTypes {
|
||||||
qIPs, qErr := b.resolveType(hostname, qType)
|
qIPs, qErr := b.resolveType(ctx, hostname, qType)
|
||||||
if qErr != nil {
|
if qErr != nil {
|
||||||
err = multierror.Append(err, qErr)
|
err = multierror.Append(err, qErr)
|
||||||
|
|
||||||
|
@ -213,7 +213,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
|
func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.Type) (ips []net.IP, err error) {
|
||||||
if ip := net.ParseIP(hostname); ip != nil {
|
if ip := net.ParseIP(hostname); ip != nil {
|
||||||
return []net.IP{ip}, nil
|
return []net.IP{ip}, nil
|
||||||
}
|
}
|
||||||
|
@ -223,7 +223,7 @@ func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP,
|
||||||
Log: b.log,
|
Log: b.log,
|
||||||
}
|
}
|
||||||
|
|
||||||
rsp, err := b.resolver.Resolve(&req)
|
rsp, err := b.resolver.Resolve(ctx, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := sut.resolveUpstream(nil, "example.com")
|
_, err := sut.resolveUpstream(ctx, nil, "example.com")
|
||||||
Expect(err).ShouldNot(Succeed())
|
Expect(err).ShouldNot(Succeed())
|
||||||
Expect(usedSystemResolver).Should(Receive(BeTrue()))
|
Expect(usedSystemResolver).Should(Receive(BeTrue()))
|
||||||
})
|
})
|
||||||
|
@ -244,7 +244,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
When("called from bootstrap.upstream", func() {
|
When("called from bootstrap.upstream", func() {
|
||||||
It("uses hardcoded IPs", func() {
|
It("uses hardcoded IPs", func() {
|
||||||
ips, err := sut.resolveUpstream(bootstrapUpstream, "host")
|
ips, err := sut.resolveUpstream(ctx, bootstrapUpstream, "host")
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs))
|
Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs))
|
||||||
|
@ -253,7 +253,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
When("hostname is an IP", func() {
|
When("hostname is an IP", func() {
|
||||||
It("returns immediately", func() {
|
It("returns immediately", func() {
|
||||||
ips, err := sut.resolve("0.0.0.0", config.IPVersionDual.QTypes())
|
ips, err := sut.resolve(ctx, "0.0.0.0", config.IPVersionDual.QTypes())
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(ips).Should(ContainElement(net.IPv4zero))
|
Expect(ips).Should(ContainElement(net.IPv4zero))
|
||||||
|
@ -269,7 +269,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
|
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
|
||||||
|
|
||||||
ips, err := sut.resolve("localhost", []dns.Type{AAAA})
|
ips, err := sut.resolve(ctx, "localhost", []dns.Type{AAAA})
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(ips).Should(HaveLen(1))
|
Expect(ips).Should(HaveLen(1))
|
||||||
|
@ -283,7 +283,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr)
|
bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr)
|
||||||
|
|
||||||
ips, err := sut.resolve("localhost", []dns.Type{A})
|
ips, err := sut.resolve(ctx, "localhost", []dns.Type{A})
|
||||||
|
|
||||||
Expect(err).ShouldNot(Succeed())
|
Expect(err).ShouldNot(Succeed())
|
||||||
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
||||||
|
@ -297,7 +297,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
|
bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil)
|
||||||
|
|
||||||
ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A})
|
ips, err := sut.resolve(ctx, "unknownhost.invalid", []dns.Type{A})
|
||||||
|
|
||||||
Expect(err).ShouldNot(Succeed())
|
Expect(err).ShouldNot(Succeed())
|
||||||
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
||||||
|
@ -329,7 +329,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
r := newUpstreamResolverUnchecked(upstream, sut)
|
r := newUpstreamResolverUnchecked(upstream, sut)
|
||||||
|
|
||||||
rsp, err := r.Resolve(mainReq)
|
rsp, err := r.Resolve(ctx, mainReq)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1))
|
Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1))
|
||||||
Expect(rsp.Res.Question[0].Name).Should(Equal("example.com."))
|
Expect(rsp.Res.Question[0].Name).Should(Equal("example.com."))
|
||||||
|
@ -373,7 +373,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
// implicit expectation of 0 bootstrapUpstream.Resolve calls
|
// implicit expectation of 0 bootstrapUpstream.Resolve calls
|
||||||
|
|
||||||
_, err = t.DialContext(context.Background(), "ip", "!bad-addr!")
|
_, err = t.DialContext(ctx, "ip", "!bad-addr!")
|
||||||
Expect(err).ShouldNot(Succeed())
|
Expect(err).ShouldNot(Succeed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
t := sut.NewHTTPTransport()
|
t := sut.NewHTTPTransport()
|
||||||
|
|
||||||
_, err = t.DialContext(context.Background(), "ip", "abc:123")
|
_, err = t.DialContext(ctx, "ip", "abc:123")
|
||||||
|
|
||||||
Expect(err).ShouldNot(Succeed())
|
Expect(err).ShouldNot(Succeed())
|
||||||
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
Expect(err.Error()).Should(ContainSubstring(resolveErr.Error()))
|
||||||
|
@ -397,7 +397,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
t := sut.NewHTTPTransport()
|
t := sut.NewHTTPTransport()
|
||||||
|
|
||||||
_, err = t.DialContext(context.Background(), "ip", "abc:123")
|
_, err = t.DialContext(ctx, "ip", "abc:123")
|
||||||
|
|
||||||
Expect(err).ShouldNot(Succeed())
|
Expect(err).ShouldNot(Succeed())
|
||||||
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
Expect(err.Error()).Should(ContainSubstring("no such host"))
|
||||||
|
@ -437,7 +437,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
|
|
||||||
Describe("resolve", func() {
|
Describe("resolve", func() {
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
_, err := sut.resolveUpstream(nil, "example.com")
|
_, err := sut.resolveUpstream(ctx, nil, "example.com")
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
@ -501,7 +501,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
t := sut.NewHTTPTransport()
|
t := sut.NewHTTPTransport()
|
||||||
|
|
||||||
conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0")
|
conn, err := t.DialContext(ctx, dialIPVersion.Net(), "localhost:0")
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(conn).Should(Equal(aMockConn))
|
Expect(conn).Should(Equal(aMockConn))
|
||||||
|
|
||||||
|
@ -583,7 +583,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("uses both", func() {
|
It("uses both", func() {
|
||||||
_, err := sut.resolve("example.com.", []dns.Type{dns.Type(dns.TypeA)})
|
_, err := sut.resolve(ctx, "example.com.", []dns.Type{dns.Type(dns.TypeA)})
|
||||||
|
|
||||||
Expect(err).To(Succeed())
|
Expect(err).To(Succeed())
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ func newCachingResolver(ctx context.Context,
|
||||||
configureCaches(ctx, c, &cfg)
|
configureCaches(ctx, c, &cfg)
|
||||||
|
|
||||||
if c.redisClient != nil {
|
if c.redisClient != nil {
|
||||||
setupRedisCacheSubscriber(ctx, c)
|
go c.redisSubscriber(ctx)
|
||||||
c.redisClient.GetRedisCache()
|
c.redisClient.GetRedisCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,14 +105,14 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Duration) {
|
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
|
||||||
qType, domainName := util.ExtractCacheKey(cacheKey)
|
qType, domainName := util.ExtractCacheKey(cacheKey)
|
||||||
logger := r.log()
|
logger := r.log()
|
||||||
|
|
||||||
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)
|
||||||
|
|
||||||
req := newRequest(dns.Fqdn(domainName), qType, logger)
|
req := newRequest(dns.Fqdn(domainName), qType, logger)
|
||||||
response, err := r.next.Resolve(req)
|
response, err := r.next.Resolve(ctx, req)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if response.Res.Rcode == dns.RcodeSuccess {
|
if response.Res.Rcode == dns.RcodeSuccess {
|
||||||
|
@ -132,22 +132,20 @@ func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Durat
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRedisCacheSubscriber(ctx context.Context, c *CachingResolver) {
|
func (r *CachingResolver) redisSubscriber(ctx context.Context) {
|
||||||
go func() {
|
for {
|
||||||
for {
|
select {
|
||||||
select {
|
case rc := <-r.redisClient.CacheChannel:
|
||||||
case rc := <-c.redisClient.CacheChannel:
|
if rc != nil {
|
||||||
if rc != nil {
|
r.log().Debug("Received key from redis: ", rc.Key)
|
||||||
c.log().Debug("Received key from redis: ", rc.Key)
|
ttl := r.adjustTTLs(rc.Response.Res.Answer)
|
||||||
ttl := c.adjustTTLs(rc.Response.Res.Answer)
|
r.putInCache(rc.Key, rc.Response, ttl, false)
|
||||||
c.putInCache(rc.Key, rc.Response, ttl, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogConfig implements `config.Configurable`.
|
// LogConfig implements `config.Configurable`.
|
||||||
|
@ -159,13 +157,13 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
|
||||||
|
|
||||||
// Resolve checks if the current query should use the cache and if the result is already in
|
// Resolve checks if the current query should use the cache and if the result is already in
|
||||||
// the cache and returns it or delegates to the next resolver
|
// the cache and returns it or delegates to the next resolver
|
||||||
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
|
||||||
logger := log.WithPrefix(request.Log, "caching_resolver")
|
logger := log.WithPrefix(request.Log, "caching_resolver")
|
||||||
|
|
||||||
if !r.IsEnabled() || !isRequestCacheable(request) {
|
if !r.IsEnabled() || !isRequestCacheable(request) {
|
||||||
logger.Debug("skip cache")
|
logger.Debug("skip cache")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, question := range request.Req.Question {
|
for _, question := range request.Req.Question {
|
||||||
|
@ -191,7 +189,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
|
||||||
response, err = r.next.Resolve(request)
|
response, err = r.next.Resolve(ctx, request)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
cacheTTL := r.adjustTTLs(response.Res.Answer)
|
||||||
|
@ -319,7 +317,7 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *CachingResolver) FlushCaches() {
|
func (r *CachingResolver) FlushCaches(context.Context) {
|
||||||
r.log().Debug("flush caches")
|
r.log().Debug("flush caches")
|
||||||
r.resultCache.Clear()
|
r.resultCache.Clear()
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,7 +114,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})).Should(Succeed())
|
})).Should(Succeed())
|
||||||
|
|
||||||
// first request
|
// first request
|
||||||
_, _ = sut.Resolve(newRequest("example.com.", A))
|
_, _ = sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
|
|
||||||
// Domain is not prefetched
|
// Domain is not prefetched
|
||||||
Expect(domainPrefetched).ShouldNot(Receive())
|
Expect(domainPrefetched).ShouldNot(Receive())
|
||||||
|
@ -124,7 +124,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
// now query again > threshold
|
// now query again > threshold
|
||||||
for i := 0; i < prefetchThreshold+1; i++ {
|
for i := 0; i < prefetchThreshold+1; i++ {
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
|
Eventually(domainPrefetched, "10s").Should(Receive(Equal(true)))
|
||||||
|
|
||||||
// and it should hit from prefetch cache
|
// and it should hit from prefetch cache
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
@ -156,7 +156,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("should cache response and use response's TTL for multiple records", func() {
|
It("should cache response and use response's TTL for multiple records", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
result, err := sut.Resolve(newRequest("example.com.", A))
|
result, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(result).
|
Expect(result).
|
||||||
Should(
|
Should(
|
||||||
|
@ -176,7 +176,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(func(g Gomega) {
|
Eventually(func(g Gomega) {
|
||||||
result, err := sut.Resolve(newRequest("example.com.", A))
|
result, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
g.Expect(err).Should(Succeed())
|
g.Expect(err).Should(Succeed())
|
||||||
g.Expect(result).
|
g.Expect(result).
|
||||||
Should(
|
Should(
|
||||||
|
@ -218,7 +218,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
_ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) {
|
_ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) {
|
||||||
totalCacheCount <- d
|
totalCacheCount <- d
|
||||||
})
|
})
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -239,7 +239,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
domain <- true
|
domain <- true
|
||||||
})
|
})
|
||||||
|
|
||||||
g.Expect(sut.Resolve(newRequest("example.com.", A))).
|
g.Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
@ -264,7 +264,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
It("should cache response and use min caching time as TTL", func() {
|
It("should cache response and use min caching time as TTL", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -276,7 +276,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", A)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
@ -299,7 +301,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
It("should cache response and use min caching time as TTL", func() {
|
It("should cache response and use min caching time as TTL", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -310,7 +312,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", AAAA)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
@ -344,7 +348,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
It("Shouldn't cache any responses", func() {
|
It("Shouldn't cache any responses", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -355,7 +359,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", AAAA)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -378,7 +384,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("should cache response and use max caching time as TTL if response TTL is bigger", func() {
|
It("should cache response and use max caching time as TTL if response TTL is bigger", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -389,7 +395,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", AAAA)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
@ -417,7 +425,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("should cache response and return 0 TTL if entry is expired", func() {
|
It("should cache response and return 0 TTL if entry is expired", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -430,7 +438,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)).
|
Eventually(sut.Resolve, "2s").
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", A)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
@ -461,7 +471,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeNameError),
|
HaveReturnCode(dns.RcodeNameError),
|
||||||
|
@ -472,7 +482,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", AAAA)).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
HaveReason("CACHED NEGATIVE"),
|
HaveReason("CACHED NEGATIVE"),
|
||||||
|
@ -495,7 +507,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
It("response shouldn't be cached", func() {
|
It("response shouldn't be cached", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeNameError),
|
HaveReturnCode(dns.RcodeNameError),
|
||||||
|
@ -506,7 +518,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", AAAA)).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReason(""),
|
HaveReason(""),
|
||||||
|
@ -529,7 +543,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
It("response should be cached", func() {
|
It("response should be cached", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -540,7 +554,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("example.com.", AAAA)).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
HaveReason("CACHED"),
|
HaveReason("CACHED"),
|
||||||
|
@ -563,7 +579,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("Should be cached", func() {
|
It("Should be cached", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.de.", MX))).
|
Expect(sut.Resolve(ctx, newRequest("google.de.", MX))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -575,7 +591,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("google.de.", MX)).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
HaveReason("CACHED"),
|
HaveReason("CACHED"),
|
||||||
|
@ -599,7 +617,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("Should not be cached", func() {
|
It("Should not be cached", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -611,7 +629,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -633,7 +651,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("Should not be cached", func() {
|
It("Should not be cached", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -645,7 +663,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -671,7 +689,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
It("Should not be cached", func() {
|
It("Should not be cached", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.de.", A))).
|
Expect(sut.Resolve(ctx, newRequest("google.de.", A))).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
HaveReturnCode(dns.RcodeSuccess),
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
@ -688,7 +706,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
Eventually(sut.Resolve).WithArguments(newRequest("google.de.", A)).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(newRequest("google.de.", A)).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
HaveReason("CACHED"),
|
HaveReason("CACHED"),
|
||||||
|
@ -750,7 +770,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("put in redis", func() {
|
It("put in redis", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(HaveResponseType(ResponseTypeRESOLVED))
|
Should(HaveResponseType(ResponseTypeRESOLVED))
|
||||||
|
|
||||||
Eventually(func() []string {
|
Eventually(func() []string {
|
||||||
|
@ -772,7 +792,9 @@ var _ = Describe("CachingResolver", func() {
|
||||||
}
|
}
|
||||||
redisClient.CacheChannel <- redisMockMsg
|
redisClient.CacheChannel <- redisMockMsg
|
||||||
|
|
||||||
Eventually(sut.Resolve).WithArguments(request).
|
Eventually(sut.Resolve).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithArguments(request).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeCACHED),
|
HaveResponseType(ResponseTypeCACHED),
|
||||||
|
|
|
@ -32,7 +32,7 @@ func NewClientNamesResolver(ctx context.Context,
|
||||||
) (cr *ClientNamesResolver, err error) {
|
) (cr *ClientNamesResolver, err error) {
|
||||||
var r Resolver
|
var r Resolver
|
||||||
if !cfg.Upstream.IsDefault() {
|
if !cfg.Upstream.IsDefault() {
|
||||||
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams)
|
r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -59,17 +59,17 @@ func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve tries to resolve the client name from the ip address
|
// Resolve tries to resolve the client name from the ip address
|
||||||
func (r *ClientNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
clientNames := r.getClientNames(request)
|
clientNames := r.getClientNames(ctx, request)
|
||||||
|
|
||||||
request.ClientNames = clientNames
|
request.ClientNames = clientNames
|
||||||
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
|
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns names of client
|
// returns names of client
|
||||||
func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model.Request) []string {
|
||||||
if request.RequestClientID != "" {
|
if request.RequestClientID != "" {
|
||||||
return []string{request.RequestClientID}
|
return []string{request.RequestClientID}
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
||||||
return cpy
|
return cpy
|
||||||
}
|
}
|
||||||
|
|
||||||
names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver"))
|
names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))
|
||||||
|
|
||||||
r.cache.Put(ip.String(), &names, time.Hour)
|
r.cache.Put(ip.String(), &names, time.Hour)
|
||||||
|
|
||||||
|
@ -111,7 +111,9 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
|
||||||
}
|
}
|
||||||
|
|
||||||
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
|
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
|
||||||
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
|
func (r *ClientNamesResolver) resolveClientNames(
|
||||||
|
ctx context.Context, ip net.IP, logger *logrus.Entry,
|
||||||
|
) (result []string) {
|
||||||
// try client mapping first
|
// try client mapping first
|
||||||
result = r.getNameFromIPMapping(ip, result)
|
result = r.getNameFromIPMapping(ip, result)
|
||||||
if len(result) > 0 {
|
if len(result) > 0 {
|
||||||
|
@ -124,7 +126,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
|
||||||
|
|
||||||
reverse, _ := dns.ReverseAddr(ip.String())
|
reverse, _ := dns.ReverseAddr(ip.String())
|
||||||
|
|
||||||
resp, err := r.externalResolver.Resolve(&model.Request{
|
resp, err := r.externalResolver.Resolve(ctx, &model.Request{
|
||||||
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
||||||
Log: logger,
|
Log: logger,
|
||||||
})
|
})
|
||||||
|
|
|
@ -22,8 +22,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
sut *ClientNamesResolver
|
sut *ClientNamesResolver
|
||||||
sutConfig config.ClientLookupConfig
|
sutConfig config.ClientLookupConfig
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
ctx context.Context
|
|
||||||
cancelFn context.CancelFunc
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -34,6 +35,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
JustBeforeEach(func() {
|
JustBeforeEach(func() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
ctx, cancelFn = context.WithCancel(context.Background())
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
DeferCleanup(cancelFn)
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
|
@ -71,7 +73,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should use clientID if set", func() {
|
It("should use clientID if set", func() {
|
||||||
request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123")
|
request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -82,7 +84,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should use IP as fallback if clientID not set", func() {
|
It("should use IP as fallback if clientID not set", func() {
|
||||||
request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "")
|
request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -112,7 +114,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should resolve defined name with ipv4 address", func() {
|
It("should resolve defined name with ipv4 address", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -124,7 +126,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should resolve defined name with ipv6 address", func() {
|
It("should resolve defined name with ipv6 address", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -135,7 +137,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should resolve multiple names defined names", func() {
|
It("should resolve multiple names defined names", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -168,7 +170,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
It("should resolve client name", func() {
|
It("should resolve client name", func() {
|
||||||
By("first request", func() {
|
By("first request", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -180,7 +182,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
By("second request", func() {
|
By("second request", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -198,7 +200,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
By("third request", func() {
|
By("third request", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -223,7 +225,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should resolve all client names", func() {
|
It("should resolve all client names", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -251,7 +253,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should resolve client name", func() {
|
It("should resolve client name", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -272,7 +274,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should resolve the client name depending to defined order", func() {
|
It("should resolve the client name depending to defined order", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -298,7 +300,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
|
|
||||||
It("should use fallback for client name", func() {
|
It("should use fallback for client name", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -318,7 +320,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should use fallback for client name", func() {
|
It("should use fallback for client name", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -335,7 +337,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should resolve no names", func() {
|
It("should resolve no names", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -351,7 +353,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should use fallback for client name", func() {
|
It("should use fallback for client name", func() {
|
||||||
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25")
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -23,14 +24,14 @@ type ConditionalUpstreamResolver struct {
|
||||||
|
|
||||||
// NewConditionalUpstreamResolver returns new resolver instance
|
// NewConditionalUpstreamResolver returns new resolver instance
|
||||||
func NewConditionalUpstreamResolver(
|
func NewConditionalUpstreamResolver(
|
||||||
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||||
) (*ConditionalUpstreamResolver, error) {
|
) (*ConditionalUpstreamResolver, error) {
|
||||||
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
|
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))
|
||||||
|
|
||||||
for domain, upstreams := range cfg.Mapping.Upstreams {
|
for domain, upstreams := range cfg.Mapping.Upstreams {
|
||||||
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
|
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
|
||||||
|
|
||||||
r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams)
|
r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -48,7 +49,9 @@ func NewConditionalUpstreamResolver(
|
||||||
return &r, nil
|
return &r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
|
func (r *ConditionalUpstreamResolver) processRequest(
|
||||||
|
ctx context.Context, request *model.Request,
|
||||||
|
) (bool, *model.Response, error) {
|
||||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||||
domain := domainFromQuestion
|
domain := domainFromQuestion
|
||||||
|
|
||||||
|
@ -56,7 +59,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
|
||||||
// try with domain with and without sub-domains
|
// try with domain with and without sub-domains
|
||||||
for len(domain) > 0 {
|
for len(domain) > 0 {
|
||||||
if resolver, found := r.mapping[domain]; found {
|
if resolver, found := r.mapping[domain]; found {
|
||||||
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
|
||||||
|
|
||||||
return true, resp, err
|
return true, resp, err
|
||||||
}
|
}
|
||||||
|
@ -68,7 +71,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if resolver, found := r.mapping["."]; found {
|
} else if resolver, found := r.mapping["."]; found {
|
||||||
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)
|
||||||
|
|
||||||
return true, resp, err
|
return true, resp, err
|
||||||
}
|
}
|
||||||
|
@ -77,11 +80,11 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve uses the conditional resolver to resolve the query
|
// Resolve uses the conditional resolver to resolve the query
|
||||||
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, "conditional_resolver")
|
logger := log.WithPrefix(request.Log, "conditional_resolver")
|
||||||
|
|
||||||
if len(r.mapping) > 0 {
|
if len(r.mapping) > 0 {
|
||||||
resolved, resp, err := r.processRequest(request)
|
resolved, resp, err := r.processRequest(ctx, request)
|
||||||
if resolved {
|
if resolved {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
@ -89,17 +92,17 @@ func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Re
|
||||||
|
|
||||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ConditionalUpstreamResolver) internalResolve(reso Resolver, doFQ, do string,
|
func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso Resolver, doFQ, do string,
|
||||||
req *model.Request,
|
req *model.Request,
|
||||||
) (*model.Response, error) {
|
) (*model.Response, error) {
|
||||||
// internal request resolution
|
// internal request resolution
|
||||||
logger := log.WithPrefix(req.Log, "conditional_resolver")
|
logger := log.WithPrefix(req.Log, "conditional_resolver")
|
||||||
|
|
||||||
req.Req.Question[0].Name = dns.Fqdn(doFQ)
|
req.Req.Question[0].Name = dns.Fqdn(doFQ)
|
||||||
response, err := reso.Resolve(req)
|
response, err := reso.Resolve(ctx, req)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
response.Reason = "CONDITIONAL"
|
response.Reason = "CONDITIONAL"
|
||||||
|
|
|
@ -19,6 +19,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
var (
|
var (
|
||||||
sut *ConditionalUpstreamResolver
|
sut *ConditionalUpstreamResolver
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -28,6 +31,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||||
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
|
response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122")
|
||||||
|
|
||||||
|
@ -59,7 +65,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
})
|
})
|
||||||
DeferCleanup(refuseTestUpstream.Close)
|
DeferCleanup(refuseTestUpstream.Close)
|
||||||
|
|
||||||
sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
|
sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
|
||||||
Mapping: config.ConditionalUpstreamMapping{
|
Mapping: config.ConditionalUpstreamMapping{
|
||||||
Upstreams: map[string][]config.Upstream{
|
Upstreams: map[string][]config.Upstream{
|
||||||
"fritz.box": {fbTestUpstream.Start()},
|
"fritz.box": {fbTestUpstream.Start()},
|
||||||
|
@ -93,7 +99,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
Describe("Resolve conditional DNS queries via defined DNS server", func() {
|
Describe("Resolve conditional DNS queries via defined DNS server", func() {
|
||||||
When("conditional resolver returns error code", func() {
|
When("conditional resolver returns error code", func() {
|
||||||
It("Should be returned without changes", func() {
|
It("Should be returned without changes", func() {
|
||||||
Expect(sut.Resolve(newRequest("refused.domain.", A))).
|
Expect(sut.Resolve(ctx, newRequest("refused.domain.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -109,7 +115,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
When("Query is exact equal defined condition in mapping", func() {
|
When("Query is exact equal defined condition in mapping", func() {
|
||||||
Context("first mapping entry", func() {
|
Context("first mapping entry", func() {
|
||||||
It("Should resolve the IP of conditional DNS", func() {
|
It("Should resolve the IP of conditional DNS", func() {
|
||||||
Expect(sut.Resolve(newRequest("fritz.box.", A))).
|
Expect(sut.Resolve(ctx, newRequest("fritz.box.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("fritz.box.", A, "123.124.122.122"),
|
BeDNSRecord("fritz.box.", A, "123.124.122.122"),
|
||||||
|
@ -125,7 +131,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
})
|
})
|
||||||
Context("last mapping entry", func() {
|
Context("last mapping entry", func() {
|
||||||
It("Should resolve the IP of conditional DNS", func() {
|
It("Should resolve the IP of conditional DNS", func() {
|
||||||
Expect(sut.Resolve(newRequest("other.box.", A))).
|
Expect(sut.Resolve(ctx, newRequest("other.box.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("other.box.", A, "192.192.192.192"),
|
BeDNSRecord("other.box.", A, "192.192.192.192"),
|
||||||
|
@ -141,7 +147,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
})
|
})
|
||||||
When("Query is a subdomain of defined condition in mapping", func() {
|
When("Query is a subdomain of defined condition in mapping", func() {
|
||||||
It("Should resolve the IP of subdomain", func() {
|
It("Should resolve the IP of subdomain", func() {
|
||||||
Expect(sut.Resolve(newRequest("test.fritz.box.", A))).
|
Expect(sut.Resolve(ctx, newRequest("test.fritz.box.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("test.fritz.box.", A, "123.124.122.122"),
|
BeDNSRecord("test.fritz.box.", A, "123.124.122.122"),
|
||||||
|
@ -156,7 +162,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
})
|
})
|
||||||
When("Query is not fqdn and . condition is defined in mapping", func() {
|
When("Query is not fqdn and . condition is defined in mapping", func() {
|
||||||
It("Should resolve the IP of .", func() {
|
It("Should resolve the IP of .", func() {
|
||||||
Expect(sut.Resolve(newRequest("test.", A))).
|
Expect(sut.Resolve(ctx, newRequest("test.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("test.", A, "168.168.168.168"),
|
BeDNSRecord("test.", A, "168.168.168.168"),
|
||||||
|
@ -173,7 +179,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
Describe("Delegation to next resolver", func() {
|
Describe("Delegation to next resolver", func() {
|
||||||
When("Query doesn't match defined mapping", func() {
|
When("Query doesn't match defined mapping", func() {
|
||||||
It("should delegate to next resolver", func() {
|
It("should delegate to next resolver", func() {
|
||||||
Expect(sut.Resolve(newRequest("google.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("google.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -186,11 +192,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu
|
||||||
|
|
||||||
When("upstream is invalid", func() {
|
When("upstream is invalid", func() {
|
||||||
It("errors during construction", func() {
|
It("errors during construction", func() {
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
|
||||||
DeferCleanup(cancelFn)
|
|
||||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||||
|
|
||||||
r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{
|
r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{
|
||||||
Mapping: config.ConditionalUpstreamMapping{
|
Mapping: config.ConditionalUpstreamMapping{
|
||||||
Upstreams: map[string][]config.Upstream{
|
Upstreams: map[string][]config.Upstream{
|
||||||
".": {config.Upstream{Host: "example.com"}},
|
".": {config.Upstream{Host: "example.com"}},
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -123,7 +124,7 @@ func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Respon
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve uses internal mapping to resolve the query
|
// Resolve uses internal mapping to resolve the query
|
||||||
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *CustomDNSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
|
logger := log.WithPrefix(request.Log, "custom_dns_resolver")
|
||||||
|
|
||||||
reverseResp := r.handleReverseDNS(request)
|
reverseResp := r.handleReverseDNS(request)
|
||||||
|
@ -140,5 +141,5 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
|
||||||
|
|
||||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -21,6 +22,9 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
sut *CustomDNSResolver
|
sut *CustomDNSResolver
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
cfg config.CustomDNSConfig
|
cfg config.CustomDNSConfig
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -30,6 +34,9 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
cfg = config.CustomDNSConfig{
|
cfg = config.CustomDNSConfig{
|
||||||
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{
|
Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{
|
||||||
"custom.domain": {net.ParseIP("192.168.143.123")},
|
"custom.domain": {net.ParseIP("192.168.143.123")},
|
||||||
|
@ -73,7 +80,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
Context("filterUnmappedTypes is true", func() {
|
Context("filterUnmappedTypes is true", func() {
|
||||||
BeforeEach(func() { cfg.FilterUnmappedTypes = true })
|
BeforeEach(func() { cfg.FilterUnmappedTypes = true })
|
||||||
It("defined ip4 query should be resolved", func() {
|
It("defined ip4 query should be resolved", func() {
|
||||||
Expect(sut.Resolve(newRequest("custom.domain.", A))).
|
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
|
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
|
||||||
|
@ -86,7 +93,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||||
})
|
})
|
||||||
It("TXT query for defined mapping should return NOERROR and empty result", func() {
|
It("TXT query for defined mapping should return NOERROR and empty result", func() {
|
||||||
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
|
Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -98,7 +105,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||||
})
|
})
|
||||||
It("ip6 query should return NOERROR and empty result", func() {
|
It("ip6 query should return NOERROR and empty result", func() {
|
||||||
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -114,7 +121,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
Context("filterUnmappedTypes is false", func() {
|
Context("filterUnmappedTypes is false", func() {
|
||||||
BeforeEach(func() { cfg.FilterUnmappedTypes = false })
|
BeforeEach(func() { cfg.FilterUnmappedTypes = false })
|
||||||
It("defined ip4 query should be resolved", func() {
|
It("defined ip4 query should be resolved", func() {
|
||||||
Expect(sut.Resolve(newRequest("custom.domain.", A))).
|
Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
|
BeDNSRecord("custom.domain.", A, "192.168.143.123"),
|
||||||
|
@ -127,7 +134,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything)
|
||||||
})
|
})
|
||||||
It("TXT query for defined mapping should be delegated to next resolver", func() {
|
It("TXT query for defined mapping should be delegated to next resolver", func() {
|
||||||
Expect(sut.Resolve(newRequest("custom.domain.", TXT))).
|
Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -139,7 +146,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
})
|
})
|
||||||
It("ip6 query should return NOERROR and empty result", func() {
|
It("ip6 query should return NOERROR and empty result", func() {
|
||||||
Expect(sut.Resolve(newRequest("custom.domain.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -154,7 +161,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
})
|
})
|
||||||
When("Ip 6 mapping is defined for custom domain ", func() {
|
When("Ip 6 mapping is defined for custom domain ", func() {
|
||||||
It("ip6 query should be resolved", func() {
|
It("ip6 query should be resolved", func() {
|
||||||
Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("ip6.domain.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
||||||
|
@ -170,7 +177,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
When("Multiple IPs are defined for custom domain ", func() {
|
When("Multiple IPs are defined for custom domain ", func() {
|
||||||
It("all IPs for the current type should be returned", func() {
|
It("all IPs for the current type should be returned", func() {
|
||||||
By("IPv6 query", func() {
|
By("IPv6 query", func() {
|
||||||
Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("multiple.ips.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"),
|
||||||
|
@ -185,7 +192,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("IPv4 query", func() {
|
By("IPv4 query", func() {
|
||||||
Expect(sut.Resolve(newRequest("multiple.ips.", A))).
|
Expect(sut.Resolve(ctx, newRequest("multiple.ips.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
WithTransform(ToAnswer, SatisfyAll(
|
WithTransform(ToAnswer, SatisfyAll(
|
||||||
|
@ -207,7 +214,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
When("Reverse DNS request is received", func() {
|
When("Reverse DNS request is received", func() {
|
||||||
It("should resolve the defined domain name", func() {
|
It("should resolve the defined domain name", func() {
|
||||||
By("ipv4", func() {
|
By("ipv4", func() {
|
||||||
Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))).
|
Expect(sut.Resolve(ctx, newRequest("123.143.168.192.in-addr.arpa.", PTR))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
WithTransform(ToAnswer, SatisfyAll(
|
WithTransform(ToAnswer, SatisfyAll(
|
||||||
|
@ -226,7 +233,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
By("ipv6", func() {
|
By("ipv6", func() {
|
||||||
Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
|
Expect(sut.Resolve(ctx, newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||||
PTR))).
|
PTR))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
|
@ -250,7 +257,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
})
|
})
|
||||||
When("Domain mapping is defined", func() {
|
When("Domain mapping is defined", func() {
|
||||||
It("subdomain must also match", func() {
|
It("subdomain must also match", func() {
|
||||||
Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))).
|
Expect(sut.Resolve(ctx, newRequest("ABC.CUSTOM.DOMAIN.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"),
|
BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"),
|
||||||
|
@ -268,7 +275,7 @@ var _ = Describe("CustomDNSResolver", func() {
|
||||||
Describe("Delegating to next resolver", func() {
|
Describe("Delegating to next resolver", func() {
|
||||||
When("no mapping for domain exist", func() {
|
When("no mapping for domain exist", func() {
|
||||||
It("should delegate to next resolver", func() {
|
It("should delegate to next resolver", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ func NewECSResolver(cfg config.ECS) ChainedResolver {
|
||||||
|
|
||||||
// Resolve adds the subnet information as EDNS0 option to the request of the next resolver
|
// Resolve adds the subnet information as EDNS0 option to the request of the next resolver
|
||||||
// and sets the client IP from the EDNS0 option to the request if this option is enabled
|
// and sets the client IP from the EDNS0 option to the request if this option is enabled
|
||||||
func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *ECSResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
if r.cfg.IsEnabled() {
|
if r.cfg.IsEnabled() {
|
||||||
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req)
|
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req)
|
||||||
// Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set
|
// Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set
|
||||||
|
@ -67,7 +68,7 @@ func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setSubnet appends the subnet information to the request as EDNS0 option
|
// setSubnet appends the subnet information to the request as EDNS0 option
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -25,6 +26,9 @@ var _ = Describe("EcsResolver", func() {
|
||||||
err error
|
err error
|
||||||
origIP net.IP
|
origIP net.IP
|
||||||
ecsIP net.IP
|
ecsIP net.IP
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -34,6 +38,9 @@ var _ = Describe("EcsResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
err = defaults.Set(&sutConfig)
|
err = defaults.Set(&sutConfig)
|
||||||
Expect(err).ShouldNot(HaveOccurred())
|
Expect(err).ShouldNot(HaveOccurred())
|
||||||
|
|
||||||
|
@ -86,13 +93,13 @@ var _ = Describe("EcsResolver", func() {
|
||||||
|
|
||||||
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
|
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
|
||||||
|
|
||||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||||
Expect(req.ClientIP).Should(Equal(ecsIP))
|
Expect(req.ClientIP).Should(Equal(ecsIP))
|
||||||
|
|
||||||
return respondWith(mockAnswer), nil
|
return respondWith(mockAnswer), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -107,13 +114,13 @@ var _ = Describe("EcsResolver", func() {
|
||||||
|
|
||||||
addEcsOption(request.Req, ecsIP, 24)
|
addEcsOption(request.Req, ecsIP, 24)
|
||||||
|
|
||||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||||
Expect(req.ClientIP).Should(Equal(origIP))
|
Expect(req.ClientIP).Should(Equal(origIP))
|
||||||
|
|
||||||
return respondWith(mockAnswer), nil
|
return respondWith(mockAnswer), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -134,14 +141,14 @@ var _ = Describe("EcsResolver", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
request.ClientIP = origIP
|
request.ClientIP = origIP
|
||||||
|
|
||||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||||
Expect(req.ClientIP).Should(Equal(origIP))
|
Expect(req.ClientIP).Should(Equal(origIP))
|
||||||
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
|
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
|
||||||
|
|
||||||
return respondWith(mockAnswer), nil
|
return respondWith(mockAnswer), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -154,13 +161,13 @@ var _ = Describe("EcsResolver", func() {
|
||||||
request := newRequest("example.com.", AAAA)
|
request := newRequest("example.com.", AAAA)
|
||||||
request.ClientIP = net.ParseIP("2001:db8::68")
|
request.ClientIP = net.ParseIP("2001:db8::68")
|
||||||
|
|
||||||
m.ResolveFn = func(req *Request) (*Response, error) {
|
m.ResolveFn = func(ctx context.Context, req *Request) (*Response, error) {
|
||||||
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
|
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
|
||||||
|
|
||||||
return respondWith(mockAnswer), nil
|
return respondWith(mockAnswer), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
"github.com/0xERR0R/blocky/util"
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
@ -25,12 +27,12 @@ func NewEDEResolver(cfg config.EDE) *EDEResolver {
|
||||||
|
|
||||||
// Resolve adds the reason as EDNS0 option to the response of the next resolver
|
// Resolve adds the reason as EDNS0 option to the response of the next resolver
|
||||||
// if it is enabled in the configuration
|
// if it is enabled in the configuration
|
||||||
func (r *EDEResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *EDEResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
if !r.cfg.Enable {
|
if !r.cfg.Enable {
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := r.next.Resolve(request)
|
resp, err := r.next.Resolve(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
@ -24,6 +25,9 @@ var _ = Describe("EdeResolver", func() {
|
||||||
sutConfig config.EDE
|
sutConfig config.EDE
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
mockAnswer *dns.Msg
|
mockAnswer *dns.Msg
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -33,6 +37,9 @@ var _ = Describe("EdeResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
mockAnswer = new(dns.Msg)
|
mockAnswer = new(dns.Msg)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -57,7 +64,7 @@ var _ = Describe("EdeResolver", func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("shouldn't add EDE information", func() {
|
It("shouldn't add EDE information", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -89,7 +96,7 @@ var _ = Describe("EdeResolver", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
It("should add EDE information", func() {
|
It("should add EDE information", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -115,7 +122,7 @@ var _ = Describe("EdeResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("shouldn't add EDE information", func() {
|
It("shouldn't add EDE information", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -137,7 +144,7 @@ var _ = Describe("EdeResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should return it", func() {
|
It("should return it", func() {
|
||||||
resp, err := sut.Resolve(newRequest("example.com", A))
|
resp, err := sut.Resolve(ctx, newRequest("example.com", A))
|
||||||
Expect(resp).To(BeNil())
|
Expect(resp).To(BeNil())
|
||||||
Expect(err).To(Equal(resolveErr))
|
Expect(err).To(Equal(resolveErr))
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -21,7 +23,7 @@ func NewFilteringResolver(cfg config.FilteringConfig) *FilteringResolver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *FilteringResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
qType := request.Req.Question[0].Qtype
|
qType := request.Req.Question[0].Qtype
|
||||||
if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
|
if r.cfg.QueryTypes.Contains(dns.Type(qType)) {
|
||||||
response := new(dns.Msg)
|
response := new(dns.Msg)
|
||||||
|
@ -30,5 +32,5 @@ func (r *FilteringResolver) Resolve(request *model.Request) (*model.Response, er
|
||||||
return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil
|
return &model.Response{Res: response, RType: model.ResponseTypeFILTERED}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
. "github.com/0xERR0R/blocky/helpertest"
|
. "github.com/0xERR0R/blocky/helpertest"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
|
@ -18,6 +20,9 @@ var _ = Describe("FilteringResolver", func() {
|
||||||
sutConfig config.FilteringConfig
|
sutConfig config.FilteringConfig
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
mockAnswer *dns.Msg
|
mockAnswer *dns.Msg
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -27,6 +32,9 @@ var _ = Describe("FilteringResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
mockAnswer = new(dns.Msg)
|
mockAnswer = new(dns.Msg)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -60,7 +68,7 @@ var _ = Describe("FilteringResolver", func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("Should delegate to next resolver if request query has other type", func() {
|
It("Should delegate to next resolver if request query has other type", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -72,7 +80,7 @@ var _ = Describe("FilteringResolver", func() {
|
||||||
Expect(m.Calls).Should(HaveLen(1))
|
Expect(m.Calls).Should(HaveLen(1))
|
||||||
})
|
})
|
||||||
It("Should return empty answer for defined query type", func() {
|
It("Should return empty answer for defined query type", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -90,7 +98,7 @@ var _ = Describe("FilteringResolver", func() {
|
||||||
sutConfig = config.FilteringConfig{}
|
sutConfig = config.FilteringConfig{}
|
||||||
})
|
})
|
||||||
It("Should return empty answer without error", func() {
|
It("Should return empty answer without error", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -22,7 +23,7 @@ func NewFQDNOnlyResolver(cfg config.FQDNOnly) *FQDNOnlyResolver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *FQDNOnlyResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
if r.IsEnabled() {
|
if r.IsEnabled() {
|
||||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||||
if !strings.Contains(domainFromQuestion, ".") {
|
if !strings.Contains(domainFromQuestion, ".") {
|
||||||
|
@ -33,5 +34,5 @@ func (r *FQDNOnlyResolver) Resolve(request *model.Request) (*model.Response, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
. "github.com/0xERR0R/blocky/helpertest"
|
. "github.com/0xERR0R/blocky/helpertest"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
|
@ -17,6 +19,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
||||||
sutConfig config.FQDNOnly
|
sutConfig config.FQDNOnly
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
mockAnswer *dns.Msg
|
mockAnswer *dns.Msg
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -26,6 +31,9 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
mockAnswer = new(dns.Msg)
|
mockAnswer = new(dns.Msg)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -57,7 +65,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
||||||
sutConfig = config.FQDNOnly{Enable: true}
|
sutConfig = config.FQDNOnly{Enable: true}
|
||||||
})
|
})
|
||||||
It("Should delegate to next resolver if request query is fqdn", func() {
|
It("Should delegate to next resolver if request query is fqdn", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -69,7 +77,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
||||||
Expect(m.Calls).Should(HaveLen(1))
|
Expect(m.Calls).Should(HaveLen(1))
|
||||||
})
|
})
|
||||||
It("Should return NXDOMAIN if request query is not fqdn", func() {
|
It("Should return NXDOMAIN if request query is not fqdn", func() {
|
||||||
Expect(sut.Resolve(newRequest("example", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -103,7 +111,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
||||||
sutConfig = config.FQDNOnly{Enable: false}
|
sutConfig = config.FQDNOnly{Enable: false}
|
||||||
})
|
})
|
||||||
It("Should delegate to next resolver if request query is fqdn", func() {
|
It("Should delegate to next resolver if request query is fqdn", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -115,7 +123,7 @@ var _ = Describe("FqdnOnlyResolver", func() {
|
||||||
Expect(m.Calls).Should(HaveLen(1))
|
Expect(m.Calls).Should(HaveLen(1))
|
||||||
})
|
})
|
||||||
It("Should delegate to next resolver if request query is not fqdn", func() {
|
It("Should delegate to next resolver if request query is not fqdn", func() {
|
||||||
Expect(sut.Resolve(newRequest("example", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("example", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
|
|
@ -109,9 +109,9 @@ func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Resp
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *HostsFileResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
if !r.IsEnabled() {
|
if !r.IsEnabled() {
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
reverseResp := r.handleReverseDNS(request)
|
reverseResp := r.handleReverseDNS(request)
|
||||||
|
@ -134,7 +134,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
||||||
|
|
||||||
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
r.log().WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg {
|
func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg {
|
||||||
|
|
|
@ -25,6 +25,9 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
tmpFile *TmpFile
|
tmpFile *TmpFile
|
||||||
err error
|
err error
|
||||||
resp *Response
|
resp *Response
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -34,6 +37,9 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
tmpDir = NewTmpFolder("HostsFileResolver")
|
tmpDir = NewTmpFolder("HostsFileResolver")
|
||||||
Expect(tmpDir.Error).Should(Succeed())
|
Expect(tmpDir.Error).Should(Succeed())
|
||||||
DeferCleanup(tmpDir.Clean)
|
DeferCleanup(tmpDir.Clean)
|
||||||
|
@ -53,8 +59,6 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
JustBeforeEach(func() {
|
JustBeforeEach(func() {
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
|
||||||
DeferCleanup(cancelFn)
|
|
||||||
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
|
sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
|
@ -96,7 +100,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
Expect(sut.hosts.isEmpty()).Should(BeTrue())
|
Expect(sut.hosts.isEmpty()).Should(BeTrue())
|
||||||
})
|
})
|
||||||
It("should go to next resolver on query", func() {
|
It("should go to next resolver on query", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -112,11 +116,11 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
sutConfig.Sources = make([]config.BytesSource, 0)
|
sutConfig.Sources = make([]config.BytesSource, 0)
|
||||||
})
|
})
|
||||||
JustBeforeEach(func() {
|
JustBeforeEach(func() {
|
||||||
err = sut.loadSources(context.Background())
|
err = sut.loadSources(ctx)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
})
|
})
|
||||||
It("should go to next resolver on query", func() {
|
It("should go to next resolver on query", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -178,7 +182,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
|
|
||||||
When("IPv4 mapping is defined for a host", func() {
|
When("IPv4 mapping is defined for a host", func() {
|
||||||
It("defined ipv4 query should be resolved", func() {
|
It("defined ipv4 query should be resolved", func() {
|
||||||
Expect(sut.Resolve(newRequest("ipv4host.", A))).
|
Expect(sut.Resolve(ctx, newRequest("ipv4host.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -188,7 +192,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("defined ipv4 query for alias should be resolved", func() {
|
It("defined ipv4 query for alias should be resolved", func() {
|
||||||
Expect(sut.Resolve(newRequest("router2.", A))).
|
Expect(sut.Resolve(ctx, newRequest("router2.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -198,7 +202,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("ipv4 query should return NOERROR and empty result", func() {
|
It("ipv4 query should return NOERROR and empty result", func() {
|
||||||
Expect(sut.Resolve(newRequest("does.not.exist.", A))).
|
Expect(sut.Resolve(ctx, newRequest("does.not.exist.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -210,7 +214,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
|
|
||||||
When("IPv6 mapping is defined for a host", func() {
|
When("IPv6 mapping is defined for a host", func() {
|
||||||
It("defined ipv6 query should be resolved", func() {
|
It("defined ipv6 query should be resolved", func() {
|
||||||
Expect(sut.Resolve(newRequest("ipv6host.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("ipv6host.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -220,7 +224,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
It("ipv6 query should return NOERROR and empty result", func() {
|
It("ipv6 query should return NOERROR and empty result", func() {
|
||||||
Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))).
|
Expect(sut.Resolve(ctx, newRequest("does.not.exist.", AAAA))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -232,7 +236,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
|
|
||||||
When("the domain is not known", func() {
|
When("the domain is not known", func() {
|
||||||
It("calls the next resolver", func() {
|
It("calls the next resolver", func() {
|
||||||
resp, err = sut.Resolve(newRequest("not-in-hostsfile.tld.", A))
|
resp, err = sut.Resolve(ctx, newRequest("not-in-hostsfile.tld.", A))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
@ -241,7 +245,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
|
|
||||||
When("the question type is not handled", func() {
|
When("the question type is not handled", func() {
|
||||||
It("calls the next resolver", func() {
|
It("calls the next resolver", func() {
|
||||||
resp, err = sut.Resolve(newRequest("localhost.", MX))
|
resp, err = sut.Resolve(ctx, newRequest("localhost.", MX))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
@ -251,7 +255,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
When("Reverse DNS request is received", func() {
|
When("Reverse DNS request is received", func() {
|
||||||
It("should resolve the defined domain name", func() {
|
It("should resolve the defined domain name", func() {
|
||||||
By("ipv4 with one hostname", func() {
|
By("ipv4 with one hostname", func() {
|
||||||
Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))).
|
Expect(sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.arpa.", PTR))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -261,7 +265,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
By("ipv4 with aliases", func() {
|
By("ipv4 with aliases", func() {
|
||||||
Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))).
|
Expect(sut.Resolve(ctx, newRequest("1.0.0.10.in-addr.arpa.", PTR))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -274,7 +278,9 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
By("ipv6", func() {
|
By("ipv6", func() {
|
||||||
Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))).
|
Expect(sut.Resolve(ctx,
|
||||||
|
newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR)),
|
||||||
|
).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -290,7 +296,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should ignore invalid PTR", func() {
|
It("should ignore invalid PTR", func() {
|
||||||
resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
|
resp, err = sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
@ -298,7 +304,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
|
|
||||||
When("filterLoopback is true", func() {
|
When("filterLoopback is true", func() {
|
||||||
It("calls the next resolver", func() {
|
It("calls the next resolver", func() {
|
||||||
resp, err = sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR))
|
resp, err = sut.Resolve(ctx, newRequest("1.0.0.127.in-addr.arpa.", PTR))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
@ -307,7 +313,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
|
|
||||||
When("the IP is not known", func() {
|
When("the IP is not known", func() {
|
||||||
It("calls the next resolver", func() {
|
It("calls the next resolver", func() {
|
||||||
resp, err = sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR))
|
resp, err = sut.Resolve(ctx, newRequest("255.255.255.255.in-addr.arpa.", PTR))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
@ -320,7 +326,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("resolve the defined domain name", func() {
|
It("resolve the defined domain name", func() {
|
||||||
Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))).
|
Expect(sut.Resolve(ctx, newRequest("1.1.0.127.in-addr.arpa.", PTR))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeHOSTSFILE),
|
HaveResponseType(ResponseTypeHOSTSFILE),
|
||||||
|
@ -338,7 +344,7 @@ var _ = Describe("HostsFileResolver", func() {
|
||||||
Describe("Delegating to next resolver", func() {
|
Describe("Delegating to next resolver", func() {
|
||||||
When("no hosts file is provided", func() {
|
When("no hosts file is provided", func() {
|
||||||
It("should delegate to next resolver", func() {
|
It("should delegate to next resolver", func() {
|
||||||
_, err = sut.Resolve(newRequest("example.com.", A))
|
_, err = sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
// delegate was executed
|
// delegate was executed
|
||||||
m.AssertExpectations(GinkgoT())
|
m.AssertExpectations(GinkgoT())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,8 +26,8 @@ type MetricsResolver struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve resolves the passed request
|
// Resolve resolves the passed request
|
||||||
func (r *MetricsResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *MetricsResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
response, err := r.next.Resolve(request)
|
response, err := r.next.Resolve(ctx, request)
|
||||||
|
|
||||||
if r.cfg.Enable {
|
if r.cfg.Enable {
|
||||||
r.totalQueries.With(prometheus.Labels{
|
r.totalQueries.With(prometheus.Labels{
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -21,6 +22,9 @@ var _ = Describe("MetricResolver", func() {
|
||||||
var (
|
var (
|
||||||
sut *MetricsResolver
|
sut *MetricsResolver
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -30,6 +34,9 @@ var _ = Describe("MetricResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
|
sut = NewMetricsResolver(config.MetricsConfig{Enable: true})
|
||||||
m = &mockResolver{}
|
m = &mockResolver{}
|
||||||
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
|
||||||
|
@ -56,7 +63,7 @@ var _ = Describe("MetricResolver", func() {
|
||||||
Context("Recording request metrics", func() {
|
Context("Recording request metrics", func() {
|
||||||
When("Request will be performed", func() {
|
When("Request will be performed", func() {
|
||||||
It("Should record metrics", func() {
|
It("Should record metrics", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -77,7 +84,7 @@ var _ = Describe("MetricResolver", func() {
|
||||||
sut.Next(m)
|
sut.Next(m)
|
||||||
})
|
})
|
||||||
It("Error should be recorded", func() {
|
It("Error should be recorded", func() {
|
||||||
_, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))
|
_, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))
|
||||||
|
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/util"
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
@ -62,6 +63,21 @@ func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *MockUDPUpstreamServer) WithDelay(delay time.Duration) *MockUDPUpstreamServer {
|
||||||
|
answerFn := t.answerFn
|
||||||
|
if answerFn == nil {
|
||||||
|
panic("WithDelay must be called after a WithAnswer function")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.answerFn = func(request *dns.Msg) *dns.Msg {
|
||||||
|
time.Sleep(delay)
|
||||||
|
|
||||||
|
return answerFn(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
func (t *MockUDPUpstreamServer) GetCallCount() int {
|
func (t *MockUDPUpstreamServer) GetCallCount() int {
|
||||||
return int(atomic.LoadInt32(&t.callCount))
|
return int(atomic.LoadInt32(&t.callCount))
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ type mockResolver struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
NextResolver
|
NextResolver
|
||||||
|
|
||||||
ResolveFn func(req *model.Request) (*model.Response, error)
|
ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error)
|
||||||
ResponseFn func(req *dns.Msg) *dns.Msg
|
ResponseFn func(req *dns.Msg) *dns.Msg
|
||||||
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
|
AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error)
|
||||||
}
|
}
|
||||||
|
@ -45,11 +45,11 @@ func (r *mockResolver) LogConfig(*logrus.Entry) {
|
||||||
r.Called()
|
r.Called()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) {
|
func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
args := r.Called(req)
|
args := r.Called(req)
|
||||||
|
|
||||||
if r.ResolveFn != nil {
|
if r.ResolveFn != nil {
|
||||||
return r.ResolveFn(req)
|
return r.ResolveFn(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.ResponseFn != nil {
|
if r.ResponseFn != nil {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
@ -28,6 +30,6 @@ func (NoOpResolver) IsEnabled() bool {
|
||||||
func (NoOpResolver) LogConfig(*logrus.Entry) {
|
func (NoOpResolver) LogConfig(*logrus.Entry) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (NoOpResolver) Resolve(*model.Request) (*model.Response, error) {
|
func (NoOpResolver) Resolve(context.Context, *model.Request) (*model.Response, error) {
|
||||||
return NoResponse, nil
|
return NoResponse, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
. "github.com/0xERR0R/blocky/helpertest"
|
. "github.com/0xERR0R/blocky/helpertest"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
@ -8,7 +10,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("NoOpResolver", func() {
|
var _ = Describe("NoOpResolver", func() {
|
||||||
var sut *NoOpResolver
|
var (
|
||||||
|
sut *NoOpResolver
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
It("follows conventions", func() {
|
It("follows conventions", func() {
|
||||||
|
@ -17,12 +24,15 @@ var _ = Describe("NoOpResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
sut = NewNoOpResolver()
|
sut = NewNoOpResolver()
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Resolving", func() {
|
Describe("Resolving", func() {
|
||||||
It("returns no response", func() {
|
It("returns no response", func() {
|
||||||
resp, err := sut.Resolve(newRequest("test.tld", A))
|
resp, err := sut.Resolve(ctx, newRequest("test.tld", A))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).Should(Equal(NoResponse))
|
Expect(resp).Should(Equal(NoResponse))
|
||||||
})
|
})
|
||||||
|
|
|
@ -53,17 +53,23 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus {
|
||||||
return status
|
return status
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *upstreamResolverStatus) resolve(req *model.Request, ch chan<- requestResponse) {
|
func (r *upstreamResolverStatus) resolve(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
resp, err := r.resolver.Resolve(req)
|
resp, err := r.resolver.Resolve(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Ignore `Canceled`: resolver lost the race, not an error
|
// Ignore `Canceled`: resolver lost the race, not an error
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
r.lastErrorTime.Store(time.Now())
|
r.lastErrorTime.Store(time.Now())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fmt.Errorf("%s: %w", r.resolver, err)
|
return nil, fmt.Errorf("%s: %w", r.resolver, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *upstreamResolverStatus) resolveToChan(ctx context.Context, req *model.Request, ch chan<- requestResponse) {
|
||||||
|
resp, err := r.resolve(ctx, req)
|
||||||
|
|
||||||
ch <- requestResponse{
|
ch <- requestResponse{
|
||||||
resolver: &r.resolver,
|
resolver: &r.resolver,
|
||||||
response: resp,
|
response: resp,
|
||||||
|
@ -78,10 +84,10 @@ type requestResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// testResolver sends a test query to verify the resolver is reachable and working
|
// testResolver sends a test query to verify the resolver is reachable and working
|
||||||
func testResolver(r *UpstreamResolver) error {
|
func testResolver(ctx context.Context, r *UpstreamResolver) error {
|
||||||
request := newRequest("github.com.", dns.Type(dns.TypeA))
|
request := newRequest("github.com.", dns.Type(dns.TypeA))
|
||||||
|
|
||||||
resp, err := r.Resolve(request)
|
resp, err := r.Resolve(ctx, request)
|
||||||
if err != nil || resp.RType != model.ResponseTypeRESOLVED {
|
if err != nil || resp.RType != model.ResponseTypeRESOLVED {
|
||||||
return fmt.Errorf("test resolve of upstream server failed: %w", err)
|
return fmt.Errorf("test resolve of upstream server failed: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -91,11 +97,11 @@ func testResolver(r *UpstreamResolver) error {
|
||||||
|
|
||||||
// NewParallelBestResolver creates new resolver instance
|
// NewParallelBestResolver creates new resolver instance
|
||||||
func NewParallelBestResolver(
|
func NewParallelBestResolver(
|
||||||
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||||
) (*ParallelBestResolver, error) {
|
) (*ParallelBestResolver, error) {
|
||||||
logger := log.PrefixedLog(parallelResolverType)
|
logger := log.PrefixedLog(parallelResolverType)
|
||||||
|
|
||||||
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
|
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -150,26 +156,18 @@ func (r *ParallelBestResolver) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
|
// Resolve sends the query request to multiple upstream resolvers and returns the fastest result
|
||||||
func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *ParallelBestResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, parallelResolverType)
|
logger := log.WithPrefix(request.Log, parallelResolverType)
|
||||||
|
|
||||||
if len(r.resolvers) == 1 {
|
if len(r.resolvers) == 1 {
|
||||||
logger.WithField("resolver", r.resolvers[0].resolver).Debug("delegating to resolver")
|
resolver := r.resolvers[0]
|
||||||
|
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
|
||||||
|
|
||||||
return r.resolvers[0].resolver.Resolve(request)
|
return resolver.resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel() // abort requests to resolvers that lost the race
|
||||||
// using context with timeout for random upstream strategy
|
|
||||||
if r.resolverCount == 1 {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
|
|
||||||
timeout := config.GetConfig().Upstreams.Timeout
|
|
||||||
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout))
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvers := pickRandom(r.resolvers, r.resolverCount)
|
resolvers := pickRandom(r.resolvers, r.resolverCount)
|
||||||
ch := make(chan requestResponse, len(resolvers))
|
ch := make(chan requestResponse, len(resolvers))
|
||||||
|
@ -177,10 +175,10 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
||||||
for _, resolver := range resolvers {
|
for _, resolver := range resolvers {
|
||||||
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
|
logger.WithField("resolver", resolver.resolver).Debug("delegating to resolver")
|
||||||
|
|
||||||
go resolver.resolve(request, ch)
|
go resolver.resolveToChan(ctx, request, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, collectedErrors := evaluateResponses(ctx, logger, ch, resolvers)
|
response, collectedErrors := evaluateResponses(logger, ch, resolvers)
|
||||||
if response != nil {
|
if response != nil {
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
@ -189,63 +187,51 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
||||||
return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...))
|
return nil, fmt.Errorf("resolution failed: %w", errors.Join(collectedErrors...))
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.retryWithDifferent(logger, request, resolvers)
|
return r.retryWithDifferent(ctx, logger, request, resolvers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func evaluateResponses(
|
func evaluateResponses(
|
||||||
ctx context.Context, logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
|
logger *logrus.Entry, ch chan requestResponse, resolvers []*upstreamResolverStatus,
|
||||||
) (*model.Response, []error) {
|
) (*model.Response, []error) {
|
||||||
collectedErrors := make([]error, 0, len(resolvers))
|
collectedErrors := make([]error, 0, len(resolvers))
|
||||||
|
|
||||||
for len(collectedErrors) < len(resolvers) {
|
for len(collectedErrors) < len(resolvers) {
|
||||||
select {
|
result := <-ch
|
||||||
case <-ctx.Done():
|
logger := logger.WithField("resolver", *result.resolver)
|
||||||
// this context currently only has a deadline when resolverCount == 1
|
|
||||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
|
||||||
logger.WithField("resolver", resolvers[0].resolver).
|
|
||||||
Debug("upstream exceeded timeout, trying other upstream")
|
|
||||||
resolvers[0].lastErrorTime.Store(time.Now())
|
|
||||||
}
|
|
||||||
case result := <-ch:
|
|
||||||
if result.err != nil {
|
|
||||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
|
||||||
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
|
|
||||||
} else {
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"resolver": *result.resolver,
|
|
||||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
|
||||||
}).Debug("using response from resolver")
|
|
||||||
|
|
||||||
return result.response, nil
|
if result.err != nil {
|
||||||
}
|
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
||||||
|
collectedErrors = append(collectedErrors, fmt.Errorf("resolver: %q error: %w", *result.resolver, result.err))
|
||||||
|
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.WithField("answer", util.AnswerToString(result.response.Res.Answer)).Debug("using response from resolver")
|
||||||
|
|
||||||
|
return result.response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, collectedErrors
|
return nil, collectedErrors
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ParallelBestResolver) retryWithDifferent(
|
func (r *ParallelBestResolver) retryWithDifferent(
|
||||||
logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
|
ctx context.Context, logger *logrus.Entry, request *model.Request, resolvers []*upstreamResolverStatus,
|
||||||
) (*model.Response, error) {
|
) (*model.Response, error) {
|
||||||
// second try (if retryWithDifferentResolver == true)
|
// second try (if retryWithDifferentResolver == true)
|
||||||
resolver := weightedRandom(r.resolvers, resolvers)
|
resolver := weightedRandom(r.resolvers, resolvers)
|
||||||
logger.Debugf("using %s as second resolver", resolver.resolver)
|
logger.Debugf("using %s as second resolver", resolver.resolver)
|
||||||
|
|
||||||
ch := make(chan requestResponse, 1)
|
resp, err := resolver.resolve(ctx, request)
|
||||||
|
if err != nil {
|
||||||
resolver.resolve(request, ch)
|
return nil, fmt.Errorf("resolution retry failed: %w", err)
|
||||||
|
|
||||||
result := <-ch
|
|
||||||
if result.err != nil {
|
|
||||||
return nil, fmt.Errorf("resolution retry failed: %w", result.err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.WithFields(logrus.Fields{
|
logger.WithFields(logrus.Fields{
|
||||||
"resolver": *result.resolver,
|
"resolver": *resolver,
|
||||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
"answer": util.AnswerToString(resp.Res.Answer),
|
||||||
}).Debug("using response from resolver")
|
}).Debug("using response from resolver")
|
||||||
|
|
||||||
return result.response, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
|
// pickRandom picks n (resolverCount) different random resolvers from the given resolver pool
|
||||||
|
|
|
@ -19,6 +19,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
const (
|
const (
|
||||||
verifyUpstreams = true
|
verifyUpstreams = true
|
||||||
noVerifyUpstreams = false
|
noVerifyUpstreams = false
|
||||||
|
|
||||||
|
timeout = 50 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -40,6 +42,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
old := config.GetConfig().Upstreams.Timeout
|
||||||
|
DeferCleanup(func() { config.GetConfig().Upstreams.Timeout = old })
|
||||||
|
|
||||||
|
config.GetConfig().Upstreams.Timeout = config.Duration(timeout)
|
||||||
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
|
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
|
||||||
|
|
||||||
ctx, cancelFn = context.WithCancel(context.Background())
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
@ -58,7 +64,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
Upstreams: upstreams,
|
Upstreams: upstreams,
|
||||||
}
|
}
|
||||||
|
|
||||||
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
|
sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify)
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("IsEnabled", func() {
|
Describe("IsEnabled", func() {
|
||||||
|
@ -98,7 +104,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
mockUpstream.Start(),
|
mockUpstream.Start(),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := NewParallelBestResolver(config.UpstreamGroup{
|
_, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||||
Name: upstreamDefaultCfgName,
|
Name: upstreamDefaultCfgName,
|
||||||
Upstreams: upstreams,
|
Upstreams: upstreams,
|
||||||
},
|
},
|
||||||
|
@ -133,6 +139,43 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
When("upstream is too slow", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
timeoutUpstream := NewMockUDPUpstreamServer().
|
||||||
|
WithAnswerRR("example.com 123 IN A 123.124.122.1").
|
||||||
|
WithDelay(2 * timeout)
|
||||||
|
DeferCleanup(timeoutUpstream.Close)
|
||||||
|
|
||||||
|
upstreams = []config.Upstream{timeoutUpstream.Start()}
|
||||||
|
})
|
||||||
|
|
||||||
|
When("strict checking is enabled", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
sutVerify = verifyUpstreams
|
||||||
|
})
|
||||||
|
It("should fail to start", func() {
|
||||||
|
Expect(err).Should(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
When("strict checking is disabled", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
sutVerify = noVerifyUpstreams
|
||||||
|
})
|
||||||
|
It("should start", func() {
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
})
|
||||||
|
It("should not resolve", func() {
|
||||||
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
|
request := newRequest("example.com.", A)
|
||||||
|
_, err := sut.Resolve(ctx, request)
|
||||||
|
Expect(err).Should(HaveOccurred())
|
||||||
|
Expect(isTimeout(err)).Should(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
Describe("Resolving result from fastest upstream resolver", func() {
|
Describe("Resolving result from fastest upstream resolver", func() {
|
||||||
When("2 Upstream resolvers are defined", func() {
|
When("2 Upstream resolvers are defined", func() {
|
||||||
When("one resolver is fast and another is slow", func() {
|
When("one resolver is fast and another is slow", func() {
|
||||||
|
@ -140,21 +183,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
fastTestUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
fastTestUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
DeferCleanup(fastTestUpstream.Close)
|
DeferCleanup(fastTestUpstream.Close)
|
||||||
|
|
||||||
slowTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
slowTestUpstream := NewMockUDPUpstreamServer().
|
||||||
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
|
WithAnswerRR("example.com 123 IN A 123.124.122.123").
|
||||||
time.Sleep(50 * time.Millisecond)
|
WithDelay(timeout / 2)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
|
|
||||||
return response
|
|
||||||
})
|
|
||||||
DeferCleanup(slowTestUpstream.Close)
|
DeferCleanup(slowTestUpstream.Close)
|
||||||
|
|
||||||
upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()}
|
upstreams = []config.Upstream{fastTestUpstream.Start(), slowTestUpstream.Start()}
|
||||||
})
|
})
|
||||||
It("Should use result from fastest one", func() {
|
It("Should use result from fastest one", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -167,21 +205,16 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
When("one resolver is slow, but another returns an error", func() {
|
When("one resolver is slow, but another returns an error", func() {
|
||||||
var slowTestUpstream *MockUDPUpstreamServer
|
var slowTestUpstream *MockUDPUpstreamServer
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
slowTestUpstream = NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
slowTestUpstream = NewMockUDPUpstreamServer().
|
||||||
response, err := util.NewMsgWithAnswer("example.com.", 123, A, "123.124.122.123")
|
WithAnswerRR("example.com 123 IN A 123.124.122.123").
|
||||||
time.Sleep(50 * time.Millisecond)
|
WithDelay(timeout / 2)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
|
||||||
|
|
||||||
return response
|
|
||||||
})
|
|
||||||
DeferCleanup(slowTestUpstream.Close)
|
DeferCleanup(slowTestUpstream.Close)
|
||||||
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
|
upstreams = []config.Upstream{{Host: "wrong"}, slowTestUpstream.Start()}
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
})
|
})
|
||||||
It("Should use result from successful resolver", func() {
|
It("Should use result from successful resolver", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
||||||
|
@ -202,7 +235,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
It("Should return error", func() {
|
It("Should return error", func() {
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
_, err = sut.Resolve(request)
|
_, err = sut.Resolve(ctx, request)
|
||||||
|
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
@ -218,7 +251,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
It("Should use result from defined resolver", func() {
|
It("Should use result from defined resolver", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -242,7 +275,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
DeferCleanup(mockUpstream2.Close)
|
DeferCleanup(mockUpstream2.Close)
|
||||||
|
|
||||||
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
|
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||||
Name: upstreamDefaultCfgName,
|
Name: upstreamDefaultCfgName,
|
||||||
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||||
},
|
},
|
||||||
|
@ -268,7 +301,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
By("perform 100 request, error upstream's weight will be reduced", func() {
|
By("perform 100 request, error upstream's weight will be reduced", func() {
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
_, _ = sut.Resolve(request)
|
_, _ = sut.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -302,7 +335,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
It("errors during construction", func() {
|
It("errors during construction", func() {
|
||||||
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})
|
||||||
|
|
||||||
r, err := NewParallelBestResolver(config.UpstreamGroup{
|
r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Upstreams: []config.Upstream{{Host: "example.com"}},
|
Upstreams: []config.Upstream{{Host: "example.com"}},
|
||||||
}, b, verifyUpstreams)
|
}, b, verifyUpstreams)
|
||||||
|
@ -313,11 +346,8 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("random resolver strategy", func() {
|
Describe("random resolver strategy", func() {
|
||||||
const timeout = config.Duration(time.Second)
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
|
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom
|
||||||
config.GetConfig().Upstreams.Timeout = timeout
|
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Name", func() {
|
Describe("Name", func() {
|
||||||
|
@ -342,7 +372,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
})
|
})
|
||||||
It("Should return result from either one", func() {
|
It("Should return result from either one", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(SatisfyAll(
|
Should(SatisfyAll(
|
||||||
HaveTTL(BeNumerically("==", 123)),
|
HaveTTL(BeNumerically("==", 123)),
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -356,24 +386,19 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
})
|
})
|
||||||
When("one upstream exceeds timeout", func() {
|
When("one upstream exceeds timeout", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
timeoutUpstream := NewMockUDPUpstreamServer().
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
WithAnswerRR("example.com 123 IN A 123.124.122.1").
|
||||||
time.Sleep(time.Duration(timeout) + 2*time.Second)
|
WithDelay(2 * timeout)
|
||||||
|
DeferCleanup(timeoutUpstream.Close)
|
||||||
Expect(err).To(Succeed())
|
|
||||||
|
|
||||||
return response
|
|
||||||
})
|
|
||||||
DeferCleanup(testUpstream1.Close)
|
|
||||||
|
|
||||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.2")
|
||||||
DeferCleanup(testUpstream2.Close)
|
DeferCleanup(testUpstream2.Close)
|
||||||
|
|
||||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
upstreams = []config.Upstream{timeoutUpstream.Start(), testUpstream2.Start()}
|
||||||
})
|
})
|
||||||
It("should ask a other random upstream and return its response", func() {
|
It("should ask another random upstream and return its response", func() {
|
||||||
request := newRequest("example.com", A)
|
request := newRequest("example.com", A)
|
||||||
Expect(sut.Resolve(request)).Should(
|
Expect(sut.Resolve(ctx, request)).Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
||||||
HaveTTL(BeNumerically("==", 123)),
|
HaveTTL(BeNumerically("==", 123)),
|
||||||
|
@ -382,57 +407,25 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
When("two upstreams exceed timeout", func() {
|
When("all upstreams exceed timeout", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
testUpstream1 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
testUpstream1 := NewMockUDPUpstreamServer().
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.1")
|
WithAnswerRR("example.com 123 IN A 123.124.122.1").
|
||||||
time.Sleep(timeout.ToDuration() + 2*time.Second)
|
WithDelay(2 * timeout)
|
||||||
|
|
||||||
Expect(err).To(Succeed())
|
|
||||||
|
|
||||||
return response
|
|
||||||
})
|
|
||||||
DeferCleanup(testUpstream1.Close)
|
DeferCleanup(testUpstream1.Close)
|
||||||
|
|
||||||
testUpstream2 := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
testUpstream2 := NewMockUDPUpstreamServer().
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.2")
|
WithAnswerRR("example.com 123 IN A 123.124.122.2").
|
||||||
time.Sleep(timeout.ToDuration() + 2*time.Second)
|
WithDelay(2 * timeout)
|
||||||
|
|
||||||
Expect(err).To(Succeed())
|
|
||||||
|
|
||||||
return response
|
|
||||||
})
|
|
||||||
DeferCleanup(testUpstream2.Close)
|
DeferCleanup(testUpstream2.Close)
|
||||||
|
|
||||||
testUpstream3 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.3")
|
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start()}
|
||||||
DeferCleanup(testUpstream3.Close)
|
|
||||||
|
|
||||||
upstreams = []config.Upstream{testUpstream1.Start(), testUpstream2.Start(), testUpstream3.Start()}
|
|
||||||
})
|
})
|
||||||
// These two tests are flaky -_- (maybe recreate the RandomResolver )
|
It("Should return error", func() {
|
||||||
It("should not return error (due to random selection the request could to through)", func() {
|
request := newRequest("example.com.", A)
|
||||||
Eventually(func() error {
|
_, err := sut.Resolve(ctx, request)
|
||||||
request := newRequest("example.com", A)
|
Expect(err).Should(HaveOccurred())
|
||||||
_, err := sut.Resolve(request)
|
Expect(isTimeout(err)).Should(BeTrue())
|
||||||
|
|
||||||
return err
|
|
||||||
}).WithTimeout(30 * time.Second).
|
|
||||||
Should(Not(HaveOccurred()))
|
|
||||||
})
|
|
||||||
It("should return error (because it can be possible that the two broken upstreams are chosen)", func() {
|
|
||||||
Eventually(func() error {
|
|
||||||
sutConfig := config.UpstreamGroup{
|
|
||||||
Name: upstreamDefaultCfgName,
|
|
||||||
Upstreams: upstreams,
|
|
||||||
}
|
|
||||||
sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify)
|
|
||||||
|
|
||||||
request := newRequest("example.com", A)
|
|
||||||
_, err := sut.Resolve(request)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}).WithTimeout(30 * time.Second).
|
|
||||||
Should(HaveOccurred())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -446,7 +439,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
})
|
})
|
||||||
It("Should return error", func() {
|
It("Should return error", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
_, err := sut.Resolve(request)
|
_, err := sut.Resolve(ctx, request)
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -461,7 +454,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
It("Should use result from defined resolver", func() {
|
It("Should use result from defined resolver", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -485,7 +478,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
DeferCleanup(mockUpstream2.Close)
|
DeferCleanup(mockUpstream2.Close)
|
||||||
|
|
||||||
sut, _ = NewParallelBestResolver(config.UpstreamGroup{
|
sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{
|
||||||
Name: upstreamDefaultCfgName,
|
Name: upstreamDefaultCfgName,
|
||||||
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2},
|
||||||
},
|
},
|
||||||
|
@ -506,7 +499,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
By("perform 100 request, error upstream's weight will be reduced", func() {
|
By("perform 100 request, error upstream's weight will be reduced", func() {
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
_, _ = sut.Resolve(request)
|
_, _ = sut.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -111,12 +111,12 @@ func (r *QueryLoggingResolver) doCleanUp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve logs the query, duration and the result
|
// Resolve logs the query, duration and the result
|
||||||
func (r *QueryLoggingResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *QueryLoggingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
|
logger := log.WithPrefix(request.Log, queryLoggingResolverType)
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
resp, err := r.next.Resolve(request)
|
resp, err := r.next.Resolve(ctx, request)
|
||||||
|
|
||||||
duration := time.Since(start).Milliseconds()
|
duration := time.Since(start).Milliseconds()
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,9 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
tmpDir *TmpFolder
|
tmpDir *TmpFolder
|
||||||
mockAnswer *dns.Msg
|
mockAnswer *dns.Msg
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -53,6 +56,9 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
mockAnswer = new(dns.Msg)
|
mockAnswer = new(dns.Msg)
|
||||||
tmpDir = NewTmpFolder("queryLoggingResolver")
|
tmpDir = NewTmpFolder("queryLoggingResolver")
|
||||||
Expect(tmpDir.Error).Should(Succeed())
|
Expect(tmpDir.Error).Should(Succeed())
|
||||||
|
@ -64,9 +70,6 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
sutConfig.SetDefaults() // not called when using a struct literal
|
sutConfig.SetDefaults() // not called when using a struct literal
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
|
||||||
DeferCleanup(cancelFn)
|
|
||||||
|
|
||||||
sut = NewQueryLoggingResolver(ctx, sutConfig)
|
sut = NewQueryLoggingResolver(ctx, sutConfig)
|
||||||
m = &mockResolver{}
|
m = &mockResolver{}
|
||||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
|
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil)
|
||||||
|
@ -98,7 +101,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("should process request without query logging", func() {
|
It("should process request without query logging", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -120,7 +123,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
})
|
})
|
||||||
It("should create a log file per client", func() {
|
It("should create a log file per client", func() {
|
||||||
By("request from client 1", func() {
|
By("request from client 1", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -128,7 +131,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
By("request from client 2, has name with special chars, should be escaped", func() {
|
By("request from client 2, has name with special chars, should be escaped", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient(
|
Expect(sut.Resolve(ctx, newRequestWithClient(
|
||||||
"example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))).
|
"example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
|
@ -188,7 +191,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
})
|
})
|
||||||
It("should create one log file for all clients", func() {
|
It("should create one log file for all clients", func() {
|
||||||
By("request from client 1", func() {
|
By("request from client 1", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -196,7 +199,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
By("request from client 2, has name with special chars, should be escaped", func() {
|
By("request from client 2, has name with special chars, should be escaped", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -249,7 +252,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
})
|
})
|
||||||
It("should create one log file", func() {
|
It("should create one log file", func() {
|
||||||
By("request from client 1", func() {
|
By("request from client 1", func() {
|
||||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeRESOLVED),
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
@ -297,7 +300,7 @@ var _ = Describe("QueryLoggingResolver", func() {
|
||||||
sut.writer = mockWriter
|
sut.writer = mockWriter
|
||||||
|
|
||||||
Eventually(func() int {
|
Eventually(func() int {
|
||||||
_, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
|
_, ierr := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))
|
||||||
Expect(ierr).Should(Succeed())
|
Expect(ierr).Should(Succeed())
|
||||||
|
|
||||||
return len(sut.logChan)
|
return len(sut.logChan)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
@ -74,7 +75,7 @@ type Resolver interface {
|
||||||
Type() string
|
Type() string
|
||||||
|
|
||||||
// Resolve performs resolution of a DNS request
|
// Resolve performs resolution of a DNS request
|
||||||
Resolve(req *model.Request) (*model.Response, error)
|
Resolve(ctx context.Context, req *model.Request) (*model.Response, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChainedResolver represents a resolver, which can delegate result to the next one
|
// ChainedResolver represents a resolver, which can delegate result to the next one
|
||||||
|
@ -216,13 +217,14 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func createResolvers(
|
func createResolvers(
|
||||||
logger *logrus.Entry, cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
|
ctx context.Context, logger *logrus.Entry,
|
||||||
|
cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool,
|
||||||
) ([]Resolver, error) {
|
) ([]Resolver, error) {
|
||||||
resolvers := make([]Resolver, 0, len(cfg.Upstreams))
|
resolvers := make([]Resolver, 0, len(cfg.Upstreams))
|
||||||
hasValidResolvers := false
|
hasValidResolvers := false
|
||||||
|
|
||||||
for _, u := range cfg.Upstreams {
|
for _, u := range cfg.Upstreams {
|
||||||
resolver, err := NewUpstreamResolver(u, bootstrap, shoudVerifyUpstreams)
|
resolver, err := NewUpstreamResolver(ctx, u, bootstrap, shoudVerifyUpstreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warnf("upstream group %s: %v", cfg.Name, err)
|
logger.Warnf("upstream group %s: %v", cfg.Name, err)
|
||||||
|
|
||||||
|
@ -230,7 +232,7 @@ func createResolvers(
|
||||||
}
|
}
|
||||||
|
|
||||||
if shoudVerifyUpstreams {
|
if shoudVerifyUpstreams {
|
||||||
err = testResolver(resolver)
|
err = testResolver(ctx, resolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ func (r *RewriterResolver) LogConfig(logger *logrus.Entry) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve uses the inner resolver to resolve the rewritten query
|
// Resolve uses the inner resolver to resolve the rewritten query
|
||||||
func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, "rewriter_resolver")
|
logger := log.WithPrefix(request.Log, "rewriter_resolver")
|
||||||
|
|
||||||
original := request.Req
|
original := request.Req
|
||||||
|
@ -69,7 +70,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
||||||
|
|
||||||
logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver")
|
logger.WithField("resolver", Name(r.inner)).Trace("go to inner resolver")
|
||||||
|
|
||||||
response, err := r.inner.Resolve(request)
|
response, err := r.inner.Resolve(ctx, request)
|
||||||
// Test for error after checking for fallbackUpstream
|
// Test for error after checking for fallbackUpstream
|
||||||
|
|
||||||
// Revert the request: must be done before calling r.next
|
// Revert the request: must be done before calling r.next
|
||||||
|
@ -80,7 +81,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
||||||
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain
|
// Inner resolver had no answer, configuration requests fallback, continue with the normal chain
|
||||||
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Trace("fallback to next resolver")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -91,7 +92,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
||||||
// Inner resolver had no response, continue with the normal chain
|
// Inner resolver had no response, continue with the normal chain
|
||||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Revert the rewrite in r.inner's response
|
// Revert the rewrite in r.inner's response
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
|
@ -94,7 +96,7 @@ var _ = Describe("RewriterResolver", func() {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := sut.Resolve(request)
|
resp, err := sut.Resolve(context.Background(), request)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
if resp != mNextResponse {
|
if resp != mNextResponse {
|
||||||
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
|
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||||
|
@ -132,18 +134,18 @@ var _ = Describe("RewriterResolver", func() {
|
||||||
expectNilAnswer = true
|
expectNilAnswer = true
|
||||||
|
|
||||||
// Make inner call the NoOpResolver
|
// Make inner call the NoOpResolver
|
||||||
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
Expect(req).Should(Equal(request))
|
Expect(req).Should(Equal(request))
|
||||||
|
|
||||||
// Inner should see fqdnRewritten
|
// Inner should see fqdnRewritten
|
||||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten))
|
Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten))
|
||||||
|
|
||||||
return mInner.next.Resolve(req)
|
return mInner.next.Resolve(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolver after RewriterResolver should see `fqdnOriginal`
|
// Resolver after RewriterResolver should see `fqdnOriginal`
|
||||||
mNext.On("Resolve", mock.Anything)
|
mNext.On("Resolve", mock.Anything)
|
||||||
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||||
|
|
||||||
return mNextResponse, nil
|
return mNextResponse, nil
|
||||||
|
@ -156,7 +158,7 @@ var _ = Describe("RewriterResolver", func() {
|
||||||
expectNilAnswer = true
|
expectNilAnswer = true
|
||||||
|
|
||||||
// Make inner return a nil Answer but not an empty Response
|
// Make inner return a nil Answer but not an empty Response
|
||||||
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
Expect(req).Should(Equal(request))
|
Expect(req).Should(Equal(request))
|
||||||
|
|
||||||
// Inner should see fqdnRewritten
|
// Inner should see fqdnRewritten
|
||||||
|
@ -179,7 +181,7 @@ var _ = Describe("RewriterResolver", func() {
|
||||||
fqdnRewritten = sampleRewritten
|
fqdnRewritten = sampleRewritten
|
||||||
|
|
||||||
// Make inner return a nil Answer but not an empty Response
|
// Make inner return a nil Answer but not an empty Response
|
||||||
mInner.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
Expect(req).Should(Equal(request))
|
Expect(req).Should(Equal(request))
|
||||||
|
|
||||||
// Inner should see fqdnRewritten
|
// Inner should see fqdnRewritten
|
||||||
|
@ -190,7 +192,7 @@ var _ = Describe("RewriterResolver", func() {
|
||||||
|
|
||||||
// Resolver after RewriterResolver should see `fqdnOriginal`
|
// Resolver after RewriterResolver should see `fqdnOriginal`
|
||||||
mNext.On("Resolve", mock.Anything)
|
mNext.On("Resolve", mock.Anything)
|
||||||
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) {
|
||||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||||
|
|
||||||
return mNextResponse, nil
|
return mNextResponse, nil
|
||||||
|
|
|
@ -30,11 +30,11 @@ type StrictResolver struct {
|
||||||
|
|
||||||
// NewStrictResolver creates a new strict resolver instance
|
// NewStrictResolver creates a new strict resolver instance
|
||||||
func NewStrictResolver(
|
func NewStrictResolver(
|
||||||
cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
|
||||||
) (*StrictResolver, error) {
|
) (*StrictResolver, error) {
|
||||||
logger := log.PrefixedLog(strictResolverType)
|
logger := log.PrefixedLog(strictResolverType)
|
||||||
|
|
||||||
resolvers, err := createResolvers(logger, cfg, bootstrap, shouldVerifyUpstreams)
|
resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -76,44 +76,27 @@ func (r *StrictResolver) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve sends the query request in a strict order to the upstream resolvers
|
// Resolve sends the query request in a strict order to the upstream resolvers
|
||||||
func (r *StrictResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *StrictResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, strictResolverType)
|
logger := log.WithPrefix(request.Log, strictResolverType)
|
||||||
|
|
||||||
// start with first resolver
|
// start with first resolver
|
||||||
for i := range r.resolvers {
|
for _, resolver := range r.resolvers {
|
||||||
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
resolver := r.resolvers[i]
|
|
||||||
logger.Debugf("using %s as resolver", resolver.resolver)
|
logger.Debugf("using %s as resolver", resolver.resolver)
|
||||||
|
|
||||||
ch := make(chan requestResponse, 1)
|
resp, err := resolver.resolve(ctx, request)
|
||||||
|
if err != nil {
|
||||||
go resolver.resolve(request, ch)
|
// log error and try next upstream
|
||||||
|
logger.WithField("resolver", resolver.resolver).Debug("resolution failed from resolver, cause: ", err)
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// log debug/info that timeout exceeded, call `continue` to try next upstream
|
|
||||||
logger.WithField("resolver", r.resolvers[i].resolver).Debug("upstream exceeded timeout, trying next upstream")
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
case result := <-ch:
|
|
||||||
if result.err != nil {
|
|
||||||
// log error & call `continue` to try next upstream
|
|
||||||
logger.Debug("resolution failed from resolver, cause: ", result.err)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"resolver": *result.resolver,
|
|
||||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
|
||||||
}).Debug("using response from resolver")
|
|
||||||
|
|
||||||
return result.response, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"resolver": *resolver,
|
||||||
|
"answer": util.AnswerToString(resp.Res.Answer),
|
||||||
|
}).Debug("using response from resolver")
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("resolution was not successful, no resolver returned an answer in time")
|
return nil, errors.New("resolution was not successful, no resolver returned an answer in time")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -27,6 +28,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
err error
|
err error
|
||||||
|
|
||||||
bootstrap *Bootstrap
|
bootstrap *Bootstrap
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -36,6 +40,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
upstreams = []config.Upstream{
|
upstreams = []config.Upstream{
|
||||||
{Host: "wrong"},
|
{Host: "wrong"},
|
||||||
{Host: "127.0.0.2"},
|
{Host: "127.0.0.2"},
|
||||||
|
@ -51,7 +58,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
Name: upstreamDefaultCfgName,
|
Name: upstreamDefaultCfgName,
|
||||||
Upstreams: upstreams,
|
Upstreams: upstreams,
|
||||||
}
|
}
|
||||||
sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify)
|
sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify)
|
||||||
})
|
})
|
||||||
|
|
||||||
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
|
config.GetConfig().Upstreams.Timeout = config.Duration(time.Second)
|
||||||
|
@ -100,7 +107,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
mockUpstream.Start(),
|
mockUpstream.Start(),
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := NewStrictResolver(config.UpstreamGroup{
|
_, err := NewStrictResolver(ctx, config.UpstreamGroup{
|
||||||
Name: upstreamDefaultCfgName,
|
Name: upstreamDefaultCfgName,
|
||||||
Upstreams: upstreams,
|
Upstreams: upstreams,
|
||||||
},
|
},
|
||||||
|
@ -151,7 +158,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
It("Should use result from first one", func() {
|
It("Should use result from first one", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -180,7 +187,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should return response from next upstream", func() {
|
It("should return response from next upstream", func() {
|
||||||
request := newRequest("example.com", A)
|
request := newRequest("example.com", A)
|
||||||
Expect(sut.Resolve(request)).Should(
|
Expect(sut.Resolve(ctx, request)).Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
BeDNSRecord("example.com.", A, "123.124.122.2"),
|
||||||
HaveTTL(BeNumerically("==", 123)),
|
HaveTTL(BeNumerically("==", 123)),
|
||||||
|
@ -214,7 +221,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
It("should return error", func() {
|
It("should return error", func() {
|
||||||
request := newRequest("example.com", A)
|
request := newRequest("example.com", A)
|
||||||
_, err := sut.Resolve(request)
|
_, err := sut.Resolve(ctx, request)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -230,7 +237,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
It("Should use result from second one", func() {
|
It("Should use result from second one", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
BeDNSRecord("example.com.", A, "123.124.122.123"),
|
||||||
|
@ -247,7 +254,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
})
|
})
|
||||||
It("Should return error", func() {
|
It("Should return error", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
_, err = sut.Resolve(request)
|
_, err = sut.Resolve(ctx, request)
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -262,7 +269,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() {
|
||||||
It("Should use result from defined resolver", func() {
|
It("Should use result from defined resolver", func() {
|
||||||
request := newRequest("example.com.", A)
|
request := newRequest("example.com.", A)
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -99,7 +100,7 @@ func NewSpecialUseDomainNamesResolver(cfg config.SUDN) *SpecialUseDomainNamesRes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *SpecialUseDomainNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
handler := r.handler(request)
|
handler := r.handler(request)
|
||||||
if handler != nil {
|
if handler != nil {
|
||||||
resp := handler(request, r.cfg)
|
resp := handler(request, r.cfg)
|
||||||
|
@ -108,7 +109,7 @@ func (r *SpecialUseDomainNamesResolver) Resolve(request *model.Request) (*model.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SpecialUseDomainNamesResolver) handler(request *model.Request) sudnHandler {
|
func (r *SpecialUseDomainNamesResolver) handler(request *model.Request) sudnHandler {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -19,6 +20,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
sut *SpecialUseDomainNamesResolver
|
sut *SpecialUseDomainNamesResolver
|
||||||
sutConfig config.SUDN
|
sutConfig config.SUDN
|
||||||
m *mockResolver
|
m *mockResolver
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
Describe("Type", func() {
|
Describe("Type", func() {
|
||||||
|
@ -30,6 +34,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
sutConfig, err = config.WithDefaults[config.SUDN]()
|
sutConfig, err = config.WithDefaults[config.SUDN]()
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
})
|
})
|
||||||
|
@ -48,7 +55,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
Describe("handlers", func() {
|
Describe("handlers", func() {
|
||||||
It("should have correct response type", func() {
|
It("should have correct response type", func() {
|
||||||
for domain, handler := range sudnHandlers {
|
for domain, handler := range sudnHandlers {
|
||||||
resp, err := sut.Resolve(newRequest(domain, A))
|
resp, err := sut.Resolve(ctx, newRequest(domain, A))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
|
@ -90,7 +97,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
|
|
||||||
DescribeTable("handled domains",
|
DescribeTable("handled domains",
|
||||||
func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) {
|
func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) {
|
||||||
resp, err := sut.Resolve(newRequest(qName, qType))
|
resp, err := sut.Resolve(ctx, newRequest(qName, qType))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).Should(SatisfyAll(
|
Expect(resp).Should(SatisfyAll(
|
||||||
HaveResponseType(ResponseTypeSPECIAL),
|
HaveResponseType(ResponseTypeSPECIAL),
|
||||||
|
@ -133,7 +140,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
|
|
||||||
DescribeTable("",
|
DescribeTable("",
|
||||||
func(qType dns.Type, qName string, expectedRCode int) {
|
func(qType dns.Type, qName string, expectedRCode int) {
|
||||||
resp, err := sut.Resolve(newRequest(qName, qType))
|
resp, err := sut.Resolve(ctx, newRequest(qName, qType))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).Should(HaveReturnCode(expectedRCode))
|
Expect(resp).Should(HaveReturnCode(expectedRCode))
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
|
||||||
|
@ -150,7 +157,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should forward example.com", func() {
|
It("should forward example.com", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.145.123.145"),
|
BeDNSRecord("example.com.", A, "123.145.123.145"),
|
||||||
|
@ -161,7 +168,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should forward home.arpa. IN DS", func() {
|
It("should forward home.arpa. IN DS", func() {
|
||||||
Expect(sut.Resolve(newRequest("something.home.arpa.", DS))).
|
Expect(sut.Resolve(ctx, newRequest("something.home.arpa.", DS))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
// setup code doesn't care about the question
|
// setup code doesn't care about the question
|
||||||
|
@ -173,7 +180,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should forward non special use domains", func() {
|
It("should forward non special use domains", func() {
|
||||||
resp, err := sut.Resolve(newRequest("something.not-special.", AAAA))
|
resp, err := sut.Resolve(ctx, newRequest("something.not-special.", AAAA))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
|
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL))
|
||||||
})
|
})
|
||||||
|
|
|
@ -2,6 +2,7 @@ package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -39,8 +40,9 @@ type UpstreamResolver struct {
|
||||||
|
|
||||||
type upstreamClient interface {
|
type upstreamClient interface {
|
||||||
fmtURL(ip net.IP, port uint16, path string) string
|
fmtURL(ip net.IP, port uint16, path string) string
|
||||||
callExternal(msg *dns.Msg, upstreamURL string,
|
callExternal(
|
||||||
protocol model.RequestProtocol) (response *dns.Msg, rtt time.Duration, err error)
|
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
|
||||||
|
) (response *dns.Msg, rtt time.Duration, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type dnsUpstreamClient struct {
|
type dnsUpstreamClient struct {
|
||||||
|
@ -53,8 +55,6 @@ type httpUpstreamClient struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||||
timeout := config.GetConfig().Upstreams.Timeout.ToDuration()
|
|
||||||
|
|
||||||
tlsConfig := tls.Config{
|
tlsConfig := tls.Config{
|
||||||
ServerName: cfg.Host,
|
ServerName: cfg.Host,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
|
@ -73,7 +73,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||||
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
|
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
|
||||||
ForceAttemptHTTP2: true,
|
ForceAttemptHTTP2: true,
|
||||||
},
|
},
|
||||||
Timeout: timeout,
|
|
||||||
},
|
},
|
||||||
host: cfg.Host,
|
host: cfg.Host,
|
||||||
}
|
}
|
||||||
|
@ -83,7 +82,6 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||||
tcpClient: &dns.Client{
|
tcpClient: &dns.Client{
|
||||||
TLSConfig: &tlsConfig,
|
TLSConfig: &tlsConfig,
|
||||||
Net: cfg.Net.String(),
|
Net: cfg.Net.String(),
|
||||||
Timeout: timeout,
|
|
||||||
SingleInflight: true,
|
SingleInflight: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -92,12 +90,10 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||||
return &dnsUpstreamClient{
|
return &dnsUpstreamClient{
|
||||||
tcpClient: &dns.Client{
|
tcpClient: &dns.Client{
|
||||||
Net: "tcp",
|
Net: "tcp",
|
||||||
Timeout: timeout,
|
|
||||||
SingleInflight: true,
|
SingleInflight: true,
|
||||||
},
|
},
|
||||||
udpClient: &dns.Client{
|
udpClient: &dns.Client{
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Timeout: timeout,
|
|
||||||
SingleInflight: true,
|
SingleInflight: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -112,8 +108,8 @@ func (r *httpUpstreamClient) fmtURL(ip net.IP, port uint16, path string) string
|
||||||
return fmt.Sprintf("https://%s%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), path)
|
return fmt.Sprintf("https://%s%s", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))), path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
|
func (r *httpUpstreamClient) callExternal(
|
||||||
upstreamURL string, _ model.RequestProtocol,
|
ctx context.Context, msg *dns.Msg, upstreamURL string, _ model.RequestProtocol,
|
||||||
) (*dns.Msg, time.Duration, error) {
|
) (*dns.Msg, time.Duration, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
|
@ -122,7 +118,7 @@ func (r *httpUpstreamClient) callExternal(msg *dns.Msg,
|
||||||
return nil, 0, fmt.Errorf("can't pack message: %w", err)
|
return nil, 0, fmt.Errorf("can't pack message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(rawDNSMessage))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("can't create the new request %w", err)
|
return nil, 0, fmt.Errorf("can't create the new request %w", err)
|
||||||
}
|
}
|
||||||
|
@ -169,16 +165,16 @@ func (r *dnsUpstreamClient) fmtURL(ip net.IP, port uint16, _ string) string {
|
||||||
return net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
|
return net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
|
func (r *dnsUpstreamClient) callExternal(
|
||||||
upstreamURL string, protocol model.RequestProtocol,
|
ctx context.Context, msg *dns.Msg, upstreamURL string, protocol model.RequestProtocol,
|
||||||
) (response *dns.Msg, rtt time.Duration, err error) {
|
) (response *dns.Msg, rtt time.Duration, err error) {
|
||||||
if protocol == model.RequestProtocolTCP {
|
if protocol == model.RequestProtocolTCP {
|
||||||
response, rtt, err = r.tcpClient.Exchange(msg, upstreamURL)
|
response, rtt, err = r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||||
if err != nil {
|
if err != nil && r.udpClient != nil {
|
||||||
// try UDP as fallback
|
// try UDP as fallback
|
||||||
var opErr *net.OpError
|
var opErr *net.OpError
|
||||||
if errors.As(err, &opErr) && opErr.Op == "dial" && r.udpClient != nil {
|
if errors.As(err, &opErr) && opErr.Op == "dial" {
|
||||||
return r.udpClient.Exchange(msg, upstreamURL)
|
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,18 +182,20 @@ func (r *dnsUpstreamClient) callExternal(msg *dns.Msg,
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.udpClient != nil {
|
if r.udpClient != nil {
|
||||||
return r.udpClient.Exchange(msg, upstreamURL)
|
return r.udpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.tcpClient.Exchange(msg, upstreamURL)
|
return r.tcpClient.ExchangeContext(ctx, msg, upstreamURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUpstreamResolver creates new resolver instance
|
// NewUpstreamResolver creates new resolver instance
|
||||||
func NewUpstreamResolver(upstream config.Upstream, bootstrap *Bootstrap, verify bool) (*UpstreamResolver, error) {
|
func NewUpstreamResolver(
|
||||||
|
ctx context.Context, upstream config.Upstream, bootstrap *Bootstrap, verify bool,
|
||||||
|
) (*UpstreamResolver, error) {
|
||||||
r := newUpstreamResolverUnchecked(upstream, bootstrap)
|
r := newUpstreamResolverUnchecked(upstream, bootstrap)
|
||||||
|
|
||||||
if verify {
|
if verify {
|
||||||
_, err := r.bootstrap.UpstreamIPs(r)
|
_, err := r.bootstrap.UpstreamIPs(ctx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -234,50 +232,41 @@ func (r UpstreamResolver) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve calls external resolver
|
// Resolve calls external resolver
|
||||||
func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
ips, err := r.bootstrap.UpstreamIPs(r)
|
ips, err := r.bootstrap.UpstreamIPs(ctx, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
rtt time.Duration
|
|
||||||
resp *dns.Msg
|
resp *dns.Msg
|
||||||
ip net.IP
|
ip net.IP
|
||||||
)
|
)
|
||||||
|
|
||||||
err = retry.Do(
|
err = retry.Do(
|
||||||
func() error {
|
func() error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, config.GetConfig().Upstreams.Timeout.ToDuration())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
ip = ips.Current()
|
ip = ips.Current()
|
||||||
upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path)
|
upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path)
|
||||||
|
|
||||||
var err error
|
response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol)
|
||||||
resp, rtt, err = r.upstreamClient.callExternal(request.Req, upstreamURL, request.Protocol)
|
if err != nil {
|
||||||
if err == nil {
|
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
|
||||||
r.log().WithFields(logrus.Fields{
|
|
||||||
"answer": util.AnswerToString(resp.Answer),
|
|
||||||
"return_code": dns.RcodeToString[resp.Rcode],
|
|
||||||
"upstream": r.upstream.String(),
|
|
||||||
"upstream_ip": ip.String(),
|
|
||||||
"protocol": request.Protocol,
|
|
||||||
"net": r.upstream.Net,
|
|
||||||
"response_time_ms": rtt.Milliseconds(),
|
|
||||||
}).Debugf("received response from upstream")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err)
|
resp = response
|
||||||
|
r.logResponse(request, response, ip, rtt)
|
||||||
|
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
|
retry.Context(ctx),
|
||||||
retry.Attempts(retryAttempts),
|
retry.Attempts(retryAttempts),
|
||||||
retry.DelayType(retry.FixedDelay),
|
retry.DelayType(retry.FixedDelay),
|
||||||
retry.Delay(1*time.Millisecond),
|
retry.Delay(1*time.Millisecond),
|
||||||
retry.LastErrorOnly(true),
|
retry.LastErrorOnly(true),
|
||||||
retry.RetryIf(func(err error) bool {
|
retry.RetryIf(isTimeout),
|
||||||
var netErr net.Error
|
|
||||||
|
|
||||||
return errors.As(err, &netErr) && netErr.Timeout()
|
|
||||||
}),
|
|
||||||
retry.OnRetry(func(n uint, err error) {
|
retry.OnRetry(func(n uint, err error) {
|
||||||
r.log().WithFields(logrus.Fields{
|
r.log().WithFields(logrus.Fields{
|
||||||
"upstream": r.upstream.String(),
|
"upstream": r.upstream.String(),
|
||||||
|
@ -289,8 +278,31 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
|
||||||
ips.Next()
|
ips.Next()
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
// Make the error more user friendly than just "context deadline exceeded"
|
||||||
|
err = fmt.Errorf("timeout (%w)", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil
|
return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) {
|
||||||
|
r.log().WithFields(logrus.Fields{
|
||||||
|
"answer": util.AnswerToString(resp.Answer),
|
||||||
|
"return_code": dns.RcodeToString[resp.Rcode],
|
||||||
|
"upstream": r.upstream.String(),
|
||||||
|
"upstream_ip": ip.String(),
|
||||||
|
"protocol": request.Protocol,
|
||||||
|
"net": r.upstream.Net,
|
||||||
|
"response_time_ms": rtt.Milliseconds(),
|
||||||
|
}).Debugf("received response from upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeout(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
|
||||||
|
return errors.As(err, &netErr) && netErr.Timeout()
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -21,9 +22,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
var (
|
var (
|
||||||
sut *UpstreamResolver
|
sut *UpstreamResolver
|
||||||
sutConfig config.Upstream
|
sutConfig config.Upstream
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
sutConfig = config.Upstream{Host: "localhost"}
|
sutConfig = config.Upstream{Host: "localhost"}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -62,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
upstream := mockUpstream.Start()
|
upstream := mockUpstream.Start()
|
||||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -81,7 +88,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
upstream := mockUpstream.Start()
|
upstream := mockUpstream.Start()
|
||||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
HaveNoAnswer(),
|
HaveNoAnswer(),
|
||||||
|
@ -100,7 +107,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
upstream := mockUpstream.Start()
|
upstream := mockUpstream.Start()
|
||||||
sut := newUpstreamResolverUnchecked(upstream, nil)
|
sut := newUpstreamResolverUnchecked(upstream, nil)
|
||||||
|
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -133,7 +140,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
atomic.StoreInt32(&counter, 0)
|
atomic.StoreInt32(&counter, 0)
|
||||||
atomic.StoreInt32(&attemptsWithTimeout, 2)
|
atomic.StoreInt32(&attemptsWithTimeout, 2)
|
||||||
|
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -146,12 +153,37 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
By("3 attempts with timeout -> should return error", func() {
|
By("3 attempts with timeout -> should return error", func() {
|
||||||
atomic.StoreInt32(&counter, 0)
|
atomic.StoreInt32(&counter, 0)
|
||||||
atomic.StoreInt32(&attemptsWithTimeout, 3)
|
atomic.StoreInt32(&attemptsWithTimeout, 3)
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
Expect(err.Error()).Should(ContainSubstring("i/o timeout"))
|
Expect(err.Error()).Should(ContainSubstring("i/o timeout"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
When("user request is TCP", func() {
|
||||||
|
When("TCP upstream connection fails", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122")
|
||||||
|
DeferCleanup(mockUpstream.Close)
|
||||||
|
|
||||||
|
sutConfig = mockUpstream.Start()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should retry with UDP", func() {
|
||||||
|
req := newRequest("example.com.", A)
|
||||||
|
req.Protocol = RequestProtocolTCP
|
||||||
|
|
||||||
|
Expect(sut.Resolve(ctx, req)).
|
||||||
|
Should(
|
||||||
|
SatisfyAll(
|
||||||
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
HaveResponseType(ResponseTypeRESOLVED),
|
||||||
|
HaveReturnCode(dns.RcodeSuccess),
|
||||||
|
HaveTTL(BeNumerically("==", 123)),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Describe("Using Dns over HTTP (DOH) upstream", func() {
|
Describe("Using Dns over HTTP (DOH) upstream", func() {
|
||||||
|
@ -185,7 +217,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
})
|
})
|
||||||
When("Configured DOH resolver can resolve query", func() {
|
When("Configured DOH resolver can resolve query", func() {
|
||||||
It("should return answer from DNS upstream", func() {
|
It("should return answer from DNS upstream", func() {
|
||||||
Expect(sut.Resolve(newRequest("example.com.", A))).
|
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
BeDNSRecord("example.com.", A, "123.124.122.122"),
|
||||||
|
@ -203,7 +235,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("should return error", func() {
|
It("should return error", func() {
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
|
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
|
||||||
})
|
})
|
||||||
|
@ -215,7 +247,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("should return error", func() {
|
It("should return error", func() {
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
Expect(err.Error()).Should(
|
Expect(err.Error()).Should(
|
||||||
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
|
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
|
||||||
|
@ -228,7 +260,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
It("should return error", func() {
|
It("should return error", func() {
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
|
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
|
||||||
})
|
})
|
||||||
|
@ -241,7 +273,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
}, systemResolverBootstrap)
|
}, systemResolverBootstrap)
|
||||||
})
|
})
|
||||||
It("should return error", func() {
|
It("should return error", func() {
|
||||||
_, err := sut.Resolve(newRequest("example.com.", A))
|
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
Expect(err.Error()).Should(Or(
|
Expect(err.Error()).Should(Or(
|
||||||
ContainSubstring("no such host"),
|
ContainSubstring("no such host"),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ func (r *UpstreamTreeResolver) String() string {
|
||||||
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
|
return fmt.Sprintf("%s upstreams %q", upstreamTreeResolverType, strings.Join(result, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
|
logger := log.WithPrefix(request.Log, upstreamTreeResolverType)
|
||||||
|
|
||||||
group := r.upstreamGroupByClient(request)
|
group := r.upstreamGroupByClient(request)
|
||||||
|
@ -72,7 +73,7 @@ func (r *UpstreamTreeResolver) Resolve(request *model.Request) (*model.Response,
|
||||||
// delegate request to group resolver
|
// delegate request to group resolver
|
||||||
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
|
logger.WithField("resolver", fmt.Sprintf("%s (%s)", group, r.branches[group].Type())).Debug("delegating to resolver")
|
||||||
|
|
||||||
return r.branches[group].Resolve(request)
|
return r.branches[group].Resolve(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
|
func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package resolver
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
. "github.com/0xERR0R/blocky/helpertest"
|
. "github.com/0xERR0R/blocky/helpertest"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
|
@ -139,7 +141,15 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
When("client specific resolvers are defined", func() {
|
When("client specific resolvers are defined", func() {
|
||||||
|
var (
|
||||||
|
ctx context.Context
|
||||||
|
cancelFn context.CancelFunc
|
||||||
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
|
ctx, cancelFn = context.WithCancel(context.Background())
|
||||||
|
DeferCleanup(cancelFn)
|
||||||
|
|
||||||
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
|
sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{
|
||||||
upstreamDefaultCfgName: {config.Upstream{}},
|
upstreamDefaultCfgName: {config.Upstream{}},
|
||||||
"laptop": {config.Upstream{}},
|
"laptop": {config.Upstream{}},
|
||||||
|
@ -191,7 +201,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use default if client name or IP don't match", func() {
|
It("Should use default if client name or IP don't match", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
|
request := newRequestWithClient("example.com.", A, "192.168.178.55", "test")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "default"),
|
BeDNSRecord("example.com.", A, "default"),
|
||||||
|
@ -202,7 +212,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client specific resolver if client name matches exact", func() {
|
It("Should use client specific resolver if client name matches exact", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
|
request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "laptop"),
|
BeDNSRecord("example.com.", A, "laptop"),
|
||||||
|
@ -213,7 +223,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client specific resolver if client name matches with wildcard", func() {
|
It("Should use client specific resolver if client name matches with wildcard", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
|
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "client-*-m"),
|
BeDNSRecord("example.com.", A, "client-*-m"),
|
||||||
|
@ -224,7 +234,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client specific resolver if client name matches with range wildcard", func() {
|
It("Should use client specific resolver if client name matches with range wildcard", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
|
request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "client[0-9]"),
|
BeDNSRecord("example.com.", A, "client[0-9]"),
|
||||||
|
@ -235,7 +245,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client specific resolver if client IP matches", func() {
|
It("Should use client specific resolver if client IP matches", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
|
request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||||
|
@ -246,7 +256,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client specific resolver if client name (containing IP) matches", func() {
|
It("Should use client specific resolver if client name (containing IP) matches", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
|
request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||||
|
@ -257,7 +267,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
|
It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
|
request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
|
BeDNSRecord("example.com.", A, "10.43.8.67/28"),
|
||||||
|
@ -268,7 +278,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use exact IP match before client name match", func() {
|
It("Should use exact IP match before client name match", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
|
request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
BeDNSRecord("example.com.", A, "192.168.178.33"),
|
||||||
|
@ -279,7 +289,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
It("Should use client name match before CIDR match", func() {
|
It("Should use client name match before CIDR match", func() {
|
||||||
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
|
request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop")
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
BeDNSRecord("example.com.", A, "laptop"),
|
BeDNSRecord("example.com.", A, "laptop"),
|
||||||
|
@ -293,7 +303,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() {
|
||||||
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
|
request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1")
|
||||||
request.Log = logger
|
request.Log = logger
|
||||||
|
|
||||||
Expect(sut.Resolve(request)).
|
Expect(sut.Resolve(ctx, request)).
|
||||||
Should(
|
Should(
|
||||||
SatisfyAll(
|
SatisfyAll(
|
||||||
SatisfyAny(
|
SatisfyAny(
|
||||||
|
|
|
@ -389,7 +389,7 @@ func createQueryResolver(
|
||||||
bootstrap *resolver.Bootstrap,
|
bootstrap *resolver.Bootstrap,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
) (r resolver.ChainedResolver, err error) {
|
) (r resolver.ChainedResolver, err error) {
|
||||||
upstreamBranches, uErr := createUpstreamBranches(cfg, bootstrap)
|
upstreamBranches, uErr := createUpstreamBranches(ctx, cfg, bootstrap)
|
||||||
if uErr != nil {
|
if uErr != nil {
|
||||||
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
|
return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr)
|
||||||
}
|
}
|
||||||
|
@ -398,7 +398,9 @@ func createQueryResolver(
|
||||||
|
|
||||||
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
|
blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap)
|
||||||
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream)
|
||||||
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(cfg.Conditional, bootstrap, cfg.StartVerifyUpstream)
|
condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(
|
||||||
|
ctx, cfg.Conditional, bootstrap, cfg.StartVerifyUpstream,
|
||||||
|
)
|
||||||
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
|
hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap)
|
||||||
|
|
||||||
err = multierror.Append(
|
err = multierror.Append(
|
||||||
|
@ -433,6 +435,7 @@ func createQueryResolver(
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUpstreamBranches(
|
func createUpstreamBranches(
|
||||||
|
ctx context.Context,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bootstrap *resolver.Bootstrap,
|
bootstrap *resolver.Bootstrap,
|
||||||
) (map[string]resolver.Resolver, error) {
|
) (map[string]resolver.Resolver, error) {
|
||||||
|
@ -453,11 +456,11 @@ func createUpstreamBranches(
|
||||||
|
|
||||||
switch cfg.Upstreams.Strategy {
|
switch cfg.Upstreams.Strategy {
|
||||||
case config.UpstreamStrategyParallelBest:
|
case config.UpstreamStrategyParallelBest:
|
||||||
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||||
case config.UpstreamStrategyStrict:
|
case config.UpstreamStrategyStrict:
|
||||||
upstream, err = resolver.NewStrictResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
upstream, err = resolver.NewStrictResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||||
case config.UpstreamStrategyRandom:
|
case config.UpstreamStrategyRandom:
|
||||||
upstream, err = resolver.NewParallelBestResolver(groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamBranches[group] = upstream
|
upstreamBranches[group] = upstream
|
||||||
|
@ -643,7 +646,7 @@ func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
|
||||||
|
|
||||||
r := createResolverRequest(w, request)
|
r := createResolverRequest(w, request)
|
||||||
|
|
||||||
response, err := s.queryResolver.Resolve(r)
|
response, err := s.queryResolver.Resolve(context.Background(), r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger().Error("error on processing request:", err)
|
logger().Error("error on processing request:", err)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
@ -149,7 +150,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
||||||
|
|
||||||
r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)
|
r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)
|
||||||
|
|
||||||
resResponse, err := s.queryResolver.Resolve(r)
|
resResponse, err := s.queryResolver.Resolve(req.Context(), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||||
|
|
||||||
|
@ -192,11 +193,11 @@ func extractIP(r *http.Request) string {
|
||||||
return hostPort
|
return hostPort
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Query(question string, qType dns.Type) (*model.Response, error) {
|
func (s *Server) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
|
||||||
dnsRequest := util.NewMsgWithQuestion(question, qType)
|
dnsRequest := util.NewMsgWithQuestion(question, qType)
|
||||||
r := createResolverRequest(nil, dnsRequest)
|
r := createResolverRequest(nil, dnsRequest)
|
||||||
|
|
||||||
return s.queryResolver.Resolve(r)
|
return s.queryResolver.Resolve(ctx, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
||||||
|
|
|
@ -76,10 +76,9 @@ var _ = BeforeSuite(func() {
|
||||||
|
|
||||||
clientMockUpstream = resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
clientMockUpstream = resolver.NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) {
|
||||||
var clientName string
|
var clientName string
|
||||||
client := mockClientName.Load()
|
|
||||||
|
|
||||||
if client != nil {
|
if name, ok := mockClientName.Load().(string); ok {
|
||||||
clientName = mockClientName.Load().(string)
|
clientName = name
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := util.NewMsgWithAnswer(
|
response, err := util.NewMsgWithAnswer(
|
||||||
|
@ -118,8 +117,7 @@ var _ = BeforeSuite(func() {
|
||||||
youtubeFile := tmpDir.CreateStringFile("youtube.com.txt", "youtube.com")
|
youtubeFile := tmpDir.CreateStringFile("youtube.com.txt", "youtube.com")
|
||||||
Expect(youtubeFile.Error).Should(Succeed())
|
Expect(youtubeFile.Error).Should(Succeed())
|
||||||
|
|
||||||
// create server
|
cfg := &config.Config{
|
||||||
sut, err = NewServer(ctx, &config.Config{
|
|
||||||
CustomDNS: config.CustomDNSConfig{
|
CustomDNS: config.CustomDNSConfig{
|
||||||
CustomTTL: config.Duration(3600 * time.Second),
|
CustomTTL: config.Duration(3600 * time.Second),
|
||||||
Mapping: config.CustomDNSMapping{
|
Mapping: config.CustomDNSMapping{
|
||||||
|
@ -160,7 +158,8 @@ var _ = BeforeSuite(func() {
|
||||||
BlockTTL: config.Duration(6 * time.Hour),
|
BlockTTL: config.Duration(6 * time.Hour),
|
||||||
},
|
},
|
||||||
Upstreams: config.UpstreamsConfig{
|
Upstreams: config.UpstreamsConfig{
|
||||||
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
|
Timeout: config.Duration(250 * time.Millisecond),
|
||||||
|
Groups: map[string][]config.Upstream{"default": {upstreamGoogle}},
|
||||||
},
|
},
|
||||||
ClientLookup: config.ClientLookupConfig{
|
ClientLookup: config.ClientLookupConfig{
|
||||||
Upstream: upstreamClient,
|
Upstream: upstreamClient,
|
||||||
|
@ -178,16 +177,19 @@ var _ = BeforeSuite(func() {
|
||||||
Enable: true,
|
Enable: true,
|
||||||
Path: "/metrics",
|
Path: "/metrics",
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Hacky but needed to update the global since we still have code that reads it
|
||||||
|
*config.GetConfig() = *cfg
|
||||||
|
|
||||||
|
// create server
|
||||||
|
sut, err = NewServer(ctx, cfg)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
errChan := make(chan error, 10)
|
errChan := make(chan error, 10)
|
||||||
|
|
||||||
// start server
|
// start server
|
||||||
go func() {
|
go sut.Start(ctx, errChan)
|
||||||
sut.Start(ctx, errChan)
|
|
||||||
}()
|
|
||||||
DeferCleanup(sut.Stop)
|
DeferCleanup(sut.Stop)
|
||||||
|
|
||||||
Consistently(errChan, "1s").ShouldNot(Receive())
|
Consistently(errChan, "1s").ShouldNot(Receive())
|
||||||
|
@ -697,7 +699,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
|
|
||||||
Describe("NewServer with strict upstream strategy", func() {
|
Describe("NewServer with strict upstream strategy", func() {
|
||||||
It("successfully returns upstream branches", func() {
|
It("successfully returns upstream branches", func() {
|
||||||
branches, err := createUpstreamBranches(&config.Config{
|
branches, err := createUpstreamBranches(context.Background(), &config.Config{
|
||||||
Upstreams: config.UpstreamsConfig{
|
Upstreams: config.UpstreamsConfig{
|
||||||
Strategy: config.UpstreamStrategyStrict,
|
Strategy: config.UpstreamStrategyStrict,
|
||||||
Groups: config.UpstreamGroups{
|
Groups: config.UpstreamGroups{
|
||||||
|
@ -715,7 +717,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
|
|
||||||
Describe("NewServer with random upstream strategy", func() {
|
Describe("NewServer with random upstream strategy", func() {
|
||||||
It("successfully returns upstream branches", func() {
|
It("successfully returns upstream branches", func() {
|
||||||
branches, err := createUpstreamBranches(&config.Config{
|
branches, err := createUpstreamBranches(context.Background(), &config.Config{
|
||||||
Upstreams: config.UpstreamsConfig{
|
Upstreams: config.UpstreamsConfig{
|
||||||
Strategy: config.UpstreamStrategyRandom,
|
Strategy: config.UpstreamStrategyRandom,
|
||||||
Groups: config.UpstreamGroups{
|
Groups: config.UpstreamGroups{
|
||||||
|
|
Loading…
Reference in New Issue