update golangci-lint (#510)

* update golangci-lint

* enable gomnd linter

* enable asciicheck linter

* enable bidichk linter

* enable durationcheck linter

* enable errchkjson linter

* enable errorlint linter

* enable exhaustive linter

* enable gomoddirectives linter

* enable gomodguard guard

* enable grouper linter

* enable grouper and ifshort linters

* enable importas linter

* enable makezero linter

* enable nestif linter

* enable nilerr linter

* enable nilnil linter

* enable nlreturn linter

* enable nolintlint linter

* enable predeclared linter

* enable sqlclosecheck linter

* enable tenv linter

* enable wastedassign linter
This commit is contained in:
Dimitri Herzog 2022-05-10 09:09:50 +02:00 committed by GitHub
parent 41febafd41
commit a4b89537db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 411 additions and 234 deletions

View File

@ -16,5 +16,5 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.43.0
version: v1.45.2
args: --timeout 5m0s

View File

@ -1,41 +1,63 @@
linters:
enable:
- govet
- errcheck
- staticcheck
- unused
- gosimple
- structcheck
- varcheck
- ineffassign
- deadcode
- typecheck
- asciicheck
- bidichk
- bodyclose
- revive
- stylecheck
- gosec
- unconvert
- dupl
- goconst
- gocyclo
- gocognit
- gofmt
- goimports
- deadcode
- depguard
- misspell
- lll
- unparam
- dogsled
- dupl
- durationcheck
- errcheck
- errchkjson
- errorlint
- exhaustive
- exportloopref
- funlen
- gochecknoglobals
- gochecknoinits
- gocognit
- goconst
- gocritic
- gocyclo
- godox
- gofmt
- goimports
- gomnd
- gomoddirectives
- gomodguard
- gosec
- gosimple
- govet
- grouper
- ifshort
- importas
- ineffassign
- lll
- makezero
- misspell
- nakedret
- nestif
- nilerr
- nilnil
- nlreturn
- nolintlint
- prealloc
- predeclared
- revive
- sqlclosecheck
- staticcheck
- structcheck
- stylecheck
- tenv
- typecheck
- unconvert
- unparam
- unused
- varcheck
- wastedassign
- whitespace
- wsl
- exportloopref
disable:
- noctx
- scopelint
@ -46,6 +68,11 @@ linters:
- unused
fast: false
linters-settings:
gomnd:
ignored-numbers:
- '0666'
issues:
exclude-rules:
# Exclude some linters from running on tests files.

View File

@ -38,7 +38,7 @@ race: ## run tests with race detector
go test -race -short ./...
lint: build ## run golangcli-lint checks
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.43.0
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.45.2
$(shell go env GOPATH)/bin/golangci-lint run
run: build ## Build and run binary

View File

@ -30,6 +30,7 @@ func (b *BlockingControlMock) EnableBlocking() {
}
func (b *BlockingControlMock) DisableBlocking(_ time.Duration, disableGroups []string) error {
b.enabled = false
return nil
}

View File

@ -139,8 +139,7 @@ func isExpired(el *element) bool {
}
func calculateRemainTTL(expiresEpoch int64) time.Duration {
now := time.Now().UnixMilli()
if now < expiresEpoch {
if now := time.Now().UnixMilli(); now < expiresEpoch {
return time.Duration(expiresEpoch-now) * time.Millisecond
}

View File

@ -99,6 +99,7 @@ func (cache regexCache) Contains(searchString string) bool {
for _, regex := range cache {
if regex.MatchString(searchString) {
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
return true
}
}

View File

@ -40,6 +40,7 @@ func refreshList(_ *cobra.Command, _ []string) error {
if resp.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(resp.Body)
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
}

View File

@ -54,6 +54,7 @@ func query(cmd *cobra.Command, args []string) error {
if resp.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(resp.Body)
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
}

View File

@ -3,7 +3,6 @@ package cmd
import (
"fmt"
"os"
"strconv"
"strings"
"github.com/0xERR0R/blocky/config"
@ -20,6 +19,12 @@ var (
apiPort uint16
)
const (
defaultPort = 4000
defaultHost = "localhost"
defaultConfigPath = "./config.yml"
)
// NewRootCommand creates a new root cli command instance
func NewRootCommand() *cobra.Command {
c := &cobra.Command{
@ -35,9 +40,9 @@ Complete documentation is available at https://github.com/0xERR0R/blocky`,
SilenceUsage: true,
}
c.PersistentFlags().StringVarP(&configPath, "config", "c", "./config.yml", "path to config file")
c.PersistentFlags().StringVar(&apiHost, "apiHost", "localhost", "host of blocky (API). Default overridden by config and CLI.") // nolint:lll
c.PersistentFlags().Uint16Var(&apiPort, "apiPort", 4000, "port of blocky (API). Default overridden by config and CLI.")
c.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "path to config file")
c.PersistentFlags().StringVar(&apiHost, "apiHost", defaultHost, "host of blocky (API). Default overridden by config and CLI.") // nolint:lll
c.PersistentFlags().Uint16Var(&apiPort, "apiPort", defaultPort, "port of blocky (API). Default overridden by config and CLI.") // nolint:lll
c.AddCommand(newRefreshCommand(),
NewQueryCommand(),
@ -73,15 +78,14 @@ func initConfig() {
apiHost = strings.Join(split[:lastIdx], ":")
var p uint64
p, err := strconv.ParseUint(strings.TrimSpace(split[lastIdx]), 10, 16)
port, err := config.ConvertPort(split[lastIdx])
if err != nil {
util.FatalOnError("can't convert port to number (1 - 65535)", err)
return
}
apiPort = uint16(p)
apiPort = port
}
}

View File

@ -49,7 +49,8 @@ func startServer(_ *cobra.Command, _ []string) error {
return fmt.Errorf("can't start server: %w", err)
}
errChan := make(chan error, 10)
const errChanSize = 10
errChan := make(chan error, errChanSize)
srv.Start(errChan)

View File

@ -23,6 +23,12 @@ import (
"gopkg.in/yaml.v2"
)
const (
udpPort = 53
tlsPort = 853
httpsPort = 443
)
// NetProtocol resolver protocol ENUM(
// tcp+udp // TCP and UDP protocols
// tcp-tls // TCP-TLS protocol
@ -60,6 +66,7 @@ func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
func (s QTypeSet) Contains(qType dns.Type) bool {
_, found := s[QType(qType)]
return found
}
@ -79,9 +86,9 @@ func (c *Duration) String() string {
// nolint:gochecknoglobals
var netDefaultPort = map[NetProtocol]uint16{
NetProtocolTcpUdp: 53,
NetProtocolTcpTls: 853,
NetProtocolHttps: 443,
NetProtocolTcpUdp: udpPort,
NetProtocolTcpTls: tlsPort,
NetProtocolHttps: httpsPort,
}
// Upstream is the definition of external DNS server
@ -252,12 +259,14 @@ func (c *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
// duration is defined as number without unit
// use minutes to ensure back compatibility
*c = Duration(time.Duration(minutes) * time.Minute)
return nil
}
duration, err := time.ParseDuration(input)
if err == nil {
*c = Duration(duration)
return nil
}
@ -320,15 +329,15 @@ func ParseUpstream(upstream string) (Upstream, error) {
// string contains host:port
if err == nil {
var p uint64
p, err = strconv.ParseUint(strings.TrimSpace(portString), 10, 16)
p, err := ConvertPort(portString)
if err != nil {
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
return Upstream{}, err
}
port = uint16(p)
port = p
} else {
// only host, use default port
host = upstream
@ -340,9 +349,7 @@ func ParseUpstream(upstream string) (Upstream, error) {
}
// validate hostname or ip
ip := net.ParseIP(host)
if ip == nil {
if ip := net.ParseIP(host); ip == nil {
// is not IP
if !validDomain.MatchString(host) {
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
@ -544,6 +551,7 @@ func LoadConfig(path string, mandatory bool) (*Config, error) {
// config file does not exist
// return config with default values
config = &cfg
return config, nil
}
@ -596,3 +604,20 @@ func validateConfig(cfg *Config) (err error) {
func GetConfig() *Config {
return config
}
// ConvertPort converts string representation into a valid port (0 - 65535)
func ConvertPort(in string) (uint16, error) {
const (
base = 10
bitSize = 16
)
var p uint64
p, err := strconv.ParseUint(strings.TrimSpace(in), base, bitSize)
if err != nil {
return 0, err
}
return uint16(p), nil
}

View File

@ -269,6 +269,7 @@ bootstrapDns:
u := &Upstream{}
err := u.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "tcp+udp:1.2.3.4"
return nil
})
@ -292,6 +293,7 @@ bootstrapDns:
l := &ListenConfig{}
err := l.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "55,:56"
return nil
})
Expect(err).Should(Succeed())
@ -311,6 +313,7 @@ bootstrapDns:
d := Duration(0)
err := d.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "1m20s"
return nil
})
Expect(err).Should(Succeed())
@ -321,6 +324,7 @@ bootstrapDns:
d := Duration(0)
err := d.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "wrong"
return nil
})
Expect(err).Should(HaveOccurred())
@ -342,6 +346,7 @@ bootstrapDns:
c := &ConditionalUpstreamMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
return nil
})
Expect(err).Should(Succeed())
@ -364,6 +369,7 @@ bootstrapDns:
c := &CustomDNSMapping{}
err := c.UnmarshalYAML(func(i interface{}) error {
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
return nil
})
Expect(err).Should(Succeed())
@ -385,6 +391,7 @@ bootstrapDns:
t := QType(0)
err := t.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "AAAA"
return nil
})
Expect(err).Should(Succeed())
@ -395,6 +402,7 @@ bootstrapDns:
t := QType(0)
err := t.UnmarshalYAML(func(i interface{}) error {
*i.(*string) = "WRONGTYPE"
return nil
})
Expect(err).Should(HaveOccurred())

