mirror of https://github.com/0xERR0R/blocky.git
update golangci-lint (#510)
* update golangci-lint * enable gomnd linter * enable asciicheck linter * enable bidichk linter * enable durationcheck linter * enable errchkjson linter * enable errorlint linter * enable exhaustive linter * enable gomoddirectives linter * enable gomodguard guard * enable grouper linter * enable grouper and ifshort linters * enable importas linter * enable makezero linter * enable nestif linter * enable nilerr linter * enable nilnil linter * enable nlreturn linter * enable nolintlint linter * enable predeclared linter * enable sqlclosecheck linter * enable tenv linter * enable wastedassign linter
This commit is contained in:
parent
41febafd41
commit
a4b89537db
|
@ -16,5 +16,5 @@ jobs:
|
|||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: v1.43.0
|
||||
version: v1.45.2
|
||||
args: --timeout 5m0s
|
||||
|
|
|
@ -1,41 +1,63 @@
|
|||
linters:
|
||||
enable:
|
||||
- govet
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- unused
|
||||
- gosimple
|
||||
- structcheck
|
||||
- varcheck
|
||||
- ineffassign
|
||||
- deadcode
|
||||
- typecheck
|
||||
- asciicheck
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- revive
|
||||
- stylecheck
|
||||
- gosec
|
||||
- unconvert
|
||||
- dupl
|
||||
- goconst
|
||||
- gocyclo
|
||||
- gocognit
|
||||
- gofmt
|
||||
- goimports
|
||||
- deadcode
|
||||
- depguard
|
||||
- misspell
|
||||
- lll
|
||||
- unparam
|
||||
- dogsled
|
||||
- dupl
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- errchkjson
|
||||
- errorlint
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- funlen
|
||||
- gochecknoglobals
|
||||
- gochecknoinits
|
||||
- gocognit
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- godox
|
||||
- gofmt
|
||||
- goimports
|
||||
- gomnd
|
||||
- gomoddirectives
|
||||
- gomodguard
|
||||
- gosec
|
||||
- gosimple
|
||||
- govet
|
||||
- grouper
|
||||
- ifshort
|
||||
- importas
|
||||
- ineffassign
|
||||
- lll
|
||||
- makezero
|
||||
- misspell
|
||||
- nakedret
|
||||
- nestif
|
||||
- nilerr
|
||||
- nilnil
|
||||
- nlreturn
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- predeclared
|
||||
- revive
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- structcheck
|
||||
- stylecheck
|
||||
- tenv
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- varcheck
|
||||
- wastedassign
|
||||
- whitespace
|
||||
- wsl
|
||||
- exportloopref
|
||||
disable:
|
||||
- noctx
|
||||
- scopelint
|
||||
|
@ -46,6 +68,11 @@ linters:
|
|||
- unused
|
||||
fast: false
|
||||
|
||||
linters-settings:
|
||||
gomnd:
|
||||
ignored-numbers:
|
||||
- '0666'
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
# Exclude some linters from running on tests files.
|
||||
|
|
2
Makefile
2
Makefile
|
@ -38,7 +38,7 @@ race: ## run tests with race detector
|
|||
go test -race -short ./...
|
||||
|
||||
lint: build ## run golangcli-lint checks
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.43.0
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.45.2
|
||||
$(shell go env GOPATH)/bin/golangci-lint run
|
||||
|
||||
run: build ## Build and run binary
|
||||
|
|
|
@ -30,6 +30,7 @@ func (b *BlockingControlMock) EnableBlocking() {
|
|||
}
|
||||
func (b *BlockingControlMock) DisableBlocking(_ time.Duration, disableGroups []string) error {
|
||||
b.enabled = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -139,8 +139,7 @@ func isExpired(el *element) bool {
|
|||
}
|
||||
|
||||
func calculateRemainTTL(expiresEpoch int64) time.Duration {
|
||||
now := time.Now().UnixMilli()
|
||||
if now < expiresEpoch {
|
||||
if now := time.Now().UnixMilli(); now < expiresEpoch {
|
||||
return time.Duration(expiresEpoch-now) * time.Millisecond
|
||||
}
|
||||
|
||||
|
|
|
@ -99,6 +99,7 @@ func (cache regexCache) Contains(searchString string) bool {
|
|||
for _, regex := range cache {
|
||||
if regex.MatchString(searchString) {
|
||||
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ func refreshList(_ *cobra.Command, _ []string) error {
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
|
||||
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ func query(cmd *cobra.Command, args []string) error {
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
|
||||
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
|
|
20
cmd/root.go
20
cmd/root.go
|
@ -3,7 +3,6 @@ package cmd
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -20,6 +19,12 @@ var (
|
|||
apiPort uint16
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 4000
|
||||
defaultHost = "localhost"
|
||||
defaultConfigPath = "./config.yml"
|
||||
)
|
||||
|
||||
// NewRootCommand creates a new root cli command instance
|
||||
func NewRootCommand() *cobra.Command {
|
||||
c := &cobra.Command{
|
||||
|
@ -35,9 +40,9 @@ Complete documentation is available at https://github.com/0xERR0R/blocky`,
|
|||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
c.PersistentFlags().StringVarP(&configPath, "config", "c", "./config.yml", "path to config file")
|
||||
c.PersistentFlags().StringVar(&apiHost, "apiHost", "localhost", "host of blocky (API). Default overridden by config and CLI.") // nolint:lll
|
||||
c.PersistentFlags().Uint16Var(&apiPort, "apiPort", 4000, "port of blocky (API). Default overridden by config and CLI.")
|
||||
c.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "path to config file")
|
||||
c.PersistentFlags().StringVar(&apiHost, "apiHost", defaultHost, "host of blocky (API). Default overridden by config and CLI.") // nolint:lll
|
||||
c.PersistentFlags().Uint16Var(&apiPort, "apiPort", defaultPort, "port of blocky (API). Default overridden by config and CLI.") // nolint:lll
|
||||
|
||||
c.AddCommand(newRefreshCommand(),
|
||||
NewQueryCommand(),
|
||||
|
@ -73,15 +78,14 @@ func initConfig() {
|
|||
|
||||
apiHost = strings.Join(split[:lastIdx], ":")
|
||||
|
||||
var p uint64
|
||||
p, err := strconv.ParseUint(strings.TrimSpace(split[lastIdx]), 10, 16)
|
||||
|
||||
port, err := config.ConvertPort(split[lastIdx])
|
||||
if err != nil {
|
||||
util.FatalOnError("can't convert port to number (1 - 65535)", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
apiPort = uint16(p)
|
||||
apiPort = port
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,8 @@ func startServer(_ *cobra.Command, _ []string) error {
|
|||
return fmt.Errorf("can't start server: %w", err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 10)
|
||||
const errChanSize = 10
|
||||
errChan := make(chan error, errChanSize)
|
||||
|
||||
srv.Start(errChan)
|
||||
|
||||
|
|
|
@ -23,6 +23,12 @@ import (
|
|||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
udpPort = 53
|
||||
tlsPort = 853
|
||||
httpsPort = 443
|
||||
)
|
||||
|
||||
// NetProtocol resolver protocol ENUM(
|
||||
// tcp+udp // TCP and UDP protocols
|
||||
// tcp-tls // TCP-TLS protocol
|
||||
|
@ -60,6 +66,7 @@ func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
|
|||
|
||||
func (s QTypeSet) Contains(qType dns.Type) bool {
|
||||
_, found := s[QType(qType)]
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
|
@ -79,9 +86,9 @@ func (c *Duration) String() string {
|
|||
|
||||
// nolint:gochecknoglobals
|
||||
var netDefaultPort = map[NetProtocol]uint16{
|
||||
NetProtocolTcpUdp: 53,
|
||||
NetProtocolTcpTls: 853,
|
||||
NetProtocolHttps: 443,
|
||||
NetProtocolTcpUdp: udpPort,
|
||||
NetProtocolTcpTls: tlsPort,
|
||||
NetProtocolHttps: httpsPort,
|
||||
}
|
||||
|
||||
// Upstream is the definition of external DNS server
|
||||
|
@ -252,12 +259,14 @@ func (c *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
|||
// duration is defined as number without unit
|
||||
// use minutes to ensure back compatibility
|
||||
*c = Duration(time.Duration(minutes) * time.Minute)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(input)
|
||||
if err == nil {
|
||||
*c = Duration(duration)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -320,15 +329,15 @@ func ParseUpstream(upstream string) (Upstream, error) {
|
|||
|
||||
// string contains host:port
|
||||
if err == nil {
|
||||
var p uint64
|
||||
p, err = strconv.ParseUint(strings.TrimSpace(portString), 10, 16)
|
||||
p, err := ConvertPort(portString)
|
||||
|
||||
if err != nil {
|
||||
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
|
||||
|
||||
return Upstream{}, err
|
||||
}
|
||||
|
||||
port = uint16(p)
|
||||
port = p
|
||||
} else {
|
||||
// only host, use default port
|
||||
host = upstream
|
||||
|
@ -340,9 +349,7 @@ func ParseUpstream(upstream string) (Upstream, error) {
|
|||
}
|
||||
|
||||
// validate hostname or ip
|
||||
ip := net.ParseIP(host)
|
||||
|
||||
if ip == nil {
|
||||
if ip := net.ParseIP(host); ip == nil {
|
||||
// is not IP
|
||||
if !validDomain.MatchString(host) {
|
||||
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
|
||||
|
@ -544,6 +551,7 @@ func LoadConfig(path string, mandatory bool) (*Config, error) {
|
|||
// config file does not exist
|
||||
// return config with default values
|
||||
config = &cfg
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
@ -596,3 +604,20 @@ func validateConfig(cfg *Config) (err error) {
|
|||
func GetConfig() *Config {
|
||||
return config
|
||||
}
|
||||
|
||||
// ConvertPort converts string representation into a valid port (0 - 65535)
|
||||
func ConvertPort(in string) (uint16, error) {
|
||||
const (
|
||||
base = 10
|
||||
bitSize = 16
|
||||
)
|
||||
|
||||
var p uint64
|
||||
p, err := strconv.ParseUint(strings.TrimSpace(in), base, bitSize)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint16(p), nil
|
||||
}
|
||||
|
|
|
@ -269,6 +269,7 @@ bootstrapDns:
|
|||
u := &Upstream{}
|
||||
err := u.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "tcp+udp:1.2.3.4"
|
||||
|
||||
return nil
|
||||
|
||||
})
|
||||
|
@ -292,6 +293,7 @@ bootstrapDns:
|
|||
l := &ListenConfig{}
|
||||
err := l.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "55,:56"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -311,6 +313,7 @@ bootstrapDns:
|
|||
d := Duration(0)
|
||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "1m20s"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -321,6 +324,7 @@ bootstrapDns:
|
|||
d := Duration(0)
|
||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "wrong"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
@ -342,6 +346,7 @@ bootstrapDns:
|
|||
c := &ConditionalUpstreamMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -364,6 +369,7 @@ bootstrapDns:
|
|||
c := &CustomDNSMapping{}
|
||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -385,6 +391,7 @@ bootstrapDns:
|
|||
t := QType(0)
|
||||
err := t.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "AAAA"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(Succeed())
|
||||
|
@ -395,6 +402,7 @@ bootstrapDns:
|
|||
t := QType(0)
|
||||
err := t.UnmarshalYAML(func(i interface{}) error {
|
||||
*i.(*string) = "WRONGTYPE"
|
||||
|
||||
return nil
|
||||
})
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
|
|
@ -103,10 +103,10 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
|||
func() error {
|
||||
var resp *http.Response
|
||||
var httpErr error
|
||||
//nolint:bodyclose
|
||||
if resp, httpErr = client.Get(link); httpErr == nil {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
body = resp.Body
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -118,6 +118,7 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
|||
if errors.As(httpErr, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
|
||||
return &TransientError{inner: netErr}
|
||||
}
|
||||
|
||||
return httpErr
|
||||
},
|
||||
retry.Attempts(d.downloadAttempts),
|
||||
|
|
|
@ -163,6 +163,7 @@ Loop:
|
|||
}
|
||||
default:
|
||||
close(c)
|
||||
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
|
@ -307,9 +308,7 @@ func processLine(line string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Fields(line)
|
||||
|
||||
if len(parts) > 0 {
|
||||
if parts := strings.Fields(line); len(parts) > 0 {
|
||||
host := parts[len(parts)-1]
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
|
|
|
@ -152,6 +152,7 @@ var _ = Describe("ListCache", func() {
|
|||
By("List couldn't be loaded due to 404 err", func() {
|
||||
Eventually(func() bool {
|
||||
found, _ := sut.Match("blocked1.com", []string{"gr1"})
|
||||
|
||||
return found
|
||||
}, "1s").Should(BeFalse())
|
||||
})
|
||||
|
@ -313,5 +314,6 @@ type MockDownloader struct {
|
|||
|
||||
func (m *MockDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||
fn := <-m.data
|
||||
|
||||
return fn()
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
result := writer.db.Find(&logEntry{})
|
||||
|
||||
result.Count(&res)
|
||||
|
||||
return res
|
||||
}, "1s").Should(BeNumerically("==", 1))
|
||||
})
|
||||
|
@ -91,6 +92,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
result := writer.db.Find(&logEntry{})
|
||||
|
||||
result.Count(&res)
|
||||
|
||||
return res
|
||||
}, "1s").Should(BeNumerically("==", 2))
|
||||
|
||||
|
@ -102,6 +104,7 @@ var _ = Describe("DatabaseWriter", func() {
|
|||
result := writer.db.Find(&logEntry{})
|
||||
|
||||
result.Count(&res)
|
||||
|
||||
return res
|
||||
}, "1s").Should(BeNumerically("==", 1))
|
||||
})
|
||||
|
|
|
@ -72,6 +72,8 @@ func (d *FileWriter) Write(entry *LogEntry) {
|
|||
|
||||
// CleanUp deletes old log files
|
||||
func (d *FileWriter) CleanUp() {
|
||||
const hoursPerDay = 24
|
||||
|
||||
logger := log.PrefixedLog(loggerPrefixFileWriter)
|
||||
|
||||
logger.Trace("starting clean up")
|
||||
|
@ -85,7 +87,7 @@ func (d *FileWriter) CleanUp() {
|
|||
if strings.HasSuffix(f.Name(), ".log") && len(f.Name()) > 10 {
|
||||
t, err := time.Parse("2006-01-02", f.Name()[:10])
|
||||
if err == nil {
|
||||
differenceDays := uint64(time.Since(t).Hours() / 24)
|
||||
differenceDays := uint64(time.Since(t).Hours() / hoursPerDay)
|
||||
if d.logRetentionDays > 0 && differenceDays > d.logRetentionDays {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"file": f.Name(),
|
||||
|
|
|
@ -69,7 +69,7 @@ type Client struct {
|
|||
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||
// disable redis if no address is provided
|
||||
if cfg == nil || len(cfg.Address) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil // nolint:nilnil
|
||||
}
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
|
@ -164,7 +164,12 @@ func (c *Client) startup() error {
|
|||
select {
|
||||
// received message from subscription
|
||||
case msg := <-ps.Channel():
|
||||
err = c.processReceivedMessage(msg)
|
||||
c.l.Debug("Received message: ", msg)
|
||||
|
||||
if msg != nil && len(msg.Payload) > 0 {
|
||||
// message is not empty
|
||||
err = c.processReceivedMessage(msg)
|
||||
}
|
||||
// publish message from buffer
|
||||
case s := <-c.sendBuffer:
|
||||
c.publishMessageFromBuffer(s)
|
||||
|
@ -201,34 +206,30 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
|
|||
}
|
||||
|
||||
func (c *Client) processReceivedMessage(msg *redis.Message) (err error) {
|
||||
c.l.Debug("Received message: ", msg)
|
||||
// message is not empty
|
||||
if msg != nil && len(msg.Payload) > 0 {
|
||||
var rm redisMessage
|
||||
var rm redisMessage
|
||||
|
||||
err = json.Unmarshal([]byte(msg.Payload), &rm)
|
||||
if err == nil {
|
||||
// message was sent from a different blocky instance
|
||||
if !bytes.Equal(rm.Client, c.id) {
|
||||
switch rm.Type {
|
||||
case messageTypeCache:
|
||||
var cm *CacheMessage
|
||||
err = json.Unmarshal([]byte(msg.Payload), &rm)
|
||||
if err == nil {
|
||||
// message was sent from a different blocky instance
|
||||
if !bytes.Equal(rm.Client, c.id) {
|
||||
switch rm.Type {
|
||||
case messageTypeCache:
|
||||
var cm *CacheMessage
|
||||
|
||||
cm, err = convertMessage(&rm, 0)
|
||||
if err == nil {
|
||||
c.CacheChannel <- cm
|
||||
}
|
||||
case messageTypeEnable:
|
||||
err = c.processEnabledMessage(&rm)
|
||||
default:
|
||||
c.l.Warn("Unknown message type: ", rm.Type)
|
||||
cm, err = convertMessage(&rm, 0)
|
||||
if err == nil {
|
||||
c.CacheChannel <- cm
|
||||
}
|
||||
case messageTypeEnable:
|
||||
err = c.processEnabledMessage(&rm)
|
||||
default:
|
||||
c.l.Warn("Unknown message type: ", rm.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.l.Error("Processing error: ", err)
|
||||
}
|
||||
if err != nil {
|
||||
c.l.Error("Processing error: ", err)
|
||||
}
|
||||
|
||||
return err
|
||||
|
|
|
@ -25,6 +25,8 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultBlockingCleanUpInterval = 5 * time.Second
|
||||
|
||||
func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) {
|
||||
cfgBlockType := cfg.BlockType
|
||||
|
||||
|
@ -362,7 +364,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
|
|||
}
|
||||
|
||||
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
||||
request *model.Request, logger *logrus.Entry) (*model.Response, error) {
|
||||
request *model.Request, logger *logrus.Entry) (bool, *model.Response, error) {
|
||||
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
|
||||
whitelistOnlyAllowed := r.hasWhiteListOnlyAllowed(groupsToCheck)
|
||||
|
||||
|
@ -372,19 +374,26 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
|||
|
||||
if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, domain); whitelisted {
|
||||
logger.WithField("group", group).Debugf("domain is whitelisted")
|
||||
return r.next.Resolve(request)
|
||||
|
||||
resp, err := r.next.Resolve(request)
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
||||
if whitelistOnlyAllowed {
|
||||
return r.handleBlocked(logger, request, question, "BLOCKED (WHITELIST ONLY)")
|
||||
resp, err := r.handleBlocked(logger, request, question, "BLOCKED (WHITELIST ONLY)")
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
||||
if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, domain); blocked {
|
||||
return r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
|
||||
resp, err := r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
|
||||
|
@ -393,8 +402,8 @@ func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
groupsToCheck := r.groupsToCheckForClient(request)
|
||||
|
||||
if len(groupsToCheck) > 0 {
|
||||
resp, err := r.handleBlacklist(groupsToCheck, request, logger)
|
||||
if resp != nil || err != nil {
|
||||
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
|
||||
if handled {
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
@ -541,6 +550,7 @@ func (b zeroIPBlockHandler) handleBlock(question dns.Question, response *dns.Msg
|
|||
zeroIP = net.IPv4zero
|
||||
default:
|
||||
response.Rcode = dns.RcodeNameError
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -603,7 +613,7 @@ func (r *BlockingResolver) initFQDNIPCache() {
|
|||
identifiers = append(identifiers, identifier)
|
||||
}
|
||||
|
||||
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(5*time.Second),
|
||||
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(defaultBlockingCleanUpInterval),
|
||||
expirationcache.WithOnExpiredFn(func(key string) (val interface{}, ttl time.Duration) {
|
||||
return r.queryForFQIdentifierIPs(key)
|
||||
}))
|
||||
|
@ -618,5 +628,6 @@ func (r *BlockingResolver) initFQDNIPCache() {
|
|||
|
||||
func isFQDN(in string) bool {
|
||||
s := strings.Trim(in, ".")
|
||||
|
||||
return strings.Contains(s, ".")
|
||||
}
|
||||
|
|
|
@ -130,8 +130,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
|||
m.AnswerFn = func(t uint16, qName string) *dns.Msg {
|
||||
if t == dns.TypeA && qName == "full.qualified.com." {
|
||||
a, _ := util.NewMsgWithAnswer(qName, 60*60, dns.Type(dns.TypeA), "192.168.178.39")
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Bus().Publish(ApplicationStarted, "")
|
||||
|
|
|
@ -141,6 +141,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
|||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("dial error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -161,6 +162,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
|||
ips, err := b.resolve(host, qTypes)
|
||||
if err != nil {
|
||||
log.Errorf("resolve error: %s", err)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -170,18 +172,20 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
|||
|
||||
// Use the standard dialer to actually connect
|
||||
addrWithIP := net.JoinHostPort(ip.String(), port)
|
||||
|
||||
return dialer.DialContext(ctx, network, addrWithIP)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
||||
ips = make([]net.IP, 0, 2)
|
||||
ips = make([]net.IP, 0, len(qTypes))
|
||||
|
||||
for _, qType := range qTypes {
|
||||
qIPs, qErr := b.resolveType(hostname, qType)
|
||||
if qErr != nil {
|
||||
err = multierror.Append(err, qErr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -196,8 +200,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
|
|||
}
|
||||
|
||||
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
|
||||
ip := net.ParseIP(hostname)
|
||||
if ip != nil {
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return []net.IP{ip}, nil
|
||||
}
|
||||
|
||||
|
@ -236,14 +239,15 @@ type IPSet struct {
|
|||
|
||||
func (ips *IPSet) Current() net.IP {
|
||||
idx := atomic.LoadUint32(&ips.index)
|
||||
|
||||
return ips.values[idx]
|
||||
}
|
||||
|
||||
func (ips *IPSet) Next() {
|
||||
old := ips.index
|
||||
new := uint32(int(ips.index+1) % len(ips.values))
|
||||
oldIP := ips.index
|
||||
newIP := uint32(int(ips.index+1) % len(ips.values))
|
||||
|
||||
// We don't care about the result: if the call fails,
|
||||
// it means the value was incremented by another goroutine
|
||||
_ = atomic.CompareAndSwapUint32(&ips.index, old, new)
|
||||
_ = atomic.CompareAndSwapUint32(&ips.index, oldIP, newIP)
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
usedSystemResolver <- true
|
||||
|
||||
return nil, errors.New("don't actually do anything")
|
||||
},
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultCachingCleanUpInterval = 5 * time.Second
|
||||
|
||||
// CachingResolver caches answers from dns queries with their TTL time,
|
||||
// to avoid external resolver calls for recurrent queries
|
||||
type CachingResolver struct {
|
||||
|
@ -58,7 +60,7 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) ChainedRe
|
|||
}
|
||||
|
||||
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
||||
cleanupOption := expirationcache.WithCleanUpInterval(5 * time.Second)
|
||||
cleanupOption := expirationcache.WithCleanUpInterval(defaultCachingCleanUpInterval)
|
||||
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
|
||||
|
||||
if cfg.Prefetching {
|
||||
|
@ -91,6 +93,7 @@ func setupRedisCacheSubscriber(c *CachingResolver) {
|
|||
// check if domain was queried > threshold in the time window
|
||||
func (r *CachingResolver) isPrefetchingDomain(cacheKey string) bool {
|
||||
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
|
||||
|
||||
return cnt != nil && cnt.(int) > r.prefetchThreshold
|
||||
}
|
||||
|
||||
|
@ -108,6 +111,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
|
|||
if err == nil {
|
||||
if response.Res.Rcode == dns.RcodeSuccess {
|
||||
evt.Bus().Publish(evt.CachingDomainPrefetched, domainName)
|
||||
|
||||
return cacheValue{response.Res.Answer, true}, time.Duration(r.adjustTTLs(response.Res.Answer)) * time.Second
|
||||
}
|
||||
} else {
|
||||
|
@ -122,6 +126,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
|
|||
func (r *CachingResolver) Configuration() (result []string) {
|
||||
if r.maxCacheTimeSec < 0 {
|
||||
result = []string{"deactivated"}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -146,12 +151,12 @@ func (r *CachingResolver) Configuration() (result []string) {
|
|||
|
||||
// Resolve checks if the current query result is already in the cache and returns it
|
||||
// or delegates to the next resolver
|
||||
//nolint:gocognit,funlen
|
||||
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
||||
logger := withPrefix(request.Log, "caching_resolver")
|
||||
|
||||
if r.maxCacheTimeSec < 0 {
|
||||
logger.Debug("skip cache")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
|
|
|
@ -547,6 +547,7 @@ var _ = Describe("CachingResolver", func() {
|
|||
|
||||
Eventually(func() error {
|
||||
resp, err = sut.Resolve(request)
|
||||
|
||||
return err
|
||||
}, "50ms").Should(Succeed())
|
||||
})
|
||||
|
|
|
@ -105,6 +105,21 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
|||
return names
|
||||
}
|
||||
|
||||
func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNames []string) {
|
||||
for _, answer := range answer {
|
||||
if t, ok := answer.(*dns.PTR); ok {
|
||||
hostName := strings.TrimSuffix(t.Ptr, ".")
|
||||
clientNames = append(clientNames, hostName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(clientNames) == 0 {
|
||||
clientNames = []string{fallbackIP.String()}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
|
||||
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
|
||||
// try client mapping first
|
||||
|
@ -114,49 +129,40 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
|
|||
return
|
||||
}
|
||||
|
||||
if r.externalResolver != nil {
|
||||
reverse, _ := dns.ReverseAddr(ip.String())
|
||||
|
||||
resp, err := r.externalResolver.Resolve(&model.Request{
|
||||
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
||||
Log: logger,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("can't resolve client name: ", err)
|
||||
return []string{ip.String()}
|
||||
}
|
||||
|
||||
var clientNames []string
|
||||
|
||||
for _, answer := range resp.Res.Answer {
|
||||
if t, ok := answer.(*dns.PTR); ok {
|
||||
hostName := strings.TrimSuffix(t.Ptr, ".")
|
||||
clientNames = append(clientNames, hostName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(clientNames) == 0 {
|
||||
clientNames = []string{ip.String()}
|
||||
}
|
||||
|
||||
// optional: if singleNameOrder is set, use only one name in the defined order
|
||||
if len(r.singleNameOrder) > 0 {
|
||||
for _, i := range r.singleNameOrder {
|
||||
if i > 0 && int(i) <= len(clientNames) {
|
||||
result = []string{clientNames[i-1]}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result = clientNames
|
||||
}
|
||||
|
||||
logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s) from external resolver")
|
||||
} else {
|
||||
result = []string{ip.String()}
|
||||
if r.externalResolver == nil {
|
||||
return []string{ip.String()}
|
||||
}
|
||||
|
||||
reverse, _ := dns.ReverseAddr(ip.String())
|
||||
|
||||
resp, err := r.externalResolver.Resolve(&model.Request{
|
||||
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
||||
Log: logger,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("can't resolve client name: ", err)
|
||||
|
||||
return []string{ip.String()}
|
||||
}
|
||||
|
||||
clientNames := extractClientNamesFromAnswer(resp.Res.Answer, ip)
|
||||
|
||||
// optional: if singleNameOrder is set, use only one name in the defined order
|
||||
if len(r.singleNameOrder) > 0 {
|
||||
for _, i := range r.singleNameOrder {
|
||||
if i > 0 && int(i) <= len(clientNames) {
|
||||
result = []string{clientNames[i-1]}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result = clientNames
|
||||
}
|
||||
|
||||
logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s) from external resolver")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
|
|
@ -51,31 +51,42 @@ func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
|
||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||
domain := domainFromQuestion
|
||||
|
||||
if strings.Contains(domainFromQuestion, ".") {
|
||||
// try with domain with and without sub-domains
|
||||
for len(domain) > 0 {
|
||||
if resolver, found := r.mapping[domain]; found {
|
||||
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
||||
if i := strings.Index(domain, "."); i >= 0 {
|
||||
domain = domain[i+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if resolver, found := r.mapping["."]; found {
|
||||
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||
|
||||
return true, resp, err
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Resolve uses the conditional resolver to resolve the query
|
||||
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := withPrefix(request.Log, "conditional_resolver")
|
||||
|
||||
if len(r.mapping) > 0 {
|
||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||
domain := domainFromQuestion
|
||||
|
||||
if !strings.Contains(domainFromQuestion, ".") {
|
||||
if resolver, found := r.mapping["."]; found {
|
||||
return r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||
}
|
||||
} else {
|
||||
// try with domain with and without sub-domains
|
||||
for len(domain) > 0 {
|
||||
if resolver, found := r.mapping[domain]; found {
|
||||
return r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||
}
|
||||
|
||||
if i := strings.Index(domain, "."); i >= 0 {
|
||||
domain = domain[i+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
resolved, resp, err := r.processRequest(request)
|
||||
if resolved {
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -88,6 +88,54 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Response {
|
||||
logger := withPrefix(request.Log, "custom_dns_resolver")
|
||||
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(request.Req)
|
||||
|
||||
question := request.Req.Question[0]
|
||||
domain := util.ExtractDomain(question)
|
||||
|
||||
for len(domain) > 0 {
|
||||
ips, found := r.mapping[domain]
|
||||
if found {
|
||||
for _, ip := range ips {
|
||||
if isSupportedType(ip, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(response.Answer) > 0 {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"answer": util.AnswerToString(response.Answer),
|
||||
"domain": domain,
|
||||
}).Debugf("returning custom dns entry")
|
||||
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||
}
|
||||
|
||||
// Mapping exists for this domain, but for another type
|
||||
if !r.filterUnmappedTypes {
|
||||
// go to next resolver
|
||||
break
|
||||
}
|
||||
|
||||
// return NOERROR with empty result
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||
}
|
||||
|
||||
if i := strings.Index(domain, "."); i >= 0 {
|
||||
domain = domain[i+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve uses internal mapping to resolve the query
|
||||
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||
logger := withPrefix(request.Log, "custom_dns_resolver")
|
||||
|
@ -98,46 +146,9 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
}
|
||||
|
||||
if len(r.mapping) > 0 {
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(request.Req)
|
||||
|
||||
question := request.Req.Question[0]
|
||||
domain := util.ExtractDomain(question)
|
||||
|
||||
for len(domain) > 0 {
|
||||
ips, found := r.mapping[domain]
|
||||
if found {
|
||||
for _, ip := range ips {
|
||||
if isSupportedType(ip, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(response.Answer) > 0 {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"answer": util.AnswerToString(response.Answer),
|
||||
"domain": domain,
|
||||
}).Debugf("returning custom dns entry")
|
||||
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
|
||||
}
|
||||
|
||||
// Mapping exists for this domain, but for another type
|
||||
if !r.filterUnmappedTypes {
|
||||
// go to next resolver
|
||||
break
|
||||
}
|
||||
|
||||
// return NOERROR with empty result
|
||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
|
||||
}
|
||||
|
||||
if i := strings.Index(domain, "."); i >= 0 {
|
||||
domain = domain[i+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
resp := r.processRequest(request)
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -76,21 +76,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
domain := util.ExtractDomain(question)
|
||||
|
||||
for _, host := range r.hosts {
|
||||
if host.Hostname == domain {
|
||||
if isSupportedType(host.IP, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, alias := range host.Aliases {
|
||||
if alias == domain {
|
||||
if isSupportedType(host.IP, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
}
|
||||
}
|
||||
response.Answer = append(response.Answer, r.processHostEntry(host, domain, question)...)
|
||||
}
|
||||
|
||||
if len(response.Answer) > 0 {
|
||||
|
@ -108,6 +94,26 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
|||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) processHostEntry(host host, domain string, question dns.Question) (result []dns.RR) {
|
||||
if host.Hostname == domain {
|
||||
if isSupportedType(host.IP, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
||||
result = append(result, rr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, alias := range host.Aliases {
|
||||
if alias == domain {
|
||||
if isSupportedType(host.IP, question) {
|
||||
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
||||
result = append(result, rr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *HostsFileResolver) Configuration() (result []string) {
|
||||
if r.HostsFilePath != "" && len(r.hosts) != 0 {
|
||||
result = append(result, fmt.Sprintf("hosts file path: %s", r.HostsFilePath))
|
||||
|
@ -127,9 +133,7 @@ func NewHostsFileResolver(cfg config.HostsFileConfig) ChainedResolver {
|
|||
refreshPeriod: time.Duration(cfg.RefreshPeriod),
|
||||
}
|
||||
|
||||
err := r.parseHostsFile()
|
||||
|
||||
if err != nil {
|
||||
if err := r.parseHostsFile(); err != nil {
|
||||
logger := logger(hostsFileResolverLogger)
|
||||
logger.Warnf("cannot parse hosts file: %s, hosts file resolving is disabled", r.HostsFilePath)
|
||||
r.HostsFilePath = ""
|
||||
|
@ -147,6 +151,8 @@ type host struct {
|
|||
}
|
||||
|
||||
func (r *HostsFileResolver) parseHostsFile() error {
|
||||
const minColumnCount = 2
|
||||
|
||||
if r.HostsFilePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
@ -177,7 +183,7 @@ func (r *HostsFileResolver) parseHostsFile() error {
|
|||
fields = strings.Fields(trimmed[:end])
|
||||
}
|
||||
|
||||
if len(fields) < 2 {
|
||||
if len(fields) < minColumnCount {
|
||||
// Skip invalid entry
|
||||
continue
|
||||
}
|
||||
|
@ -191,7 +197,7 @@ func (r *HostsFileResolver) parseHostsFile() error {
|
|||
h.IP = net.ParseIP(fields[0])
|
||||
h.Hostname = fields[1]
|
||||
|
||||
if len(fields) > 2 {
|
||||
if len(fields) > minColumnCount {
|
||||
for i := 2; i < len(fields); i++ {
|
||||
h.Aliases = append(h.Aliases, fields[i])
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -29,6 +28,7 @@ type MockResolver struct {
|
|||
|
||||
func (r *MockResolver) Configuration() []string {
|
||||
args := r.Called()
|
||||
|
||||
return args.Get(0).([]string)
|
||||
}
|
||||
|
||||
|
@ -182,6 +182,7 @@ func (t *MockUDPUpstreamServer) WithAnswerError(errorCode int) *MockUDPUpstreamS
|
|||
|
||||
func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response *dns.Msg)) *MockUDPUpstreamServer {
|
||||
t.answerFn = fn
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
|
@ -195,26 +196,33 @@ func (t *MockUDPUpstreamServer) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *MockUDPUpstreamServer) Start() config.Upstream {
|
||||
func CreateConnection() *net.UDPConn {
|
||||
a, err := net.ResolveUDPAddr("udp4", ":0")
|
||||
util.FatalOnError("can't resolve address: ", err)
|
||||
|
||||
ln, err := net.ListenUDP("udp4", a)
|
||||
util.FatalOnError("can't create connection: ", err)
|
||||
|
||||
t.ln = ln
|
||||
return ln
|
||||
}
|
||||
|
||||
func (t *MockUDPUpstreamServer) Start() config.Upstream {
|
||||
ln := CreateConnection()
|
||||
|
||||
ladr := ln.LocalAddr().String()
|
||||
host := strings.Split(ladr, ":")[0]
|
||||
p, err := strconv.ParseUint(strings.Split(ladr, ":")[1], 10, 16)
|
||||
p, err := config.ConvertPort(strings.Split(ladr, ":")[1])
|
||||
|
||||
util.FatalOnError("can't convert port: ", err)
|
||||
|
||||
port := uint16(p)
|
||||
port := p
|
||||
t.ln = ln
|
||||
|
||||
go func() {
|
||||
const bufferSize = 1024
|
||||
|
||||
for {
|
||||
buffer := make([]byte, 1024)
|
||||
buffer := make([]byte, bufferSize)
|
||||
n, addr, err := ln.ReadFromUDP(buffer)
|
||||
|
||||
if err != nil {
|
||||
|
@ -233,6 +241,7 @@ func (t *MockUDPUpstreamServer) Start() config.Upstream {
|
|||
// nil should indicate an error
|
||||
if response == nil {
|
||||
_, _ = ln.WriteToUDP([]byte("dummy"), addr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
const (
|
||||
upstreamDefaultCfgName = "default"
|
||||
parallelResolverLogger = "parallel_best_resolver"
|
||||
resolverCount = 2
|
||||
)
|
||||
|
||||
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
|
||||
|
@ -133,13 +134,14 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
|
||||
if len(resolvers) == 1 {
|
||||
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
|
||||
|
||||
return resolvers[0].resolver.Resolve(request)
|
||||
}
|
||||
|
||||
r1, r2 := pickRandom(resolvers)
|
||||
logger.Debugf("using %s and %s as resolver", r1.resolver, r2.resolver)
|
||||
|
||||
ch := make(chan requestResponse, 2)
|
||||
ch := make(chan requestResponse, resolverCount)
|
||||
|
||||
var collectedErrors []error
|
||||
|
||||
|
@ -152,7 +154,7 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
go resolve(request, r2, ch)
|
||||
|
||||
//nolint: gosimple
|
||||
for len(collectedErrors) < 2 {
|
||||
for len(collectedErrors) < resolverCount {
|
||||
select {
|
||||
case result := <-ch:
|
||||
if result.err != nil {
|
||||
|
@ -163,6 +165,7 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
|||
"resolver": r1.resolver,
|
||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||
}).Debug("using response from resolver")
|
||||
|
||||
return result.response, nil
|
||||
}
|
||||
}
|
||||
|
@ -181,14 +184,17 @@ func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upst
|
|||
}
|
||||
|
||||
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
|
||||
const errorWindowInSec = 60
|
||||
|
||||
var choices []weightedrand.Choice
|
||||
|
||||
for _, res := range in {
|
||||
var weight float64 = 60
|
||||
var weight float64 = errorWindowInSec
|
||||
|
||||
if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour {
|
||||
// reduce weight: consider last error time
|
||||
weight = math.Max(1, weight-(60-time.Since(res.lastErrorTime.Load().(time.Time)).Minutes()))
|
||||
lastErrorTime := res.lastErrorTime.Load().(time.Time)
|
||||
weight = math.Max(1, weight-(errorWindowInSec-time.Since(lastErrorTime).Minutes()))
|
||||
}
|
||||
|
||||
if exclude != res.resolver {
|
||||
|
|
|
@ -43,6 +43,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(slowTestUpstream.Close)
|
||||
|
@ -71,6 +72,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
|||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(slowTestUpstream.Close)
|
||||
|
|
|
@ -14,6 +14,7 @@ const (
|
|||
cleanUpRunPeriod = 12 * time.Hour
|
||||
queryLoggingResolverPrefix = "query_logging_resolver"
|
||||
logChanCap = 1000
|
||||
defaultFlushPeriod = 30 * time.Second
|
||||
)
|
||||
|
||||
// QueryLoggingResolver writes query information (question, answer, duration, ...)
|
||||
|
@ -40,14 +41,15 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
|||
case config.QueryLogTypeCsvClient:
|
||||
writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays)
|
||||
case config.QueryLogTypeMysql:
|
||||
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, 30*time.Second)
|
||||
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, defaultFlushPeriod)
|
||||
case config.QueryLogTypePostgresql:
|
||||
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, 30*time.Second)
|
||||
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, defaultFlushPeriod)
|
||||
case config.QueryLogTypeConsole:
|
||||
writer = querylog.NewLoggerWriter()
|
||||
case config.QueryLogTypeNone:
|
||||
writer = querylog.NewNoneWriter()
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
retry.Attempts(uint(cfg.CreationAttempts)),
|
||||
|
@ -130,7 +132,7 @@ func (r *QueryLoggingResolver) writeLog() {
|
|||
|
||||
r.writer.Write(logEntry)
|
||||
|
||||
halfCap := cap(r.logChan) / 2
|
||||
halfCap := cap(r.logChan) / 2 // nolint:gomnd
|
||||
|
||||
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
|
||||
if len(r.logChan) > halfCap {
|
||||
|
|
|
@ -80,6 +80,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
|||
if response == NoResponse {
|
||||
// Inner resolver had no response, continue with the normal chain
|
||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||
|
||||
return r.next.Resolve(request)
|
||||
}
|
||||
|
||||
|
@ -128,6 +129,7 @@ func (r *RewriterResolver) rewriteDomain(domain string) (string, string) {
|
|||
for k, v := range r.rewrite {
|
||||
if strings.HasSuffix(domain, "."+k) {
|
||||
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v
|
||||
|
||||
return newDomain, k
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,6 +122,7 @@ var _ = Describe("RewriterResolver", func() {
|
|||
mNext.On("Resolve", mock.Anything)
|
||||
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||
|
||||
return mNextResponse, nil
|
||||
}
|
||||
})
|
||||
|
|
|
@ -23,7 +23,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
dnsContentType = "application/dns-message"
|
||||
dnsContentType = "application/dns-message"
|
||||
defaultTLSHandshakeTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// nolint:gochecknoglobals
|
||||
|
@ -66,7 +67,7 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
|||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tlsConfig,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
|
||||
},
|
||||
Timeout: timeout,
|
||||
},
|
||||
|
@ -246,6 +247,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
|
|||
"response_time_ms": rtt.Milliseconds(),
|
||||
}).Debugf("received response from upstream")
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
retry.Attempts(retryAttempts),
|
||||
|
|
|
@ -131,6 +131,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
response, err := util.NewMsgWithAnswer("example.com", 123, dns.Type(dns.TypeA), "123.124.122.122")
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
}
|
||||
})
|
||||
|
|
|
@ -25,6 +25,10 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
maxUDPBufferSize = 65535
|
||||
)
|
||||
|
||||
// Server controls the endpoints for DNS and HTTP
|
||||
type Server struct {
|
||||
dnsServers []*dns.Server
|
||||
|
@ -213,7 +217,7 @@ func createUDPServer(address string) (*dns.Server, error) {
|
|||
NotifyStartedFunc: func() {
|
||||
logger().Infof("UDP server is up and running on address %s", address)
|
||||
},
|
||||
UDPSize: 65535,
|
||||
UDPSize: maxUDPBufferSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -304,7 +308,9 @@ func (s *Server) printConfiguration() {
|
|||
}
|
||||
|
||||
func toMB(b uint64) uint64 {
|
||||
return b / 1024 / 1024
|
||||
const bytesInKB = 1024
|
||||
|
||||
return b / bytesInKB / bytesInKB
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/api"
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -26,6 +27,7 @@ import (
|
|||
const (
|
||||
dohMessageLimit = 512
|
||||
dnsContentType = "application/dns-message"
|
||||
corsMaxAge = 5 * time.Minute
|
||||
)
|
||||
|
||||
func (s *Server) registerAPIEndpoints(router *chi.Mux) {
|
||||
|
@ -109,6 +111,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
|||
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -120,6 +123,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
|||
b, err := resResponse.Res.Pack()
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "can't serialize message: ", rw)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -163,6 +167,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "can't read request: ", rw)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -189,6 +194,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -201,6 +207,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "unable to marshal response: ", rw)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -289,7 +296,7 @@ func configureCorsHandler(router *chi.Mux) {
|
|||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300,
|
||||
MaxAge: int(corsMaxAge.Seconds()),
|
||||
})
|
||||
router.Use(crs.Handler)
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ var _ = BeforeSuite(func() {
|
|||
)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(googleMockUpstream.Close)
|
||||
|
@ -54,6 +55,7 @@ var _ = BeforeSuite(func() {
|
|||
)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(fritzboxMockUpstream.Close)
|
||||
|
@ -64,6 +66,7 @@ var _ = BeforeSuite(func() {
|
|||
)
|
||||
|
||||
Expect(err).Should(Succeed())
|
||||
|
||||
return response
|
||||
})
|
||||
DeferCleanup(clientMockUpstream.Close)
|
||||
|
@ -153,6 +156,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
for res != nil {
|
||||
if t, ok := res.(*resolver.ClientNamesResolver); ok {
|
||||
t.FlushCache()
|
||||
|
||||
break
|
||||
}
|
||||
if c, ok := res.(resolver.ChainedResolver); ok {
|
||||
|
@ -561,10 +565,9 @@ var _ = Describe("Running DNS server", func() {
|
|||
Expect(err).Should(Succeed())
|
||||
|
||||
errChan := make(chan error, 10)
|
||||
|
||||
// start server
|
||||
go func() {
|
||||
server.Start(errChan)
|
||||
}()
|
||||
go server.Start(errChan)
|
||||
|
||||
DeferCleanup(server.Stop)
|
||||
|
||||
|
|
|
@ -193,7 +193,8 @@ func Chunks(s string, chunkSize int) []string {
|
|||
|
||||
// GenerateCacheKey return cacheKey by query type/domain
|
||||
func GenerateCacheKey(qType dns.Type, qName string) string {
|
||||
b := make([]byte, 2+len(qName))
|
||||
const qTypeLength = 2
|
||||
b := make([]byte, qTypeLength+len(qName))
|
||||
|
||||
binary.BigEndian.PutUint16(b, uint16(qType))
|
||||
copy(b[2:], strings.ToLower(qName))
|
||||
|
@ -224,5 +225,6 @@ func CidrContainsIP(cidr string, ip net.IP) bool {
|
|||
// ClientNameMatchesGroupName checks if a group with optional wildcards contains a client name
|
||||
func ClientNameMatchesGroupName(group string, clientName string) bool {
|
||||
match, _ := filepath.Match(group, clientName)
|
||||
|
||||
return match
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue