This commit is contained in:
Dimitri Herzog 2023-04-04 11:58:42 +02:00
parent 2b97d5550b
commit 1ead31464a
8 changed files with 353 additions and 75 deletions

View File

@ -63,6 +63,13 @@ func (x IPVersion) String() string {
return fmt.Sprintf("IPVersion(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x IPVersion) IsValid() bool {
_, ok := _IPVersionMap[x]
return ok
}
var _IPVersionValue = map[string]IPVersion{
_IPVersionName[0:4]: IPVersionDual,
_IPVersionName[4:6]: IPVersionV4,
@ -145,6 +152,13 @@ func (x NetProtocol) String() string {
return fmt.Sprintf("NetProtocol(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x NetProtocol) IsValid() bool {
_, ok := _NetProtocolMap[x]
return ok
}
var _NetProtocolValue = map[string]NetProtocol{
_NetProtocolName[0:7]: NetProtocolTcpUdp,
_NetProtocolName[7:14]: NetProtocolTcpTls,
@ -225,7 +239,8 @@ func (x QueryLogField) String() string {
return string(x)
}
// String implements the Stringer interface.
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x QueryLogField) IsValid() bool {
_, err := ParseQueryLogField(string(x))
return err == nil
@ -333,6 +348,13 @@ func (x QueryLogType) String() string {
return fmt.Sprintf("QueryLogType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x QueryLogType) IsValid() bool {
_, ok := _QueryLogTypeMap[x]
return ok
}
var _QueryLogTypeValue = map[string]QueryLogType{
_QueryLogTypeName[0:7]: QueryLogTypeConsole,
_QueryLogTypeName[7:11]: QueryLogTypeNone,
@ -418,6 +440,13 @@ func (x StartStrategyType) String() string {
return fmt.Sprintf("StartStrategyType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x StartStrategyType) IsValid() bool {
_, ok := _StartStrategyTypeMap[x]
return ok
}
var _StartStrategyTypeValue = map[string]StartStrategyType{
_StartStrategyTypeName[0:8]: StartStrategyTypeBlocking,
_StartStrategyTypeName[8:19]: StartStrategyTypeFailOnError,

View File

@ -49,6 +49,13 @@ func (x ListCacheType) String() string {
return fmt.Sprintf("ListCacheType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x ListCacheType) IsValid() bool {
_, ok := _ListCacheTypeMap[x]
return ok
}
var _ListCacheTypeValue = map[string]ListCacheType{
_ListCacheTypeName[0:9]: ListCacheTypeBlacklist,
_ListCacheTypeName[9:18]: ListCacheTypeWhitelist,

View File

@ -24,7 +24,6 @@ var _ = Describe("errorFilter", func() {
// mockParser.EXPECT().Next(mock.Anything).Return(struct{}{}, errors.New("fail")).Once()
// mockParser.EXPECT().Next(mock.Anything).Return(struct{}{}, NewNonResumableError(io.EOF)).Once()
// parser = mockParser
})
When("0 errors are allowed", func() {
@ -93,7 +92,6 @@ var _ = Describe("errorFilter", func() {
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
})
})
})

View File

@ -49,6 +49,13 @@ func (x FormatType) String() string {
return fmt.Sprintf("FormatType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x FormatType) IsValid() bool {
_, ok := _FormatTypeMap[x]
return ok
}
var _FormatTypeValue = map[string]FormatType{
_FormatTypeName[0:4]: FormatTypeText,
_FormatTypeName[4:8]: FormatTypeJson,
@ -130,6 +137,13 @@ func (x Level) String() string {
return fmt.Sprintf("Level(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x Level) IsValid() bool {
_, ok := _LevelMap[x]
return ok
}
var _LevelValue = map[string]Level{
_LevelName[0:4]: LevelInfo,
_LevelName[4:9]: LevelTrace,

View File

@ -49,6 +49,13 @@ func (x RequestProtocol) String() string {
return fmt.Sprintf("RequestProtocol(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x RequestProtocol) IsValid() bool {
_, ok := _RequestProtocolMap[x]
return ok
}
var _RequestProtocolValue = map[string]RequestProtocol{
_RequestProtocolName[0:3]: RequestProtocolTCP,
_RequestProtocolName[3:6]: RequestProtocolUDP,
@ -151,6 +158,13 @@ func (x ResponseType) String() string {
return fmt.Sprintf("ResponseType(%d)", x)
}
// IsValid provides a quick way to determine if the typed value is
// part of the allowed enumerated values
func (x ResponseType) IsValid() bool {
_, ok := _ResponseTypeMap[x]
return ok
}
var _ResponseTypeValue = map[string]ResponseType{
_ResponseTypeName[0:8]: ResponseTypeRESOLVED,
_ResponseTypeName[8:14]: ResponseTypeCACHED,

View File

@ -46,43 +46,56 @@ var _ = BeforeSuite(func() {
var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
var (
sut *BlockingResolver
sutConfig config.BlockingConfig
m *mockResolver
mockAnswer *dns.Msg
sut *BlockingResolver
sutConfig config.BlockingConfig
m1 *MockResolver
mockResponse *Response
)
Describe("Type", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
}
})
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
}
mockAnswer = new(dns.Msg)
mockResponse = &Response{Res: new(dns.Msg)}
})
JustBeforeEach(func() {
var err error
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
m1 = NewMockResolver(GinkgoT())
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
Expect(err).Should(Succeed())
sut.Next(m)
sut.Next(m1)
})
Describe("IsEnabled", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
}
})
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
Describe("LogConfig", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
BlockType: "ZEROIP",
BlockTTL: config.Duration(time.Minute),
}
})
It("should log something", func() {
logger, hook := log.NewMockEntry()
@ -138,13 +151,16 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Full-qualified group name is used", func() {
It("should block request", func() {
m.AnswerFn = func(t dns.Type, qName string) (*dns.Msg, error) {
if t == dns.Type(dns.TypeA) && qName == "full.qualified.com." {
return util.NewMsgWithAnswer(qName, 60*60, A, "192.168.178.39")
m1.On("Resolve", mock.Anything).Return(func(req *Request) (*Response, error) {
question := req.Req.Question[0]
if question.Qtype == dns.TypeA && question.Name == "full.qualified.com." {
r, _ := util.NewMsgWithAnswer(question.Name, 60*60, A, "192.168.178.39")
return &Response{Res: r}, nil
}
return nil, nil //nolint:nilnil
}
return mockResponse, nil
})
Bus().Publish(ApplicationStarted, "")
Eventually(func(g Gomega) {
g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
@ -176,6 +192,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Domain is on the black list", func() {
It("should block request", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
Eventually(sut.Resolve).
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
Should(
@ -293,6 +311,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
Should(
SatisfyAll(
@ -300,11 +320,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
Should(
SatisfyAll(
@ -312,9 +331,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
})
@ -534,11 +550,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Blacklist contains IP", func() {
When("IP4", func() {
BeforeEach(func() {
// return defined IP as response
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
})
It("should block query, if lookup result contains blacklisted IP", func() {
// return defined IP as response
mockAnswer, _ := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -551,14 +567,14 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
})
When("IP6", func() {
BeforeEach(func() {
It("should block query, if lookup result contains blacklisted IP", func() {
// return defined IP as response
mockAnswer, _ = util.NewMsgWithAnswer(
mockAnswer, _ := util.NewMsgWithAnswer(
"example.com.", 300,
AAAA, "2001:0db8:85a3:08d3::0370:7344",
)
})
It("should block query, if lookup result contains blacklisted IP", func() {
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -573,15 +589,15 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("blacklist contains domain which is CNAME in response", func() {
BeforeEach(func() {
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
// reconfigure mock, to return CNAMEs
rr1, _ := dns.NewRR("example.com 300 IN CNAME domain.com")
rr2, _ := dns.NewRR("domain.com 300 IN CNAME badcnamedomain.com")
rr3, _ := dns.NewRR("badcnamedomain.com 300 IN A 125.125.125.125")
mockAnswer = new(dns.Msg)
mockAnswer := new(dns.Msg)
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
})
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -609,6 +625,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}
})
It("Should not be blocked", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -616,9 +634,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
})
@ -641,6 +656,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
It("should block everything else except domains on the white list with default group", func() {
By("querying domain on the whitelist", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Once()
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -648,9 +665,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
By("querying another domain, which is not on the whitelist", func() {
@ -663,13 +677,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
HaveReason("BLOCKED (WHITELIST ONLY)"),
))
Expect(m.Calls).Should(HaveLen(1))
})
})
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Once()
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
Should(
SatisfyAll(
@ -677,9 +691,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
By("querying another domain, which is not on the whitelist", func() {
@ -692,12 +703,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
HaveReason("BLOCKED (WHITELIST ONLY)"),
))
Expect(m.Calls).Should(HaveLen(1))
})
})
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist group 1", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Twice()
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
Should(
SatisfyAll(
@ -705,9 +717,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
By("querying another domain, which is in the whitelist group 1", func() {
@ -718,7 +727,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
Expect(m.Calls).Should(HaveLen(2))
})
})
})
@ -734,9 +742,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
"default": {"gr1"},
},
}
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
})
It("should not block if DNS answer contains IP from the white list", func() {
mockAnswer, _ := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -744,8 +754,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
})
})
@ -761,12 +769,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
},
}
})
AfterEach(func() {
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
When("domain is not on the black list", func() {
It("should delegate to next resolver", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -784,6 +790,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
}
})
It("should delegate to next resolver", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
Should(
SatisfyAll(
@ -810,6 +818,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
})
When("Disable blocking is called", func() {
It("no query should be blocked", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Times(3)
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
@ -847,8 +857,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
})
By("perform the same query again (group1)", func() {
@ -861,8 +870,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
})
By("Calling Rest API to deactivate only defaultGroup", func() {
@ -880,8 +888,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 3)
m1.AssertExpectations(GinkgoT())
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 3)
})
By("Perform query to ensure that the blocking status is active (group1)", func() {
@ -899,6 +907,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking for all groups is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Twice()
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
@ -941,8 +951,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
})
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
// now is blocking disabled, query the url again
@ -953,9 +962,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
})
By("Wait 1 sec and perform the same query again, should be blocked now", func() {
@ -989,6 +996,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
When("Disable blocking for one group is called with a duration parameter", func() {
It("No query should be blocked only for passed amount of time", func() {
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Once()
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
Should(
@ -1042,8 +1050,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
HaveReturnCode(dns.RcodeSuccess),
))
m.AssertExpectations(GinkgoT())
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
})
By("Wait 1 sec and perform the same query again, should be blocked now", func() {

View File

@ -0,0 +1,207 @@
// Code generated by mockery v2.23.1. DO NOT EDIT.
package resolver
import (
logrus "github.com/sirupsen/logrus"
mock "github.com/stretchr/testify/mock"
model "github.com/0xERR0R/blocky/model"
)
// MockResolver is an autogenerated mock type for the Resolver type
type MockResolver struct {
mock.Mock
}
type MockResolver_Expecter struct {
mock *mock.Mock
}
func (_m *MockResolver) EXPECT() *MockResolver_Expecter {
return &MockResolver_Expecter{mock: &_m.Mock}
}
// IsEnabled provides a mock function with given fields:
func (_m *MockResolver) IsEnabled() bool {
ret := _m.Called()
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockResolver_IsEnabled_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsEnabled'
type MockResolver_IsEnabled_Call struct {
*mock.Call
}
// IsEnabled is a helper method to define mock.On call
func (_e *MockResolver_Expecter) IsEnabled() *MockResolver_IsEnabled_Call {
return &MockResolver_IsEnabled_Call{Call: _e.mock.On("IsEnabled")}
}
func (_c *MockResolver_IsEnabled_Call) Run(run func()) *MockResolver_IsEnabled_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockResolver_IsEnabled_Call) Return(_a0 bool) *MockResolver_IsEnabled_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockResolver_IsEnabled_Call) RunAndReturn(run func() bool) *MockResolver_IsEnabled_Call {
_c.Call.Return(run)
return _c
}
// LogConfig provides a mock function with given fields: _a0
func (_m *MockResolver) LogConfig(_a0 *logrus.Entry) {
_m.Called(_a0)
}
// MockResolver_LogConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogConfig'
type MockResolver_LogConfig_Call struct {
*mock.Call
}
// LogConfig is a helper method to define mock.On call
// - _a0 *logrus.Entry
func (_e *MockResolver_Expecter) LogConfig(_a0 interface{}) *MockResolver_LogConfig_Call {
return &MockResolver_LogConfig_Call{Call: _e.mock.On("LogConfig", _a0)}
}
func (_c *MockResolver_LogConfig_Call) Run(run func(_a0 *logrus.Entry)) *MockResolver_LogConfig_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*logrus.Entry))
})
return _c
}
func (_c *MockResolver_LogConfig_Call) Return() *MockResolver_LogConfig_Call {
_c.Call.Return()
return _c
}
func (_c *MockResolver_LogConfig_Call) RunAndReturn(run func(*logrus.Entry)) *MockResolver_LogConfig_Call {
_c.Call.Return(run)
return _c
}
// Resolve provides a mock function with given fields: req
func (_m *MockResolver) Resolve(req *model.Request) (*model.Response, error) {
ret := _m.Called(req)
var r0 *model.Response
var r1 error
if rf, ok := ret.Get(0).(func(*model.Request) (*model.Response, error)); ok {
return rf(req)
}
if rf, ok := ret.Get(0).(func(*model.Request) *model.Response); ok {
r0 = rf(req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Response)
}
}
if rf, ok := ret.Get(1).(func(*model.Request) error); ok {
r1 = rf(req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockResolver_Resolve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Resolve'
type MockResolver_Resolve_Call struct {
*mock.Call
}
// Resolve is a helper method to define mock.On call
// - req *model.Request
func (_e *MockResolver_Expecter) Resolve(req interface{}) *MockResolver_Resolve_Call {
return &MockResolver_Resolve_Call{Call: _e.mock.On("Resolve", req)}
}
func (_c *MockResolver_Resolve_Call) Run(run func(req *model.Request)) *MockResolver_Resolve_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*model.Request))
})
return _c
}
func (_c *MockResolver_Resolve_Call) Return(_a0 *model.Response, _a1 error) *MockResolver_Resolve_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockResolver_Resolve_Call) RunAndReturn(run func(*model.Request) (*model.Response, error)) *MockResolver_Resolve_Call {
_c.Call.Return(run)
return _c
}
// Type provides a mock function with given fields:
func (_m *MockResolver) Type() string {
ret := _m.Called()
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// MockResolver_Type_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Type'
type MockResolver_Type_Call struct {
*mock.Call
}
// Type is a helper method to define mock.On call
func (_e *MockResolver_Expecter) Type() *MockResolver_Type_Call {
return &MockResolver_Type_Call{Call: _e.mock.On("Type")}
}
func (_c *MockResolver_Type_Call) Run(run func()) *MockResolver_Type_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockResolver_Type_Call) Return(_a0 string) *MockResolver_Type_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockResolver_Type_Call) RunAndReturn(run func() string) *MockResolver_Type_Call {
_c.Call.Return(run)
return _c
}
type mockConstructorTestingTNewMockResolver interface {
mock.TestingT
Cleanup(func())
}
// NewMockResolver creates a new instance of MockResolver. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewMockResolver(t mockConstructorTestingTNewMockResolver) *MockResolver {
mock := &MockResolver{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -68,6 +68,8 @@ func newRequestWithClientID(question string, rType dns.Type, ip, requestClientID
}
// Resolver generic interface for all resolvers
//
//go:generate go run github.com/vektra/mockery/v2 --name Resolver
type Resolver interface {
config.Configurable