View File

@ -103,10 +103,10 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
func() error {
var resp *http.Response
var httpErr error
//nolint:bodyclose
if resp, httpErr = client.Get(link); httpErr == nil {
if resp.StatusCode == http.StatusOK {
body = resp.Body
return nil
}
@ -118,6 +118,7 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
if errors.As(httpErr, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
return &TransientError{inner: netErr}
}
return httpErr
},
retry.Attempts(d.downloadAttempts),

View File

@ -163,6 +163,7 @@ Loop:
}
default:
close(c)
break Loop
}
}
@ -307,9 +308,7 @@ func processLine(line string) string {
return ""
}
parts := strings.Fields(line)
if len(parts) > 0 {
if parts := strings.Fields(line); len(parts) > 0 {
host := parts[len(parts)-1]
ip := net.ParseIP(host)

View File

@ -152,6 +152,7 @@ var _ = Describe("ListCache", func() {
By("List couldn't be loaded due to 404 err", func() {
Eventually(func() bool {
found, _ := sut.Match("blocked1.com", []string{"gr1"})
return found
}, "1s").Should(BeFalse())
})
@ -313,5 +314,6 @@ type MockDownloader struct {
func (m *MockDownloader) DownloadFile(link string) (io.ReadCloser, error) {
fn := <-m.data
return fn()
}

View File

@ -45,6 +45,7 @@ var _ = Describe("DatabaseWriter", func() {
result := writer.db.Find(&logEntry{})
result.Count(&res)
return res
}, "1s").Should(BeNumerically("==", 1))
})
@ -91,6 +92,7 @@ var _ = Describe("DatabaseWriter", func() {
result := writer.db.Find(&logEntry{})
result.Count(&res)
return res
}, "1s").Should(BeNumerically("==", 2))
@ -102,6 +104,7 @@ var _ = Describe("DatabaseWriter", func() {
result := writer.db.Find(&logEntry{})
result.Count(&res)
return res
}, "1s").Should(BeNumerically("==", 1))
})

View File

@ -72,6 +72,8 @@ func (d *FileWriter) Write(entry *LogEntry) {
// CleanUp deletes old log files
func (d *FileWriter) CleanUp() {
const hoursPerDay = 24
logger := log.PrefixedLog(loggerPrefixFileWriter)
logger.Trace("starting clean up")
@ -85,7 +87,7 @@ func (d *FileWriter) CleanUp() {
if strings.HasSuffix(f.Name(), ".log") && len(f.Name()) > 10 {
t, err := time.Parse("2006-01-02", f.Name()[:10])
if err == nil {
differenceDays := uint64(time.Since(t).Hours() / 24)
differenceDays := uint64(time.Since(t).Hours() / hoursPerDay)
if d.logRetentionDays > 0 && differenceDays > d.logRetentionDays {
logger.WithFields(logrus.Fields{
"file": f.Name(),

View File

@ -69,7 +69,7 @@ type Client struct {
func New(cfg *config.RedisConfig) (*Client, error) {
// disable redis if no address is provided
if cfg == nil || len(cfg.Address) == 0 {
return nil, nil
return nil, nil // nolint:nilnil
}
rdb := redis.NewClient(&redis.Options{
@ -164,7 +164,12 @@ func (c *Client) startup() error {
select {
// received message from subscription
case msg := <-ps.Channel():
err = c.processReceivedMessage(msg)
c.l.Debug("Received message: ", msg)
if msg != nil && len(msg.Payload) > 0 {
// message is not empty
err = c.processReceivedMessage(msg)
}
// publish message from buffer
case s := <-c.sendBuffer:
c.publishMessageFromBuffer(s)
@ -201,34 +206,30 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
}
func (c *Client) processReceivedMessage(msg *redis.Message) (err error) {
c.l.Debug("Received message: ", msg)
// message is not empty
if msg != nil && len(msg.Payload) > 0 {
var rm redisMessage
var rm redisMessage
err = json.Unmarshal([]byte(msg.Payload), &rm)
if err == nil {
// message was sent from a different blocky instance
if !bytes.Equal(rm.Client, c.id) {
switch rm.Type {
case messageTypeCache:
var cm *CacheMessage
err = json.Unmarshal([]byte(msg.Payload), &rm)
if err == nil {
// message was sent from a different blocky instance
if !bytes.Equal(rm.Client, c.id) {
switch rm.Type {
case messageTypeCache:
var cm *CacheMessage
cm, err = convertMessage(&rm, 0)
if err == nil {
c.CacheChannel <- cm
}
case messageTypeEnable:
err = c.processEnabledMessage(&rm)
default:
c.l.Warn("Unknown message type: ", rm.Type)
cm, err = convertMessage(&rm, 0)
if err == nil {
c.CacheChannel <- cm
}
case messageTypeEnable:
err = c.processEnabledMessage(&rm)
default:
c.l.Warn("Unknown message type: ", rm.Type)
}
}
}
if err != nil {
c.l.Error("Processing error: ", err)
}
if err != nil {
c.l.Error("Processing error: ", err)
}
return err

View File

@ -25,6 +25,8 @@ import (
"github.com/sirupsen/logrus"
)
const defaultBlockingCleanUpInterval = 5 * time.Second
func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) {
cfgBlockType := cfg.BlockType
@ -362,7 +364,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
}
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
request *model.Request, logger *logrus.Entry) (*model.Response, error) {
request *model.Request, logger *logrus.Entry) (bool, *model.Response, error) {
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
whitelistOnlyAllowed := r.hasWhiteListOnlyAllowed(groupsToCheck)
@ -372,19 +374,26 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, domain); whitelisted {
logger.WithField("group", group).Debugf("domain is whitelisted")
return r.next.Resolve(request)
resp, err := r.next.Resolve(request)
return true, resp, err
}
if whitelistOnlyAllowed {
return r.handleBlocked(logger, request, question, "BLOCKED (WHITELIST ONLY)")
resp, err := r.handleBlocked(logger, request, question, "BLOCKED (WHITELIST ONLY)")
return true, resp, err
}
if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, domain); blocked {
return r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
resp, err := r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
return true, resp, err
}
}
return nil, nil
return false, nil, nil
}
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
@ -393,8 +402,8 @@ func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, err
groupsToCheck := r.groupsToCheckForClient(request)
if len(groupsToCheck) > 0 {
resp, err := r.handleBlacklist(groupsToCheck, request, logger)
if resp != nil || err != nil {
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
if handled {
return resp, err
}
}
@ -541,6 +550,7 @@ func (b zeroIPBlockHandler) handleBlock(question dns.Question, response *dns.Msg
zeroIP = net.IPv4zero
default:
response.Rcode = dns.RcodeNameError
return
}
@ -603,7 +613,7 @@ func (r *BlockingResolver) initFQDNIPCache() {
identifiers = append(identifiers, identifier)
}
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(5*time.Second),
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(defaultBlockingCleanUpInterval),
expirationcache.WithOnExpiredFn(func(key string) (val interface{}, ttl time.Duration) {
return r.queryForFQIdentifierIPs(key)
}))
@ -618,5 +628,6 @@ func (r *BlockingResolver) initFQDNIPCache() {
func isFQDN(in string) bool {
s := strings.Trim(in, ".")
return strings.Contains(s, ".")
}

View File

@ -130,8 +130,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
m.AnswerFn = func(t uint16, qName string) *dns.Msg {
if t == dns.TypeA && qName == "full.qualified.com." {
a, _ := util.NewMsgWithAnswer(qName, 60*60, dns.Type(dns.TypeA), "192.168.178.39")
return a
}
return nil
}
Bus().Publish(ApplicationStarted, "")

View File

@ -141,6 +141,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
host, port, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("dial error: %s", err)
return nil, err
}
@ -161,6 +162,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
ips, err := b.resolve(host, qTypes)
if err != nil {
log.Errorf("resolve error: %s", err)
return nil, err
}
@ -170,18 +172,20 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
// Use the standard dialer to actually connect
addrWithIP := net.JoinHostPort(ip.String(), port)
return dialer.DialContext(ctx, network, addrWithIP)
},
}
}
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
ips = make([]net.IP, 0, 2)
ips = make([]net.IP, 0, len(qTypes))
for _, qType := range qTypes {
qIPs, qErr := b.resolveType(hostname, qType)
if qErr != nil {
err = multierror.Append(err, qErr)
continue
}
@ -196,8 +200,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
}
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
ip := net.ParseIP(hostname)
if ip != nil {
if ip := net.ParseIP(hostname); ip != nil {
return []net.IP{ip}, nil
}
@ -236,14 +239,15 @@ type IPSet struct {
func (ips *IPSet) Current() net.IP {
idx := atomic.LoadUint32(&ips.index)
return ips.values[idx]
}
func (ips *IPSet) Next() {
old := ips.index
new := uint32(int(ips.index+1) % len(ips.values))
oldIP := ips.index
newIP := uint32(int(ips.index+1) % len(ips.values))
// We don't care about the result: if the call fails,
// it means the value was incremented by another goroutine
_ = atomic.CompareAndSwapUint32(&ips.index, old, new)
_ = atomic.CompareAndSwapUint32(&ips.index, oldIP, newIP)
}

View File

@ -58,6 +58,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
usedSystemResolver <- true
return nil, errors.New("don't actually do anything")
},
}

View File

@ -17,6 +17,8 @@ import (
"github.com/sirupsen/logrus"
)
const defaultCachingCleanUpInterval = 5 * time.Second
// CachingResolver caches answers from dns queries with their TTL time,
// to avoid external resolver calls for recurrent queries
type CachingResolver struct {
@ -58,7 +60,7 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) ChainedRe
}
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
cleanupOption := expirationcache.WithCleanUpInterval(5 * time.Second)
cleanupOption := expirationcache.WithCleanUpInterval(defaultCachingCleanUpInterval)
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
if cfg.Prefetching {
@ -91,6 +93,7 @@ func setupRedisCacheSubscriber(c *CachingResolver) {
// check if domain was queried > threshold in the time window
func (r *CachingResolver) isPrefetchingDomain(cacheKey string) bool {
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
return cnt != nil && cnt.(int) > r.prefetchThreshold
}
@ -108,6 +111,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
evt.Bus().Publish(evt.CachingDomainPrefetched, domainName)
return cacheValue{response.Res.Answer, true}, time.Duration(r.adjustTTLs(response.Res.Answer)) * time.Second
}
} else {
@ -122,6 +126,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
func (r *CachingResolver) Configuration() (result []string) {
if r.maxCacheTimeSec < 0 {
result = []string{"deactivated"}
return
}
@ -146,12 +151,12 @@ func (r *CachingResolver) Configuration() (result []string) {
// Resolve checks if the current query result is already in the cache and returns it
// or delegates to the next resolver
//nolint:gocognit,funlen
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
logger := withPrefix(request.Log, "caching_resolver")
if r.maxCacheTimeSec < 0 {
logger.Debug("skip cache")
return r.next.Resolve(request)
}

View File

@ -547,6 +547,7 @@ var _ = Describe("CachingResolver", func() {
Eventually(func() error {
resp, err = sut.Resolve(request)
return err
}, "50ms").Should(Succeed())
})

View File

@ -105,6 +105,21 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
return names
}
func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNames []string) {
for _, answer := range answer {
if t, ok := answer.(*dns.PTR); ok {
hostName := strings.TrimSuffix(t.Ptr, ".")
clientNames = append(clientNames, hostName)
}
}
if len(clientNames) == 0 {
clientNames = []string{fallbackIP.String()}
}
return
}
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
// try client mapping first
@ -114,49 +129,40 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
return
}
if r.externalResolver != nil {
reverse, _ := dns.ReverseAddr(ip.String())
resp, err := r.externalResolver.Resolve(&model.Request{
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
Log: logger,
})
if err != nil {
logger.Error("can't resolve client name: ", err)
return []string{ip.String()}
}
var clientNames []string
for _, answer := range resp.Res.Answer {
if t, ok := answer.(*dns.PTR); ok {
hostName := strings.TrimSuffix(t.Ptr, ".")
clientNames = append(clientNames, hostName)
}
}
if len(clientNames) == 0 {
clientNames = []string{ip.String()}
}
// optional: if singleNameOrder is set, use only one name in the defined order
if len(r.singleNameOrder) > 0 {
for _, i := range r.singleNameOrder {
if i > 0 && int(i) <= len(clientNames) {
result = []string{clientNames[i-1]}
break
}
}
} else {
result = clientNames
}
logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s) from external resolver")
} else {
result = []string{ip.String()}
if r.externalResolver == nil {
return []string{ip.String()}
}
reverse, _ := dns.ReverseAddr(ip.String())
resp, err := r.externalResolver.Resolve(&model.Request{
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
Log: logger,
})
if err != nil {
logger.Error("can't resolve client name: ", err)
return []string{ip.String()}
}
clientNames := extractClientNamesFromAnswer(resp.Res.Answer, ip)
// optional: if singleNameOrder is set, use only one name in the defined order
if len(r.singleNameOrder) > 0 {
for _, i := range r.singleNameOrder {
if i > 0 && int(i) <= len(clientNames) {
result = []string{clientNames[i-1]}
break
}
}
} else {
result = clientNames
}
logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s) from external resolver")
return result
}

View File

@ -51,31 +51,42 @@ func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
return
}
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
domain := domainFromQuestion
if strings.Contains(domainFromQuestion, ".") {
// try with domain with and without sub-domains
for len(domain) > 0 {
if resolver, found := r.mapping[domain]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
return true, resp, err
}
if i := strings.Index(domain, "."); i >= 0 {
domain = domain[i+1:]
} else {
break
}
}
} else if resolver, found := r.mapping["."]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
return true, resp, err
}
return false, nil, nil
}
// Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := withPrefix(request.Log, "conditional_resolver")
if len(r.mapping) > 0 {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
domain := domainFromQuestion
if !strings.Contains(domainFromQuestion, ".") {
if resolver, found := r.mapping["."]; found {
return r.internalResolve(resolver, domainFromQuestion, domain, request)
}
} else {
// try with domain with and without sub-domains
for len(domain) > 0 {
if resolver, found := r.mapping[domain]; found {
return r.internalResolve(resolver, domainFromQuestion, domain, request)
}
if i := strings.Index(domain, "."); i >= 0 {
domain = domain[i+1:]
} else {
break
}
}
resolved, resp, err := r.processRequest(request)
if resolved {
return resp, err
}
}

View File

@ -88,6 +88,54 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
return nil
}
func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Response {
logger := withPrefix(request.Log, "custom_dns_resolver")
response := new(dns.Msg)
response.SetReply(request.Req)
question := request.Req.Question[0]
domain := util.ExtractDomain(question)
for len(domain) > 0 {
ips, found := r.mapping[domain]
if found {
for _, ip := range ips {
if isSupportedType(ip, question) {
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
response.Answer = append(response.Answer, rr)
}
}
if len(response.Answer) > 0 {
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer),
"domain": domain,
}).Debugf("returning custom dns entry")
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
}
// Mapping exists for this domain, but for another type
if !r.filterUnmappedTypes {
// go to next resolver
break
}
// return NOERROR with empty result
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
}
if i := strings.Index(domain, "."); i >= 0 {
domain = domain[i+1:]
} else {
break
}
}
return nil
}
// Resolve uses internal mapping to resolve the query
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
logger := withPrefix(request.Log, "custom_dns_resolver")
@ -98,46 +146,9 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
}
if len(r.mapping) > 0 {
response := new(dns.Msg)
response.SetReply(request.Req)
question := request.Req.Question[0]
domain := util.ExtractDomain(question)
for len(domain) > 0 {
ips, found := r.mapping[domain]
if found {
for _, ip := range ips {
if isSupportedType(ip, question) {
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
response.Answer = append(response.Answer, rr)
}
}
if len(response.Answer) > 0 {
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer),
"domain": domain,
}).Debugf("returning custom dns entry")
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
}
// Mapping exists for this domain, but for another type
if !r.filterUnmappedTypes {
// go to next resolver
break
}
// return NOERROR with empty result
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
}
if i := strings.Index(domain, "."); i >= 0 {
domain = domain[i+1:]
} else {
break
}
resp := r.processRequest(request)
if resp != nil {
return resp, nil
}
}

