#202: WhitelistOnly Fix for multiple entries (#199)

* Update blocking_resolver.go

Adjusted WhitelistOnly

* added test

* fixed golint issues

Co-authored-by: Dimitri Herzog <dimitri.herzog@gmail.com>
This commit is contained in:
invist 2021-05-05 22:07:14 +02:00 committed by GitHub
parent 1d511a3cd8
commit dd69a3e664
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 10 deletions

View File

@ -8,7 +8,6 @@ import (
"blocky/util"
"fmt"
"net"
"reflect"
"sort"
"strings"
"time"
@ -65,7 +64,7 @@ type BlockingResolver struct {
whitelistMatcher *lists.ListCache
cfg config.BlockingConfig
blockHandler blockHandler
whitelistOnlyGroups []string
whitelistOnlyGroups map[string]bool
status *status
}
@ -177,17 +176,17 @@ func (r *BlockingResolver) BlockingStatus() api.BlockingStatus {
}
// returns groups, which have only whitelist entries
func determineWhitelistOnlyGroups(cfg *config.BlockingConfig) (result []string) {
func determineWhitelistOnlyGroups(cfg *config.BlockingConfig) (result map[string]bool) {
result = make(map[string]bool)
for g, links := range cfg.WhiteLists {
if len(links) > 0 {
if _, found := cfg.BlackLists[g]; !found {
result = append(result, g)
result[g] = true
}
}
}
sort.Strings(result)
return
}
@ -230,10 +229,20 @@ func (r *BlockingResolver) Configuration() (result []string) {
return
}
func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool {
for _, group := range groupsToCheck {
if _, found := r.whitelistOnlyGroups[group]; found {
return true
}
}
return false
}
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
request *Request, logger *logrus.Entry) (*Response, error) {
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
whitelistOnlyAllowed := reflect.DeepEqual(groupsToCheck, r.whitelistOnlyGroups)
whitelistOnlyAllowed := r.hasWhiteListOnlyAllowed(groupsToCheck)
for _, question := range request.Req.Question {
domain := util.ExtractDomain(question)

View File

@ -359,13 +359,19 @@ badcnamedomain.com`)
When("Only whitelist is defined", func() {
BeforeEach(func() {
sutConfig = config.BlockingConfig{
WhiteLists: map[string][]string{"gr1": {group1File.Name()}},
WhiteLists: map[string][]string{
"gr1": {group1File.Name()},
"gr2": {group2File.Name()},
},
ClientGroupsBlock: map[string][]string{
"default": {"gr1"},
"default": {"gr1"},
"one-client": {"gr1"},
"two-client": {"gr2"},
"all-client": {"gr1", "gr2"},
},
}
})
It("should block everything else except domains on the white list", func() {
It("should block everything else except domains on the white list with default group", func() {
By("querying domain on the whitelist", func() {
resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "unknown"))
@ -379,6 +385,35 @@ badcnamedomain.com`)
Expect(resp.Reason).Should(Equal("BLOCKED (WHITELIST ONLY)"))
})
})
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist", func() {
resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "one-client"))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
By("querying another domain, which is not on the whitelist", func() {
resp, err = sut.Resolve(newRequestWithClient("blocked2.com.", dns.TypeA, "1.2.1.2", "one-client"))
Expect(m.Calls).Should(HaveLen(1))
Expect(resp.Reason).Should(Equal("BLOCKED (WHITELIST ONLY)"))
})
})
It("should block everything else except domains on the white list "+
"if multiple white list only groups are defined", func() {
By("querying domain on the whitelist group 1", func() {
resp, err = sut.Resolve(newRequestWithClient("domain1.com.", dns.TypeA, "1.2.1.2", "all-client"))
// was delegated to next resolver
m.AssertExpectations(GinkgoT())
})
By("querying another domain, which is in the whitelist group 1", func() {
resp, err = sut.Resolve(newRequestWithClient("blocked2.com.", dns.TypeA, "1.2.1.2", "all-client"))
Expect(m.Calls).Should(HaveLen(2))
})
})
})
When("IP address is on black and white list", func() {