mirror of https://github.com/0xERR0R/blocky.git
add mock
This commit is contained in:
parent
2b97d5550b
commit
1ead31464a
|
@ -63,6 +63,13 @@ func (x IPVersion) String() string {
|
|||
return fmt.Sprintf("IPVersion(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x IPVersion) IsValid() bool {
|
||||
_, ok := _IPVersionMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _IPVersionValue = map[string]IPVersion{
|
||||
_IPVersionName[0:4]: IPVersionDual,
|
||||
_IPVersionName[4:6]: IPVersionV4,
|
||||
|
@ -145,6 +152,13 @@ func (x NetProtocol) String() string {
|
|||
return fmt.Sprintf("NetProtocol(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x NetProtocol) IsValid() bool {
|
||||
_, ok := _NetProtocolMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _NetProtocolValue = map[string]NetProtocol{
|
||||
_NetProtocolName[0:7]: NetProtocolTcpUdp,
|
||||
_NetProtocolName[7:14]: NetProtocolTcpTls,
|
||||
|
@ -225,7 +239,8 @@ func (x QueryLogField) String() string {
|
|||
return string(x)
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x QueryLogField) IsValid() bool {
|
||||
_, err := ParseQueryLogField(string(x))
|
||||
return err == nil
|
||||
|
@ -333,6 +348,13 @@ func (x QueryLogType) String() string {
|
|||
return fmt.Sprintf("QueryLogType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x QueryLogType) IsValid() bool {
|
||||
_, ok := _QueryLogTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _QueryLogTypeValue = map[string]QueryLogType{
|
||||
_QueryLogTypeName[0:7]: QueryLogTypeConsole,
|
||||
_QueryLogTypeName[7:11]: QueryLogTypeNone,
|
||||
|
@ -418,6 +440,13 @@ func (x StartStrategyType) String() string {
|
|||
return fmt.Sprintf("StartStrategyType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x StartStrategyType) IsValid() bool {
|
||||
_, ok := _StartStrategyTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _StartStrategyTypeValue = map[string]StartStrategyType{
|
||||
_StartStrategyTypeName[0:8]: StartStrategyTypeBlocking,
|
||||
_StartStrategyTypeName[8:19]: StartStrategyTypeFailOnError,
|
||||
|
|
|
@ -49,6 +49,13 @@ func (x ListCacheType) String() string {
|
|||
return fmt.Sprintf("ListCacheType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x ListCacheType) IsValid() bool {
|
||||
_, ok := _ListCacheTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _ListCacheTypeValue = map[string]ListCacheType{
|
||||
_ListCacheTypeName[0:9]: ListCacheTypeBlacklist,
|
||||
_ListCacheTypeName[9:18]: ListCacheTypeWhitelist,
|
||||
|
|
|
@ -24,7 +24,6 @@ var _ = Describe("errorFilter", func() {
|
|||
// mockParser.EXPECT().Next(mock.Anything).Return(struct{}{}, errors.New("fail")).Once()
|
||||
// mockParser.EXPECT().Next(mock.Anything).Return(struct{}{}, NewNonResumableError(io.EOF)).Once()
|
||||
// parser = mockParser
|
||||
|
||||
})
|
||||
|
||||
When("0 errors are allowed", func() {
|
||||
|
@ -93,7 +92,6 @@ var _ = Describe("errorFilter", func() {
|
|||
Expect(err).ShouldNot(Succeed())
|
||||
Expect(err).Should(MatchError(io.EOF))
|
||||
Expect(IsNonResumableErr(err)).Should(BeTrue())
|
||||
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -49,6 +49,13 @@ func (x FormatType) String() string {
|
|||
return fmt.Sprintf("FormatType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x FormatType) IsValid() bool {
|
||||
_, ok := _FormatTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _FormatTypeValue = map[string]FormatType{
|
||||
_FormatTypeName[0:4]: FormatTypeText,
|
||||
_FormatTypeName[4:8]: FormatTypeJson,
|
||||
|
@ -130,6 +137,13 @@ func (x Level) String() string {
|
|||
return fmt.Sprintf("Level(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x Level) IsValid() bool {
|
||||
_, ok := _LevelMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _LevelValue = map[string]Level{
|
||||
_LevelName[0:4]: LevelInfo,
|
||||
_LevelName[4:9]: LevelTrace,
|
||||
|
|
|
@ -49,6 +49,13 @@ func (x RequestProtocol) String() string {
|
|||
return fmt.Sprintf("RequestProtocol(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x RequestProtocol) IsValid() bool {
|
||||
_, ok := _RequestProtocolMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _RequestProtocolValue = map[string]RequestProtocol{
|
||||
_RequestProtocolName[0:3]: RequestProtocolTCP,
|
||||
_RequestProtocolName[3:6]: RequestProtocolUDP,
|
||||
|
@ -151,6 +158,13 @@ func (x ResponseType) String() string {
|
|||
return fmt.Sprintf("ResponseType(%d)", x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x ResponseType) IsValid() bool {
|
||||
_, ok := _ResponseTypeMap[x]
|
||||
return ok
|
||||
}
|
||||
|
||||
var _ResponseTypeValue = map[string]ResponseType{
|
||||
_ResponseTypeName[0:8]: ResponseTypeRESOLVED,
|
||||
_ResponseTypeName[8:14]: ResponseTypeCACHED,
|
||||
|
|
|
@ -46,43 +46,56 @@ var _ = BeforeSuite(func() {
|
|||
|
||||
var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||
var (
|
||||
sut *BlockingResolver
|
||||
sutConfig config.BlockingConfig
|
||||
m *mockResolver
|
||||
mockAnswer *dns.Msg
|
||||
sut *BlockingResolver
|
||||
sutConfig config.BlockingConfig
|
||||
m1 *MockResolver
|
||||
mockResponse *Response
|
||||
)
|
||||
|
||||
Describe("Type", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
}
|
||||
})
|
||||
|
||||
It("follows conventions", func() {
|
||||
expectValidResolverType(sut)
|
||||
})
|
||||
})
|
||||
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
}
|
||||
|
||||
mockAnswer = new(dns.Msg)
|
||||
mockResponse = &Response{Res: new(dns.Msg)}
|
||||
})
|
||||
|
||||
JustBeforeEach(func() {
|
||||
var err error
|
||||
m = &mockResolver{}
|
||||
m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
m1 = NewMockResolver(GinkgoT())
|
||||
sut, err = NewBlockingResolver(sutConfig, nil, systemResolverBootstrap)
|
||||
Expect(err).Should(Succeed())
|
||||
sut.Next(m)
|
||||
sut.Next(m1)
|
||||
})
|
||||
|
||||
Describe("IsEnabled", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
}
|
||||
})
|
||||
It("is false", func() {
|
||||
Expect(sut.IsEnabled()).Should(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("LogConfig", func() {
|
||||
BeforeEach(func() {
|
||||
sutConfig = config.BlockingConfig{
|
||||
BlockType: "ZEROIP",
|
||||
BlockTTL: config.Duration(time.Minute),
|
||||
}
|
||||
})
|
||||
It("should log something", func() {
|
||||
logger, hook := log.NewMockEntry()
|
||||
|
||||
|
@ -138,13 +151,16 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Full-qualified group name is used", func() {
|
||||
It("should block request", func() {
|
||||
m.AnswerFn = func(t dns.Type, qName string) (*dns.Msg, error) {
|
||||
if t == dns.Type(dns.TypeA) && qName == "full.qualified.com." {
|
||||
return util.NewMsgWithAnswer(qName, 60*60, A, "192.168.178.39")
|
||||
m1.On("Resolve", mock.Anything).Return(func(req *Request) (*Response, error) {
|
||||
question := req.Req.Question[0]
|
||||
if question.Qtype == dns.TypeA && question.Name == "full.qualified.com." {
|
||||
r, _ := util.NewMsgWithAnswer(question.Name, 60*60, A, "192.168.178.39")
|
||||
|
||||
return &Response{Res: r}, nil
|
||||
}
|
||||
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return mockResponse, nil
|
||||
})
|
||||
Bus().Publish(ApplicationStarted, "")
|
||||
Eventually(func(g Gomega) {
|
||||
g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))).
|
||||
|
@ -176,6 +192,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Domain is on the black list", func() {
|
||||
It("should block request", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
|
||||
|
||||
Eventually(sut.Resolve).
|
||||
WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")).
|
||||
Should(
|
||||
|
@ -293,6 +311,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() {
|
||||
It("should not block the query for 10.43.8.63 if domain is on the black list", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -300,11 +320,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
It("should not block the query for 10.43.8.80 if domain is on the black list", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -312,9 +331,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -534,11 +550,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Blacklist contains IP", func() {
|
||||
When("IP4", func() {
|
||||
BeforeEach(func() {
|
||||
// return defined IP as response
|
||||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
})
|
||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||
// return defined IP as response
|
||||
mockAnswer, _ := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -551,14 +567,14 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
})
|
||||
When("IP6", func() {
|
||||
BeforeEach(func() {
|
||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||
// return defined IP as response
|
||||
mockAnswer, _ = util.NewMsgWithAnswer(
|
||||
mockAnswer, _ := util.NewMsgWithAnswer(
|
||||
"example.com.", 300,
|
||||
AAAA, "2001:0db8:85a3:08d3::0370:7344",
|
||||
)
|
||||
})
|
||||
It("should block query, if lookup result contains blacklisted IP", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -573,15 +589,15 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
|
||||
When("blacklist contains domain which is CNAME in response", func() {
|
||||
BeforeEach(func() {
|
||||
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
|
||||
// reconfigure mock, to return CNAMEs
|
||||
rr1, _ := dns.NewRR("example.com 300 IN CNAME domain.com")
|
||||
rr2, _ := dns.NewRR("domain.com 300 IN CNAME badcnamedomain.com")
|
||||
rr3, _ := dns.NewRR("badcnamedomain.com 300 IN A 125.125.125.125")
|
||||
mockAnswer = new(dns.Msg)
|
||||
mockAnswer := new(dns.Msg)
|
||||
mockAnswer.Answer = []dns.RR{rr1, rr2, rr3}
|
||||
})
|
||||
It("should block the query, if response contains a CNAME with domain on a blacklist", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -609,6 +625,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("Should not be blocked", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -616,9 +634,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -641,6 +656,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
It("should block everything else except domains on the white list with default group", func() {
|
||||
By("querying domain on the whitelist", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Once()
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -648,9 +665,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
|
||||
By("querying another domain, which is not on the whitelist", func() {
|
||||
|
@ -663,13 +677,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
HaveReason("BLOCKED (WHITELIST ONLY)"),
|
||||
))
|
||||
|
||||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
})
|
||||
It("should block everything else except domains on the white list "+
|
||||
"if multiple white list only groups are defined", func() {
|
||||
By("querying domain on the whitelist", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Once()
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -677,9 +691,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
|
||||
By("querying another domain, which is not on the whitelist", func() {
|
||||
|
@ -692,12 +703,13 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
HaveReason("BLOCKED (WHITELIST ONLY)"),
|
||||
))
|
||||
Expect(m.Calls).Should(HaveLen(1))
|
||||
})
|
||||
})
|
||||
It("should block everything else except domains on the white list "+
|
||||
"if multiple white list only groups are defined", func() {
|
||||
By("querying domain on the whitelist group 1", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Twice()
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -705,9 +717,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
|
||||
By("querying another domain, which is in the whitelist group 1", func() {
|
||||
|
@ -718,7 +727,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
Expect(m.Calls).Should(HaveLen(2))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -734,9 +742,11 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
"default": {"gr1"},
|
||||
},
|
||||
}
|
||||
mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
})
|
||||
It("should not block if DNS answer contains IP from the white list", func() {
|
||||
mockAnswer, _ := util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145")
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(&Response{Res: mockAnswer}, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -744,8 +754,6 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -761,12 +769,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
},
|
||||
}
|
||||
})
|
||||
AfterEach(func() {
|
||||
// was delegated to next resolver
|
||||
m.AssertExpectations(GinkgoT())
|
||||
})
|
||||
When("domain is not on the black list", func() {
|
||||
It("should delegate to next resolver", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -784,6 +790,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
}
|
||||
})
|
||||
It("should delegate to next resolver", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil)
|
||||
|
||||
Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
SatisfyAll(
|
||||
|
@ -810,6 +818,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
})
|
||||
When("Disable blocking is called", func() {
|
||||
It("no query should be blocked", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Times(3)
|
||||
|
||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
|
@ -847,8 +857,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
|
||||
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
|
||||
})
|
||||
|
||||
By("perform the same query again (group1)", func() {
|
||||
|
@ -861,8 +870,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
|
||||
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
|
||||
})
|
||||
|
||||
By("Calling Rest API to deactivate only defaultGroup", func() {
|
||||
|
@ -880,8 +888,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 3)
|
||||
m1.AssertExpectations(GinkgoT())
|
||||
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 3)
|
||||
})
|
||||
|
||||
By("Perform query to ensure that the blocking status is active (group1)", func() {
|
||||
|
@ -899,6 +907,8 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Disable blocking for all groups is called with a duration parameter", func() {
|
||||
It("No query should be blocked only for passed amount of time", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Twice()
|
||||
|
||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
|
@ -941,8 +951,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
|
||||
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
|
||||
})
|
||||
By("perform the same query again to ensure that this query will not be blocked (group1)", func() {
|
||||
// now is blocking disabled, query the url again
|
||||
|
@ -953,9 +962,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveResponseType(ResponseTypeRESOLVED),
|
||||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
|
||||
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 2)
|
||||
})
|
||||
|
||||
By("Wait 1 sec and perform the same query again, should be blocked now", func() {
|
||||
|
@ -989,6 +996,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
|
||||
When("Disable blocking for one group is called with a duration parameter", func() {
|
||||
It("No query should be blocked only for passed amount of time", func() {
|
||||
m1.EXPECT().Resolve(mock.Anything).Return(mockResponse, nil).Once()
|
||||
By("Perform query to ensure that the blocking status is active (defaultGroup)", func() {
|
||||
Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))).
|
||||
Should(
|
||||
|
@ -1042,8 +1050,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
HaveReturnCode(dns.RcodeSuccess),
|
||||
))
|
||||
|
||||
m.AssertExpectations(GinkgoT())
|
||||
m.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
|
||||
m1.AssertNumberOfCalls(GinkgoT(), "Resolve", 1)
|
||||
})
|
||||
|
||||
By("Wait 1 sec and perform the same query again, should be blocked now", func() {
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
// Code generated by mockery v2.23.1. DO NOT EDIT.
|
||||
|
||||
package resolver
|
||||
|
||||
import (
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
model "github.com/0xERR0R/blocky/model"
|
||||
)
|
||||
|
||||
// MockResolver is an autogenerated mock type for the Resolver type
|
||||
type MockResolver struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockResolver_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockResolver) EXPECT() *MockResolver_Expecter {
|
||||
return &MockResolver_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// IsEnabled provides a mock function with given fields:
|
||||
func (_m *MockResolver) IsEnabled() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockResolver_IsEnabled_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsEnabled'
|
||||
type MockResolver_IsEnabled_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// IsEnabled is a helper method to define mock.On call
|
||||
func (_e *MockResolver_Expecter) IsEnabled() *MockResolver_IsEnabled_Call {
|
||||
return &MockResolver_IsEnabled_Call{Call: _e.mock.On("IsEnabled")}
|
||||
}
|
||||
|
||||
func (_c *MockResolver_IsEnabled_Call) Run(run func()) *MockResolver_IsEnabled_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_IsEnabled_Call) Return(_a0 bool) *MockResolver_IsEnabled_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_IsEnabled_Call) RunAndReturn(run func() bool) *MockResolver_IsEnabled_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// LogConfig provides a mock function with given fields: _a0
|
||||
func (_m *MockResolver) LogConfig(_a0 *logrus.Entry) {
|
||||
_m.Called(_a0)
|
||||
}
|
||||
|
||||
// MockResolver_LogConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LogConfig'
|
||||
type MockResolver_LogConfig_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// LogConfig is a helper method to define mock.On call
|
||||
// - _a0 *logrus.Entry
|
||||
func (_e *MockResolver_Expecter) LogConfig(_a0 interface{}) *MockResolver_LogConfig_Call {
|
||||
return &MockResolver_LogConfig_Call{Call: _e.mock.On("LogConfig", _a0)}
|
||||
}
|
||||
|
||||
func (_c *MockResolver_LogConfig_Call) Run(run func(_a0 *logrus.Entry)) *MockResolver_LogConfig_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*logrus.Entry))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_LogConfig_Call) Return() *MockResolver_LogConfig_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_LogConfig_Call) RunAndReturn(run func(*logrus.Entry)) *MockResolver_LogConfig_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Resolve provides a mock function with given fields: req
|
||||
func (_m *MockResolver) Resolve(req *model.Request) (*model.Response, error) {
|
||||
ret := _m.Called(req)
|
||||
|
||||
var r0 *model.Response
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(*model.Request) (*model.Response, error)); ok {
|
||||
return rf(req)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(*model.Request) *model.Response); ok {
|
||||
r0 = rf(req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.Response)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(*model.Request) error); ok {
|
||||
r1 = rf(req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockResolver_Resolve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Resolve'
|
||||
type MockResolver_Resolve_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Resolve is a helper method to define mock.On call
|
||||
// - req *model.Request
|
||||
func (_e *MockResolver_Expecter) Resolve(req interface{}) *MockResolver_Resolve_Call {
|
||||
return &MockResolver_Resolve_Call{Call: _e.mock.On("Resolve", req)}
|
||||
}
|
||||
|
||||
func (_c *MockResolver_Resolve_Call) Run(run func(req *model.Request)) *MockResolver_Resolve_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*model.Request))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_Resolve_Call) Return(_a0 *model.Response, _a1 error) *MockResolver_Resolve_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_Resolve_Call) RunAndReturn(run func(*model.Request) (*model.Response, error)) *MockResolver_Resolve_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Type provides a mock function with given fields:
|
||||
func (_m *MockResolver) Type() string {
|
||||
ret := _m.Called()
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockResolver_Type_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Type'
|
||||
type MockResolver_Type_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Type is a helper method to define mock.On call
|
||||
func (_e *MockResolver_Expecter) Type() *MockResolver_Type_Call {
|
||||
return &MockResolver_Type_Call{Call: _e.mock.On("Type")}
|
||||
}
|
||||
|
||||
func (_c *MockResolver_Type_Call) Run(run func()) *MockResolver_Type_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_Type_Call) Return(_a0 string) *MockResolver_Type_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockResolver_Type_Call) RunAndReturn(run func() string) *MockResolver_Type_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
type mockConstructorTestingTNewMockResolver interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}
|
||||
|
||||
// NewMockResolver creates a new instance of MockResolver. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
func NewMockResolver(t mockConstructorTestingTNewMockResolver) *MockResolver {
|
||||
mock := &MockResolver{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
|
@ -68,6 +68,8 @@ func newRequestWithClientID(question string, rType dns.Type, ip, requestClientID
|
|||
}
|
||||
|
||||
// Resolver generic interface for all resolvers
|
||||
//
|
||||
//go:generate go run github.com/vektra/mockery/v2 --name Resolver
|
||||
type Resolver interface {
|
||||
config.Configurable
|
||||
|
||||
|
|
Loading…
Reference in New Issue