View File

@ -76,21 +76,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
domain := util.ExtractDomain(question)
for _, host := range r.hosts {
if host.Hostname == domain {
if isSupportedType(host.IP, question) {
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
response.Answer = append(response.Answer, rr)
}
}
for _, alias := range host.Aliases {
if alias == domain {
if isSupportedType(host.IP, question) {
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
response.Answer = append(response.Answer, rr)
}
}
}
response.Answer = append(response.Answer, r.processHostEntry(host, domain, question)...)
}
if len(response.Answer) > 0 {
@ -108,6 +94,26 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
return r.next.Resolve(request)
}
func (r *HostsFileResolver) processHostEntry(host host, domain string, question dns.Question) (result []dns.RR) {
if host.Hostname == domain {
if isSupportedType(host.IP, question) {
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
result = append(result, rr)
}
}
for _, alias := range host.Aliases {
if alias == domain {
if isSupportedType(host.IP, question) {
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
result = append(result, rr)
}
}
}
return
}
func (r *HostsFileResolver) Configuration() (result []string) {
if r.HostsFilePath != "" && len(r.hosts) != 0 {
result = append(result, fmt.Sprintf("hosts file path: %s", r.HostsFilePath))
@ -127,9 +133,7 @@ func NewHostsFileResolver(cfg config.HostsFileConfig) ChainedResolver {
refreshPeriod: time.Duration(cfg.RefreshPeriod),
}
err := r.parseHostsFile()
if err != nil {
if err := r.parseHostsFile(); err != nil {
logger := logger(hostsFileResolverLogger)
logger.Warnf("cannot parse hosts file: %s, hosts file resolving is disabled", r.HostsFilePath)
r.HostsFilePath = ""
@ -147,6 +151,8 @@ type host struct {
}
func (r *HostsFileResolver) parseHostsFile() error {
const minColumnCount = 2
if r.HostsFilePath == "" {
return nil
}
@ -177,7 +183,7 @@ func (r *HostsFileResolver) parseHostsFile() error {
fields = strings.Fields(trimmed[:end])
}
if len(fields) < 2 {
if len(fields) < minColumnCount {
// Skip invalid entry
continue
}
@ -191,7 +197,7 @@ func (r *HostsFileResolver) parseHostsFile() error {
h.IP = net.ParseIP(fields[0])
h.Hostname = fields[1]
if len(fields) > 2 {
if len(fields) > minColumnCount {
for i := 2; i < len(fields); i++ {
h.Aliases = append(h.Aliases, fields[i])
}

View File

@ -5,7 +5,6 @@ import (
"net"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync/atomic"
@ -29,6 +28,7 @@ type MockResolver struct {
func (r *MockResolver) Configuration() []string {
args := r.Called()
return args.Get(0).([]string)
}
@ -182,6 +182,7 @@ func (t *MockUDPUpstreamServer) WithAnswerError(errorCode int) *MockUDPUpstreamS
func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response *dns.Msg)) *MockUDPUpstreamServer {
t.answerFn = fn
return t
}
@ -195,26 +196,33 @@ func (t *MockUDPUpstreamServer) Close() {
}
}
func (t *MockUDPUpstreamServer) Start() config.Upstream {
func CreateConnection() *net.UDPConn {
a, err := net.ResolveUDPAddr("udp4", ":0")
util.FatalOnError("can't resolve address: ", err)
ln, err := net.ListenUDP("udp4", a)
util.FatalOnError("can't create connection: ", err)
t.ln = ln
return ln
}
func (t *MockUDPUpstreamServer) Start() config.Upstream {
ln := CreateConnection()
ladr := ln.LocalAddr().String()
host := strings.Split(ladr, ":")[0]
p, err := strconv.ParseUint(strings.Split(ladr, ":")[1], 10, 16)
p, err := config.ConvertPort(strings.Split(ladr, ":")[1])
util.FatalOnError("can't convert port: ", err)
port := uint16(p)
port := p
t.ln = ln
go func() {
const bufferSize = 1024
for {
buffer := make([]byte, 1024)
buffer := make([]byte, bufferSize)
n, addr, err := ln.ReadFromUDP(buffer)
if err != nil {
@ -233,6 +241,7 @@ func (t *MockUDPUpstreamServer) Start() config.Upstream {
// nil should indicate an error
if response == nil {
_, _ = ln.WriteToUDP([]byte("dummy"), addr)
continue
}

View File

@ -18,6 +18,7 @@ import (
const (
upstreamDefaultCfgName = "default"
parallelResolverLogger = "parallel_best_resolver"
resolverCount = 2
)
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
@ -133,13 +134,14 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
if len(resolvers) == 1 {
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
return resolvers[0].resolver.Resolve(request)
}
r1, r2 := pickRandom(resolvers)
logger.Debugf("using %s and %s as resolver", r1.resolver, r2.resolver)
ch := make(chan requestResponse, 2)
ch := make(chan requestResponse, resolverCount)
var collectedErrors []error
@ -152,7 +154,7 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
go resolve(request, r2, ch)
//nolint: gosimple
for len(collectedErrors) < 2 {
for len(collectedErrors) < resolverCount {
select {
case result := <-ch:
if result.err != nil {
@ -163,6 +165,7 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
"resolver": r1.resolver,
"answer": util.AnswerToString(result.response.Res.Answer),
}).Debug("using response from resolver")
return result.response, nil
}
}
@ -181,14 +184,17 @@ func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upst
}
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
const errorWindowInSec = 60
var choices []weightedrand.Choice
for _, res := range in {
var weight float64 = 60
var weight float64 = errorWindowInSec
if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour {
// reduce weight: consider last error time
weight = math.Max(1, weight-(60-time.Since(res.lastErrorTime.Load().(time.Time)).Minutes()))
lastErrorTime := res.lastErrorTime.Load().(time.Time)
weight = math.Max(1, weight-(errorWindowInSec-time.Since(lastErrorTime).Minutes()))
}
if exclude != res.resolver {

View File

@ -43,6 +43,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
time.Sleep(50 * time.Millisecond)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(slowTestUpstream.Close)
@ -71,6 +72,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
time.Sleep(50 * time.Millisecond)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(slowTestUpstream.Close)

View File

@ -14,6 +14,7 @@ const (
cleanUpRunPeriod = 12 * time.Hour
queryLoggingResolverPrefix = "query_logging_resolver"
logChanCap = 1000
defaultFlushPeriod = 30 * time.Second
)
// QueryLoggingResolver writes query information (question, answer, duration, ...)
@ -40,14 +41,15 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
case config.QueryLogTypeCsvClient:
writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays)
case config.QueryLogTypeMysql:
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, 30*time.Second)
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, defaultFlushPeriod)
case config.QueryLogTypePostgresql:
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, 30*time.Second)
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, defaultFlushPeriod)
case config.QueryLogTypeConsole:
writer = querylog.NewLoggerWriter()
case config.QueryLogTypeNone:
writer = querylog.NewNoneWriter()
}
return err
},
retry.Attempts(uint(cfg.CreationAttempts)),
@ -130,7 +132,7 @@ func (r *QueryLoggingResolver) writeLog() {
r.writer.Write(logEntry)
halfCap := cap(r.logChan) / 2
halfCap := cap(r.logChan) / 2 // nolint:gomnd
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
if len(r.logChan) > halfCap {

View File

@ -80,6 +80,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
if response == NoResponse {
// Inner resolver had no response, continue with the normal chain
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
return r.next.Resolve(request)
}
@ -128,6 +129,7 @@ func (r *RewriterResolver) rewriteDomain(domain string) (string, string) {
for k, v := range r.rewrite {
if strings.HasSuffix(domain, "."+k) {
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v
return newDomain, k
}
}

View File

@ -122,6 +122,7 @@ var _ = Describe("RewriterResolver", func() {
mNext.On("Resolve", mock.Anything)
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
return mNextResponse, nil
}
})

View File

@ -23,7 +23,8 @@ import (
)
const (
dnsContentType = "application/dns-message"
dnsContentType = "application/dns-message"
defaultTLSHandshakeTimeout = 5 * time.Second
)
// nolint:gochecknoglobals
@ -66,7 +67,7 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
TLSHandshakeTimeout: 5 * time.Second,
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
},
Timeout: timeout,
},
@ -246,6 +247,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
"response_time_ms": rtt.Milliseconds(),
}).Debugf("received response from upstream")
}
return err
},
retry.Attempts(retryAttempts),

View File

@ -131,6 +131,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
response, err := util.NewMsgWithAnswer("example.com", 123, dns.Type(dns.TypeA), "123.124.122.122")
Expect(err).Should(Succeed())
return response
}
})

View File

@ -25,6 +25,10 @@ import (
"github.com/sirupsen/logrus"
)
const (
maxUDPBufferSize = 65535
)
// Server controls the endpoints for DNS and HTTP
type Server struct {
dnsServers []*dns.Server
@ -213,7 +217,7 @@ func createUDPServer(address string) (*dns.Server, error) {
NotifyStartedFunc: func() {
logger().Infof("UDP server is up and running on address %s", address)
},
UDPSize: 65535,
UDPSize: maxUDPBufferSize,
}, nil
}
@ -304,7 +308,9 @@ func (s *Server) printConfiguration() {
}
func toMB(b uint64) uint64 {
return b / 1024 / 1024
const bytesInKB = 1024
return b / bytesInKB / bytesInKB
}
// Start starts the server

View File

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"strings"
"time"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/config"
@ -26,6 +27,7 @@ import (
const (
dohMessageLimit = 512
dnsContentType = "application/dns-message"
corsMaxAge = 5 * time.Minute
)
func (s *Server) registerAPIEndpoints(router *chi.Mux) {
@ -109,6 +111,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)
return
}
@ -120,6 +123,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
b, err := resResponse.Res.Pack()
if err != nil {
logAndResponseWithError(err, "can't serialize message: ", rw)
return
}
@ -163,6 +167,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
if err != nil {
logAndResponseWithError(err, "can't read request: ", rw)
return
}
@ -189,6 +194,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)
return
}
@ -201,6 +207,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
if err != nil {
logAndResponseWithError(err, "unable to marshal response: ", rw)
return
}
@ -289,7 +296,7 @@ func configureCorsHandler(router *chi.Mux) {
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300,
MaxAge: int(corsMaxAge.Seconds()),
})
router.Use(crs.Handler)
}

View File

@ -44,6 +44,7 @@ var _ = BeforeSuite(func() {
)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(googleMockUpstream.Close)
@ -54,6 +55,7 @@ var _ = BeforeSuite(func() {
)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(fritzboxMockUpstream.Close)
@ -64,6 +66,7 @@ var _ = BeforeSuite(func() {
)
Expect(err).Should(Succeed())
return response
})
DeferCleanup(clientMockUpstream.Close)
@ -153,6 +156,7 @@ var _ = Describe("Running DNS server", func() {
for res != nil {
if t, ok := res.(*resolver.ClientNamesResolver); ok {
t.FlushCache()
break
}
if c, ok := res.(resolver.ChainedResolver); ok {
@ -561,10 +565,9 @@ var _ = Describe("Running DNS server", func() {
Expect(err).Should(Succeed())
errChan := make(chan error, 10)
// start server
go func() {
server.Start(errChan)
}()
go server.Start(errChan)
DeferCleanup(server.Stop)

View File

@ -193,7 +193,8 @@ func Chunks(s string, chunkSize int) []string {
// GenerateCacheKey return cacheKey by query type/domain
func GenerateCacheKey(qType dns.Type, qName string) string {
b := make([]byte, 2+len(qName))
const qTypeLength = 2
b := make([]byte, qTypeLength+len(qName))
binary.BigEndian.PutUint16(b, uint16(qType))
copy(b[2:], strings.ToLower(qName))
@ -224,5 +225,6 @@ func CidrContainsIP(cidr string, ip net.IP) bool {
// ClientNameMatchesGroupName checks if a group with optional wildcards contains a client name
func ClientNameMatchesGroupName(group string, clientName string) bool {
match, _ := filepath.Match(group, clientName)
return match
}