mirror of https://github.com/0xERR0R/blocky.git
update golangci-lint (#510)
* update golangci-lint * enable gomnd linter * enable asciicheck linter * enable bidichk linter * enable durationcheck linter * enable errchkjson linter * enable errorlint linter * enable exhaustive linter * enable gomoddirectives linter * enable gomodguard guard * enable grouper linter * enable grouper and ifshort linters * enable importas linter * enable makezero linter * enable nestif linter * enable nilerr linter * enable nilnil linter * enable nlreturn linter * enable nolintlint linter * enable predeclared linter * enable sqlclosecheck linter * enable tenv linter * enable wastedassign linter
This commit is contained in:
parent
41febafd41
commit
a4b89537db
|
@ -16,5 +16,5 @@ jobs:
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v3
|
||||||
with:
|
with:
|
||||||
version: v1.43.0
|
version: v1.45.2
|
||||||
args: --timeout 5m0s
|
args: --timeout 5m0s
|
||||||
|
|
|
@ -1,41 +1,63 @@
|
||||||
linters:
|
linters:
|
||||||
enable:
|
enable:
|
||||||
- govet
|
- asciicheck
|
||||||
- errcheck
|
- bidichk
|
||||||
- staticcheck
|
|
||||||
- unused
|
|
||||||
- gosimple
|
|
||||||
- structcheck
|
|
||||||
- varcheck
|
|
||||||
- ineffassign
|
|
||||||
- deadcode
|
|
||||||
- typecheck
|
|
||||||
- bodyclose
|
- bodyclose
|
||||||
- revive
|
- deadcode
|
||||||
- stylecheck
|
|
||||||
- gosec
|
|
||||||
- unconvert
|
|
||||||
- dupl
|
|
||||||
- goconst
|
|
||||||
- gocyclo
|
|
||||||
- gocognit
|
|
||||||
- gofmt
|
|
||||||
- goimports
|
|
||||||
- depguard
|
- depguard
|
||||||
- misspell
|
|
||||||
- lll
|
|
||||||
- unparam
|
|
||||||
- dogsled
|
- dogsled
|
||||||
|
- dupl
|
||||||
|
- durationcheck
|
||||||
|
- errcheck
|
||||||
|
- errchkjson
|
||||||
|
- errorlint
|
||||||
|
- exhaustive
|
||||||
|
- exportloopref
|
||||||
- funlen
|
- funlen
|
||||||
- gochecknoglobals
|
- gochecknoglobals
|
||||||
- gochecknoinits
|
- gochecknoinits
|
||||||
|
- gocognit
|
||||||
|
- goconst
|
||||||
- gocritic
|
- gocritic
|
||||||
|
- gocyclo
|
||||||
- godox
|
- godox
|
||||||
|
- gofmt
|
||||||
|
- goimports
|
||||||
|
- gomnd
|
||||||
|
- gomoddirectives
|
||||||
|
- gomodguard
|
||||||
|
- gosec
|
||||||
|
- gosimple
|
||||||
|
- govet
|
||||||
|
- grouper
|
||||||
|
- ifshort
|
||||||
|
- importas
|
||||||
|
- ineffassign
|
||||||
|
- lll
|
||||||
|
- makezero
|
||||||
|
- misspell
|
||||||
- nakedret
|
- nakedret
|
||||||
|
- nestif
|
||||||
|
- nilerr
|
||||||
|
- nilnil
|
||||||
|
- nlreturn
|
||||||
|
- nolintlint
|
||||||
- prealloc
|
- prealloc
|
||||||
|
- predeclared
|
||||||
|
- revive
|
||||||
|
- sqlclosecheck
|
||||||
|
- staticcheck
|
||||||
|
- structcheck
|
||||||
|
- stylecheck
|
||||||
|
- tenv
|
||||||
|
- typecheck
|
||||||
|
- unconvert
|
||||||
|
- unparam
|
||||||
|
- unused
|
||||||
|
- varcheck
|
||||||
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
- wsl
|
- wsl
|
||||||
- exportloopref
|
|
||||||
disable:
|
disable:
|
||||||
- noctx
|
- noctx
|
||||||
- scopelint
|
- scopelint
|
||||||
|
@ -46,6 +68,11 @@ linters:
|
||||||
- unused
|
- unused
|
||||||
fast: false
|
fast: false
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
gomnd:
|
||||||
|
ignored-numbers:
|
||||||
|
- '0666'
|
||||||
|
|
||||||
issues:
|
issues:
|
||||||
exclude-rules:
|
exclude-rules:
|
||||||
# Exclude some linters from running on tests files.
|
# Exclude some linters from running on tests files.
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -38,7 +38,7 @@ race: ## run tests with race detector
|
||||||
go test -race -short ./...
|
go test -race -short ./...
|
||||||
|
|
||||||
lint: build ## run golangcli-lint checks
|
lint: build ## run golangcli-lint checks
|
||||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.43.0
|
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.45.2
|
||||||
$(shell go env GOPATH)/bin/golangci-lint run
|
$(shell go env GOPATH)/bin/golangci-lint run
|
||||||
|
|
||||||
run: build ## Build and run binary
|
run: build ## Build and run binary
|
||||||
|
|
|
@ -30,6 +30,7 @@ func (b *BlockingControlMock) EnableBlocking() {
|
||||||
}
|
}
|
||||||
func (b *BlockingControlMock) DisableBlocking(_ time.Duration, disableGroups []string) error {
|
func (b *BlockingControlMock) DisableBlocking(_ time.Duration, disableGroups []string) error {
|
||||||
b.enabled = false
|
b.enabled = false
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -139,8 +139,7 @@ func isExpired(el *element) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateRemainTTL(expiresEpoch int64) time.Duration {
|
func calculateRemainTTL(expiresEpoch int64) time.Duration {
|
||||||
now := time.Now().UnixMilli()
|
if now := time.Now().UnixMilli(); now < expiresEpoch {
|
||||||
if now < expiresEpoch {
|
|
||||||
return time.Duration(expiresEpoch-now) * time.Millisecond
|
return time.Duration(expiresEpoch-now) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,7 @@ func (cache regexCache) Contains(searchString string) bool {
|
||||||
for _, regex := range cache {
|
for _, regex := range cache {
|
||||||
if regex.MatchString(searchString) {
|
if regex.MatchString(searchString) {
|
||||||
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,6 +40,7 @@ func refreshList(_ *cobra.Command, _ []string) error {
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := ioutil.ReadAll(resp.Body)
|
body, _ := ioutil.ReadAll(resp.Body)
|
||||||
|
|
||||||
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
|
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ func query(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := ioutil.ReadAll(resp.Body)
|
body, _ := ioutil.ReadAll(resp.Body)
|
||||||
|
|
||||||
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
|
return fmt.Errorf("response NOK, %s %s", resp.Status, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
20
cmd/root.go
20
cmd/root.go
|
@ -3,7 +3,6 @@ package cmd
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -20,6 +19,12 @@ var (
|
||||||
apiPort uint16
|
apiPort uint16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPort = 4000
|
||||||
|
defaultHost = "localhost"
|
||||||
|
defaultConfigPath = "./config.yml"
|
||||||
|
)
|
||||||
|
|
||||||
// NewRootCommand creates a new root cli command instance
|
// NewRootCommand creates a new root cli command instance
|
||||||
func NewRootCommand() *cobra.Command {
|
func NewRootCommand() *cobra.Command {
|
||||||
c := &cobra.Command{
|
c := &cobra.Command{
|
||||||
|
@ -35,9 +40,9 @@ Complete documentation is available at https://github.com/0xERR0R/blocky`,
|
||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.PersistentFlags().StringVarP(&configPath, "config", "c", "./config.yml", "path to config file")
|
c.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "path to config file")
|
||||||
c.PersistentFlags().StringVar(&apiHost, "apiHost", "localhost", "host of blocky (API). Default overridden by config and CLI.") // nolint:lll
|
c.PersistentFlags().StringVar(&apiHost, "apiHost", defaultHost, "host of blocky (API). Default overridden by config and CLI.") // nolint:lll
|
||||||
c.PersistentFlags().Uint16Var(&apiPort, "apiPort", 4000, "port of blocky (API). Default overridden by config and CLI.")
|
c.PersistentFlags().Uint16Var(&apiPort, "apiPort", defaultPort, "port of blocky (API). Default overridden by config and CLI.") // nolint:lll
|
||||||
|
|
||||||
c.AddCommand(newRefreshCommand(),
|
c.AddCommand(newRefreshCommand(),
|
||||||
NewQueryCommand(),
|
NewQueryCommand(),
|
||||||
|
@ -73,15 +78,14 @@ func initConfig() {
|
||||||
|
|
||||||
apiHost = strings.Join(split[:lastIdx], ":")
|
apiHost = strings.Join(split[:lastIdx], ":")
|
||||||
|
|
||||||
var p uint64
|
port, err := config.ConvertPort(split[lastIdx])
|
||||||
p, err := strconv.ParseUint(strings.TrimSpace(split[lastIdx]), 10, 16)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.FatalOnError("can't convert port to number (1 - 65535)", err)
|
util.FatalOnError("can't convert port to number (1 - 65535)", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiPort = uint16(p)
|
apiPort = port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,8 @@ func startServer(_ *cobra.Command, _ []string) error {
|
||||||
return fmt.Errorf("can't start server: %w", err)
|
return fmt.Errorf("can't start server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
errChan := make(chan error, 10)
|
const errChanSize = 10
|
||||||
|
errChan := make(chan error, errChanSize)
|
||||||
|
|
||||||
srv.Start(errChan)
|
srv.Start(errChan)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,12 @@ import (
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpPort = 53
|
||||||
|
tlsPort = 853
|
||||||
|
httpsPort = 443
|
||||||
|
)
|
||||||
|
|
||||||
// NetProtocol resolver protocol ENUM(
|
// NetProtocol resolver protocol ENUM(
|
||||||
// tcp+udp // TCP and UDP protocols
|
// tcp+udp // TCP and UDP protocols
|
||||||
// tcp-tls // TCP-TLS protocol
|
// tcp-tls // TCP-TLS protocol
|
||||||
|
@ -60,6 +66,7 @@ func NewQTypeSet(qTypes ...dns.Type) QTypeSet {
|
||||||
|
|
||||||
func (s QTypeSet) Contains(qType dns.Type) bool {
|
func (s QTypeSet) Contains(qType dns.Type) bool {
|
||||||
_, found := s[QType(qType)]
|
_, found := s[QType(qType)]
|
||||||
|
|
||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,9 +86,9 @@ func (c *Duration) String() string {
|
||||||
|
|
||||||
// nolint:gochecknoglobals
|
// nolint:gochecknoglobals
|
||||||
var netDefaultPort = map[NetProtocol]uint16{
|
var netDefaultPort = map[NetProtocol]uint16{
|
||||||
NetProtocolTcpUdp: 53,
|
NetProtocolTcpUdp: udpPort,
|
||||||
NetProtocolTcpTls: 853,
|
NetProtocolTcpTls: tlsPort,
|
||||||
NetProtocolHttps: 443,
|
NetProtocolHttps: httpsPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upstream is the definition of external DNS server
|
// Upstream is the definition of external DNS server
|
||||||
|
@ -252,12 +259,14 @@ func (c *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
// duration is defined as number without unit
|
// duration is defined as number without unit
|
||||||
// use minutes to ensure back compatibility
|
// use minutes to ensure back compatibility
|
||||||
*c = Duration(time.Duration(minutes) * time.Minute)
|
*c = Duration(time.Duration(minutes) * time.Minute)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
duration, err := time.ParseDuration(input)
|
duration, err := time.ParseDuration(input)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
*c = Duration(duration)
|
*c = Duration(duration)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,15 +329,15 @@ func ParseUpstream(upstream string) (Upstream, error) {
|
||||||
|
|
||||||
// string contains host:port
|
// string contains host:port
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var p uint64
|
p, err := ConvertPort(portString)
|
||||||
p, err = strconv.ParseUint(strings.TrimSpace(portString), 10, 16)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
|
err = fmt.Errorf("can't convert port to number (1 - 65535) %w", err)
|
||||||
|
|
||||||
return Upstream{}, err
|
return Upstream{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
port = uint16(p)
|
port = p
|
||||||
} else {
|
} else {
|
||||||
// only host, use default port
|
// only host, use default port
|
||||||
host = upstream
|
host = upstream
|
||||||
|
@ -340,9 +349,7 @@ func ParseUpstream(upstream string) (Upstream, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate hostname or ip
|
// validate hostname or ip
|
||||||
ip := net.ParseIP(host)
|
if ip := net.ParseIP(host); ip == nil {
|
||||||
|
|
||||||
if ip == nil {
|
|
||||||
// is not IP
|
// is not IP
|
||||||
if !validDomain.MatchString(host) {
|
if !validDomain.MatchString(host) {
|
||||||
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
|
return Upstream{}, fmt.Errorf("wrong host name '%s'", host)
|
||||||
|
@ -544,6 +551,7 @@ func LoadConfig(path string, mandatory bool) (*Config, error) {
|
||||||
// config file does not exist
|
// config file does not exist
|
||||||
// return config with default values
|
// return config with default values
|
||||||
config = &cfg
|
config = &cfg
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -596,3 +604,20 @@ func validateConfig(cfg *Config) (err error) {
|
||||||
func GetConfig() *Config {
|
func GetConfig() *Config {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConvertPort converts string representation into a valid port (0 - 65535)
|
||||||
|
func ConvertPort(in string) (uint16, error) {
|
||||||
|
const (
|
||||||
|
base = 10
|
||||||
|
bitSize = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
var p uint64
|
||||||
|
p, err := strconv.ParseUint(strings.TrimSpace(in), base, bitSize)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16(p), nil
|
||||||
|
}
|
||||||
|
|
|
@ -269,6 +269,7 @@ bootstrapDns:
|
||||||
u := &Upstream{}
|
u := &Upstream{}
|
||||||
err := u.UnmarshalYAML(func(i interface{}) error {
|
err := u.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*string) = "tcp+udp:1.2.3.4"
|
*i.(*string) = "tcp+udp:1.2.3.4"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
})
|
})
|
||||||
|
@ -292,6 +293,7 @@ bootstrapDns:
|
||||||
l := &ListenConfig{}
|
l := &ListenConfig{}
|
||||||
err := l.UnmarshalYAML(func(i interface{}) error {
|
err := l.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*string) = "55,:56"
|
*i.(*string) = "55,:56"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -311,6 +313,7 @@ bootstrapDns:
|
||||||
d := Duration(0)
|
d := Duration(0)
|
||||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*string) = "1m20s"
|
*i.(*string) = "1m20s"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -321,6 +324,7 @@ bootstrapDns:
|
||||||
d := Duration(0)
|
d := Duration(0)
|
||||||
err := d.UnmarshalYAML(func(i interface{}) error {
|
err := d.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*string) = "wrong"
|
*i.(*string) = "wrong"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
|
@ -342,6 +346,7 @@ bootstrapDns:
|
||||||
c := &ConditionalUpstreamMapping{}
|
c := &ConditionalUpstreamMapping{}
|
||||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -364,6 +369,7 @@ bootstrapDns:
|
||||||
c := &CustomDNSMapping{}
|
c := &CustomDNSMapping{}
|
||||||
err := c.UnmarshalYAML(func(i interface{}) error {
|
err := c.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
*i.(*map[string]string) = map[string]string{"key": "1.2.3.4"}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -385,6 +391,7 @@ bootstrapDns:
|
||||||
t := QType(0)
|
t := QType(0)
|
||||||
err := t.UnmarshalYAML(func(i interface{}) error {
|
err := t.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*string) = "AAAA"
|
*i.(*string) = "AAAA"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
@ -395,6 +402,7 @@ bootstrapDns:
|
||||||
t := QType(0)
|
t := QType(0)
|
||||||
err := t.UnmarshalYAML(func(i interface{}) error {
|
err := t.UnmarshalYAML(func(i interface{}) error {
|
||||||
*i.(*string) = "WRONGTYPE"
|
*i.(*string) = "WRONGTYPE"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
Expect(err).Should(HaveOccurred())
|
Expect(err).Should(HaveOccurred())
|
||||||
|
|
|
@ -103,10 +103,10 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||||
func() error {
|
func() error {
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
var httpErr error
|
var httpErr error
|
||||||
//nolint:bodyclose
|
|
||||||
if resp, httpErr = client.Get(link); httpErr == nil {
|
if resp, httpErr = client.Get(link); httpErr == nil {
|
||||||
if resp.StatusCode == http.StatusOK {
|
if resp.StatusCode == http.StatusOK {
|
||||||
body = resp.Body
|
body = resp.Body
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +118,7 @@ func (d *HTTPDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||||
if errors.As(httpErr, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
|
if errors.As(httpErr, &netErr) && (netErr.Timeout() || netErr.Temporary()) {
|
||||||
return &TransientError{inner: netErr}
|
return &TransientError{inner: netErr}
|
||||||
}
|
}
|
||||||
|
|
||||||
return httpErr
|
return httpErr
|
||||||
},
|
},
|
||||||
retry.Attempts(d.downloadAttempts),
|
retry.Attempts(d.downloadAttempts),
|
||||||
|
|
|
@ -163,6 +163,7 @@ Loop:
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
close(c)
|
close(c)
|
||||||
|
|
||||||
break Loop
|
break Loop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -307,9 +308,7 @@ func processLine(line string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Fields(line)
|
if parts := strings.Fields(line); len(parts) > 0 {
|
||||||
|
|
||||||
if len(parts) > 0 {
|
|
||||||
host := parts[len(parts)-1]
|
host := parts[len(parts)-1]
|
||||||
|
|
||||||
ip := net.ParseIP(host)
|
ip := net.ParseIP(host)
|
||||||
|
|
|
@ -152,6 +152,7 @@ var _ = Describe("ListCache", func() {
|
||||||
By("List couldn't be loaded due to 404 err", func() {
|
By("List couldn't be loaded due to 404 err", func() {
|
||||||
Eventually(func() bool {
|
Eventually(func() bool {
|
||||||
found, _ := sut.Match("blocked1.com", []string{"gr1"})
|
found, _ := sut.Match("blocked1.com", []string{"gr1"})
|
||||||
|
|
||||||
return found
|
return found
|
||||||
}, "1s").Should(BeFalse())
|
}, "1s").Should(BeFalse())
|
||||||
})
|
})
|
||||||
|
@ -313,5 +314,6 @@ type MockDownloader struct {
|
||||||
|
|
||||||
func (m *MockDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
func (m *MockDownloader) DownloadFile(link string) (io.ReadCloser, error) {
|
||||||
fn := <-m.data
|
fn := <-m.data
|
||||||
|
|
||||||
return fn()
|
return fn()
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ var _ = Describe("DatabaseWriter", func() {
|
||||||
result := writer.db.Find(&logEntry{})
|
result := writer.db.Find(&logEntry{})
|
||||||
|
|
||||||
result.Count(&res)
|
result.Count(&res)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}, "1s").Should(BeNumerically("==", 1))
|
}, "1s").Should(BeNumerically("==", 1))
|
||||||
})
|
})
|
||||||
|
@ -91,6 +92,7 @@ var _ = Describe("DatabaseWriter", func() {
|
||||||
result := writer.db.Find(&logEntry{})
|
result := writer.db.Find(&logEntry{})
|
||||||
|
|
||||||
result.Count(&res)
|
result.Count(&res)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}, "1s").Should(BeNumerically("==", 2))
|
}, "1s").Should(BeNumerically("==", 2))
|
||||||
|
|
||||||
|
@ -102,6 +104,7 @@ var _ = Describe("DatabaseWriter", func() {
|
||||||
result := writer.db.Find(&logEntry{})
|
result := writer.db.Find(&logEntry{})
|
||||||
|
|
||||||
result.Count(&res)
|
result.Count(&res)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
}, "1s").Should(BeNumerically("==", 1))
|
}, "1s").Should(BeNumerically("==", 1))
|
||||||
})
|
})
|
||||||
|
|
|
@ -72,6 +72,8 @@ func (d *FileWriter) Write(entry *LogEntry) {
|
||||||
|
|
||||||
// CleanUp deletes old log files
|
// CleanUp deletes old log files
|
||||||
func (d *FileWriter) CleanUp() {
|
func (d *FileWriter) CleanUp() {
|
||||||
|
const hoursPerDay = 24
|
||||||
|
|
||||||
logger := log.PrefixedLog(loggerPrefixFileWriter)
|
logger := log.PrefixedLog(loggerPrefixFileWriter)
|
||||||
|
|
||||||
logger.Trace("starting clean up")
|
logger.Trace("starting clean up")
|
||||||
|
@ -85,7 +87,7 @@ func (d *FileWriter) CleanUp() {
|
||||||
if strings.HasSuffix(f.Name(), ".log") && len(f.Name()) > 10 {
|
if strings.HasSuffix(f.Name(), ".log") && len(f.Name()) > 10 {
|
||||||
t, err := time.Parse("2006-01-02", f.Name()[:10])
|
t, err := time.Parse("2006-01-02", f.Name()[:10])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
differenceDays := uint64(time.Since(t).Hours() / 24)
|
differenceDays := uint64(time.Since(t).Hours() / hoursPerDay)
|
||||||
if d.logRetentionDays > 0 && differenceDays > d.logRetentionDays {
|
if d.logRetentionDays > 0 && differenceDays > d.logRetentionDays {
|
||||||
logger.WithFields(logrus.Fields{
|
logger.WithFields(logrus.Fields{
|
||||||
"file": f.Name(),
|
"file": f.Name(),
|
||||||
|
|
|
@ -69,7 +69,7 @@ type Client struct {
|
||||||
func New(cfg *config.RedisConfig) (*Client, error) {
|
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||||
// disable redis if no address is provided
|
// disable redis if no address is provided
|
||||||
if cfg == nil || len(cfg.Address) == 0 {
|
if cfg == nil || len(cfg.Address) == 0 {
|
||||||
return nil, nil
|
return nil, nil // nolint:nilnil
|
||||||
}
|
}
|
||||||
|
|
||||||
rdb := redis.NewClient(&redis.Options{
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
@ -164,7 +164,12 @@ func (c *Client) startup() error {
|
||||||
select {
|
select {
|
||||||
// received message from subscription
|
// received message from subscription
|
||||||
case msg := <-ps.Channel():
|
case msg := <-ps.Channel():
|
||||||
err = c.processReceivedMessage(msg)
|
c.l.Debug("Received message: ", msg)
|
||||||
|
|
||||||
|
if msg != nil && len(msg.Payload) > 0 {
|
||||||
|
// message is not empty
|
||||||
|
err = c.processReceivedMessage(msg)
|
||||||
|
}
|
||||||
// publish message from buffer
|
// publish message from buffer
|
||||||
case s := <-c.sendBuffer:
|
case s := <-c.sendBuffer:
|
||||||
c.publishMessageFromBuffer(s)
|
c.publishMessageFromBuffer(s)
|
||||||
|
@ -201,34 +206,30 @@ func (c *Client) publishMessageFromBuffer(s *bufferMessage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) processReceivedMessage(msg *redis.Message) (err error) {
|
func (c *Client) processReceivedMessage(msg *redis.Message) (err error) {
|
||||||
c.l.Debug("Received message: ", msg)
|
var rm redisMessage
|
||||||
// message is not empty
|
|
||||||
if msg != nil && len(msg.Payload) > 0 {
|
|
||||||
var rm redisMessage
|
|
||||||
|
|
||||||
err = json.Unmarshal([]byte(msg.Payload), &rm)
|
err = json.Unmarshal([]byte(msg.Payload), &rm)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// message was sent from a different blocky instance
|
// message was sent from a different blocky instance
|
||||||
if !bytes.Equal(rm.Client, c.id) {
|
if !bytes.Equal(rm.Client, c.id) {
|
||||||
switch rm.Type {
|
switch rm.Type {
|
||||||
case messageTypeCache:
|
case messageTypeCache:
|
||||||
var cm *CacheMessage
|
var cm *CacheMessage
|
||||||
|
|
||||||
cm, err = convertMessage(&rm, 0)
|
cm, err = convertMessage(&rm, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.CacheChannel <- cm
|
c.CacheChannel <- cm
|
||||||
}
|
|
||||||
case messageTypeEnable:
|
|
||||||
err = c.processEnabledMessage(&rm)
|
|
||||||
default:
|
|
||||||
c.l.Warn("Unknown message type: ", rm.Type)
|
|
||||||
}
|
}
|
||||||
|
case messageTypeEnable:
|
||||||
|
err = c.processEnabledMessage(&rm)
|
||||||
|
default:
|
||||||
|
c.l.Warn("Unknown message type: ", rm.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.l.Error("Processing error: ", err)
|
c.l.Error("Processing error: ", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -25,6 +25,8 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultBlockingCleanUpInterval = 5 * time.Second
|
||||||
|
|
||||||
func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) {
|
func createBlockHandler(cfg config.BlockingConfig) (blockHandler, error) {
|
||||||
cfgBlockType := cfg.BlockType
|
cfgBlockType := cfg.BlockType
|
||||||
|
|
||||||
|
@ -362,7 +364,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
||||||
request *model.Request, logger *logrus.Entry) (*model.Response, error) {
|
request *model.Request, logger *logrus.Entry) (bool, *model.Response, error) {
|
||||||
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
|
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
|
||||||
whitelistOnlyAllowed := r.hasWhiteListOnlyAllowed(groupsToCheck)
|
whitelistOnlyAllowed := r.hasWhiteListOnlyAllowed(groupsToCheck)
|
||||||
|
|
||||||
|
@ -372,19 +374,26 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
|
||||||
|
|
||||||
if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, domain); whitelisted {
|
if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, domain); whitelisted {
|
||||||
logger.WithField("group", group).Debugf("domain is whitelisted")
|
logger.WithField("group", group).Debugf("domain is whitelisted")
|
||||||
return r.next.Resolve(request)
|
|
||||||
|
resp, err := r.next.Resolve(request)
|
||||||
|
|
||||||
|
return true, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if whitelistOnlyAllowed {
|
if whitelistOnlyAllowed {
|
||||||
return r.handleBlocked(logger, request, question, "BLOCKED (WHITELIST ONLY)")
|
resp, err := r.handleBlocked(logger, request, question, "BLOCKED (WHITELIST ONLY)")
|
||||||
|
|
||||||
|
return true, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, domain); blocked {
|
if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, domain); blocked {
|
||||||
return r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
|
resp, err := r.handleBlocked(logger, request, question, fmt.Sprintf("BLOCKED (%s)", group))
|
||||||
|
|
||||||
|
return true, resp, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return false, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
|
// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
|
||||||
|
@ -393,8 +402,8 @@ func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, err
|
||||||
groupsToCheck := r.groupsToCheckForClient(request)
|
groupsToCheck := r.groupsToCheckForClient(request)
|
||||||
|
|
||||||
if len(groupsToCheck) > 0 {
|
if len(groupsToCheck) > 0 {
|
||||||
resp, err := r.handleBlacklist(groupsToCheck, request, logger)
|
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
|
||||||
if resp != nil || err != nil {
|
if handled {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -541,6 +550,7 @@ func (b zeroIPBlockHandler) handleBlock(question dns.Question, response *dns.Msg
|
||||||
zeroIP = net.IPv4zero
|
zeroIP = net.IPv4zero
|
||||||
default:
|
default:
|
||||||
response.Rcode = dns.RcodeNameError
|
response.Rcode = dns.RcodeNameError
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -603,7 +613,7 @@ func (r *BlockingResolver) initFQDNIPCache() {
|
||||||
identifiers = append(identifiers, identifier)
|
identifiers = append(identifiers, identifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(5*time.Second),
|
r.fqdnIPCache = expirationcache.NewCache(expirationcache.WithCleanUpInterval(defaultBlockingCleanUpInterval),
|
||||||
expirationcache.WithOnExpiredFn(func(key string) (val interface{}, ttl time.Duration) {
|
expirationcache.WithOnExpiredFn(func(key string) (val interface{}, ttl time.Duration) {
|
||||||
return r.queryForFQIdentifierIPs(key)
|
return r.queryForFQIdentifierIPs(key)
|
||||||
}))
|
}))
|
||||||
|
@ -618,5 +628,6 @@ func (r *BlockingResolver) initFQDNIPCache() {
|
||||||
|
|
||||||
func isFQDN(in string) bool {
|
func isFQDN(in string) bool {
|
||||||
s := strings.Trim(in, ".")
|
s := strings.Trim(in, ".")
|
||||||
|
|
||||||
return strings.Contains(s, ".")
|
return strings.Contains(s, ".")
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,8 +130,10 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() {
|
||||||
m.AnswerFn = func(t uint16, qName string) *dns.Msg {
|
m.AnswerFn = func(t uint16, qName string) *dns.Msg {
|
||||||
if t == dns.TypeA && qName == "full.qualified.com." {
|
if t == dns.TypeA && qName == "full.qualified.com." {
|
||||||
a, _ := util.NewMsgWithAnswer(qName, 60*60, dns.Type(dns.TypeA), "192.168.178.39")
|
a, _ := util.NewMsgWithAnswer(qName, 60*60, dns.Type(dns.TypeA), "192.168.178.39")
|
||||||
|
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
Bus().Publish(ApplicationStarted, "")
|
Bus().Publish(ApplicationStarted, "")
|
||||||
|
|
|
@ -141,6 +141,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
||||||
host, port, err := net.SplitHostPort(addr)
|
host, port, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("dial error: %s", err)
|
log.Errorf("dial error: %s", err)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,6 +162,7 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
||||||
ips, err := b.resolve(host, qTypes)
|
ips, err := b.resolve(host, qTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("resolve error: %s", err)
|
log.Errorf("resolve error: %s", err)
|
||||||
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,18 +172,20 @@ func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
||||||
|
|
||||||
// Use the standard dialer to actually connect
|
// Use the standard dialer to actually connect
|
||||||
addrWithIP := net.JoinHostPort(ip.String(), port)
|
addrWithIP := net.JoinHostPort(ip.String(), port)
|
||||||
|
|
||||||
return dialer.DialContext(ctx, network, addrWithIP)
|
return dialer.DialContext(ctx, network, addrWithIP)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
|
||||||
ips = make([]net.IP, 0, 2)
|
ips = make([]net.IP, 0, len(qTypes))
|
||||||
|
|
||||||
for _, qType := range qTypes {
|
for _, qType := range qTypes {
|
||||||
qIPs, qErr := b.resolveType(hostname, qType)
|
qIPs, qErr := b.resolveType(hostname, qType)
|
||||||
if qErr != nil {
|
if qErr != nil {
|
||||||
err = multierror.Append(err, qErr)
|
err = multierror.Append(err, qErr)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,8 +200,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
|
func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
|
||||||
ip := net.ParseIP(hostname)
|
if ip := net.ParseIP(hostname); ip != nil {
|
||||||
if ip != nil {
|
|
||||||
return []net.IP{ip}, nil
|
return []net.IP{ip}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,14 +239,15 @@ type IPSet struct {
|
||||||
|
|
||||||
func (ips *IPSet) Current() net.IP {
|
func (ips *IPSet) Current() net.IP {
|
||||||
idx := atomic.LoadUint32(&ips.index)
|
idx := atomic.LoadUint32(&ips.index)
|
||||||
|
|
||||||
return ips.values[idx]
|
return ips.values[idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ips *IPSet) Next() {
|
func (ips *IPSet) Next() {
|
||||||
old := ips.index
|
oldIP := ips.index
|
||||||
new := uint32(int(ips.index+1) % len(ips.values))
|
newIP := uint32(int(ips.index+1) % len(ips.values))
|
||||||
|
|
||||||
// We don't care about the result: if the call fails,
|
// We don't care about the result: if the call fails,
|
||||||
// it means the value was incremented by another goroutine
|
// it means the value was incremented by another goroutine
|
||||||
_ = atomic.CompareAndSwapUint32(&ips.index, old, new)
|
_ = atomic.CompareAndSwapUint32(&ips.index, oldIP, newIP)
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,6 +58,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
||||||
PreferGo: true,
|
PreferGo: true,
|
||||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
usedSystemResolver <- true
|
usedSystemResolver <- true
|
||||||
|
|
||||||
return nil, errors.New("don't actually do anything")
|
return nil, errors.New("don't actually do anything")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,8 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultCachingCleanUpInterval = 5 * time.Second
|
||||||
|
|
||||||
// CachingResolver caches answers from dns queries with their TTL time,
|
// CachingResolver caches answers from dns queries with their TTL time,
|
||||||
// to avoid external resolver calls for recurrent queries
|
// to avoid external resolver calls for recurrent queries
|
||||||
type CachingResolver struct {
|
type CachingResolver struct {
|
||||||
|
@ -58,7 +60,7 @@ func NewCachingResolver(cfg config.CachingConfig, redis *redis.Client) ChainedRe
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
func configureCaches(c *CachingResolver, cfg *config.CachingConfig) {
|
||||||
cleanupOption := expirationcache.WithCleanUpInterval(5 * time.Second)
|
cleanupOption := expirationcache.WithCleanUpInterval(defaultCachingCleanUpInterval)
|
||||||
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
|
maxSizeOption := expirationcache.WithMaxSize(uint(cfg.MaxItemsCount))
|
||||||
|
|
||||||
if cfg.Prefetching {
|
if cfg.Prefetching {
|
||||||
|
@ -91,6 +93,7 @@ func setupRedisCacheSubscriber(c *CachingResolver) {
|
||||||
// check if domain was queried > threshold in the time window
|
// check if domain was queried > threshold in the time window
|
||||||
func (r *CachingResolver) isPrefetchingDomain(cacheKey string) bool {
|
func (r *CachingResolver) isPrefetchingDomain(cacheKey string) bool {
|
||||||
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
|
cnt, _ := r.prefetchingNameCache.Get(cacheKey)
|
||||||
|
|
||||||
return cnt != nil && cnt.(int) > r.prefetchThreshold
|
return cnt != nil && cnt.(int) > r.prefetchThreshold
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,6 +111,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if response.Res.Rcode == dns.RcodeSuccess {
|
if response.Res.Rcode == dns.RcodeSuccess {
|
||||||
evt.Bus().Publish(evt.CachingDomainPrefetched, domainName)
|
evt.Bus().Publish(evt.CachingDomainPrefetched, domainName)
|
||||||
|
|
||||||
return cacheValue{response.Res.Answer, true}, time.Duration(r.adjustTTLs(response.Res.Answer)) * time.Second
|
return cacheValue{response.Res.Answer, true}, time.Duration(r.adjustTTLs(response.Res.Answer)) * time.Second
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -122,6 +126,7 @@ func (r *CachingResolver) onExpired(cacheKey string) (val interface{}, ttl time.
|
||||||
func (r *CachingResolver) Configuration() (result []string) {
|
func (r *CachingResolver) Configuration() (result []string) {
|
||||||
if r.maxCacheTimeSec < 0 {
|
if r.maxCacheTimeSec < 0 {
|
||||||
result = []string{"deactivated"}
|
result = []string{"deactivated"}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,12 +151,12 @@ func (r *CachingResolver) Configuration() (result []string) {
|
||||||
|
|
||||||
// Resolve checks if the current query result is already in the cache and returns it
|
// Resolve checks if the current query result is already in the cache and returns it
|
||||||
// or delegates to the next resolver
|
// or delegates to the next resolver
|
||||||
//nolint:gocognit,funlen
|
|
||||||
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
|
||||||
logger := withPrefix(request.Log, "caching_resolver")
|
logger := withPrefix(request.Log, "caching_resolver")
|
||||||
|
|
||||||
if r.maxCacheTimeSec < 0 {
|
if r.maxCacheTimeSec < 0 {
|
||||||
logger.Debug("skip cache")
|
logger.Debug("skip cache")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -547,6 +547,7 @@ var _ = Describe("CachingResolver", func() {
|
||||||
|
|
||||||
Eventually(func() error {
|
Eventually(func() error {
|
||||||
resp, err = sut.Resolve(request)
|
resp, err = sut.Resolve(request)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}, "50ms").Should(Succeed())
|
}, "50ms").Should(Succeed())
|
||||||
})
|
})
|
||||||
|
|
|
@ -105,6 +105,21 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNames []string) {
|
||||||
|
for _, answer := range answer {
|
||||||
|
if t, ok := answer.(*dns.PTR); ok {
|
||||||
|
hostName := strings.TrimSuffix(t.Ptr, ".")
|
||||||
|
clientNames = append(clientNames, hostName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(clientNames) == 0 {
|
||||||
|
clientNames = []string{fallbackIP.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
|
// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
|
||||||
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
|
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
|
||||||
// try client mapping first
|
// try client mapping first
|
||||||
|
@ -114,49 +129,40 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.externalResolver != nil {
|
if r.externalResolver == nil {
|
||||||
reverse, _ := dns.ReverseAddr(ip.String())
|
return []string{ip.String()}
|
||||||
|
|
||||||
resp, err := r.externalResolver.Resolve(&model.Request{
|
|
||||||
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
|
||||||
Log: logger,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("can't resolve client name: ", err)
|
|
||||||
return []string{ip.String()}
|
|
||||||
}
|
|
||||||
|
|
||||||
var clientNames []string
|
|
||||||
|
|
||||||
for _, answer := range resp.Res.Answer {
|
|
||||||
if t, ok := answer.(*dns.PTR); ok {
|
|
||||||
hostName := strings.TrimSuffix(t.Ptr, ".")
|
|
||||||
clientNames = append(clientNames, hostName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(clientNames) == 0 {
|
|
||||||
clientNames = []string{ip.String()}
|
|
||||||
}
|
|
||||||
|
|
||||||
// optional: if singleNameOrder is set, use only one name in the defined order
|
|
||||||
if len(r.singleNameOrder) > 0 {
|
|
||||||
for _, i := range r.singleNameOrder {
|
|
||||||
if i > 0 && int(i) <= len(clientNames) {
|
|
||||||
result = []string{clientNames[i-1]}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result = clientNames
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s) from external resolver")
|
|
||||||
} else {
|
|
||||||
result = []string{ip.String()}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reverse, _ := dns.ReverseAddr(ip.String())
|
||||||
|
|
||||||
|
resp, err := r.externalResolver.Resolve(&model.Request{
|
||||||
|
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
|
||||||
|
Log: logger,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("can't resolve client name: ", err)
|
||||||
|
|
||||||
|
return []string{ip.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
clientNames := extractClientNamesFromAnswer(resp.Res.Answer, ip)
|
||||||
|
|
||||||
|
// optional: if singleNameOrder is set, use only one name in the defined order
|
||||||
|
if len(r.singleNameOrder) > 0 {
|
||||||
|
for _, i := range r.singleNameOrder {
|
||||||
|
if i > 0 && int(i) <= len(clientNames) {
|
||||||
|
result = []string{clientNames[i-1]}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result = clientNames
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s) from external resolver")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,31 +51,42 @@ func (r *ConditionalUpstreamResolver) Configuration() (result []string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
|
||||||
|
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
||||||
|
domain := domainFromQuestion
|
||||||
|
|
||||||
|
if strings.Contains(domainFromQuestion, ".") {
|
||||||
|
// try with domain with and without sub-domains
|
||||||
|
for len(domain) > 0 {
|
||||||
|
if resolver, found := r.mapping[domain]; found {
|
||||||
|
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||||
|
|
||||||
|
return true, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := strings.Index(domain, "."); i >= 0 {
|
||||||
|
domain = domain[i+1:]
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if resolver, found := r.mapping["."]; found {
|
||||||
|
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
|
||||||
|
|
||||||
|
return true, resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve uses the conditional resolver to resolve the query
|
// Resolve uses the conditional resolver to resolve the query
|
||||||
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||||
logger := withPrefix(request.Log, "conditional_resolver")
|
logger := withPrefix(request.Log, "conditional_resolver")
|
||||||
|
|
||||||
if len(r.mapping) > 0 {
|
if len(r.mapping) > 0 {
|
||||||
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
|
resolved, resp, err := r.processRequest(request)
|
||||||
domain := domainFromQuestion
|
if resolved {
|
||||||
|
return resp, err
|
||||||
if !strings.Contains(domainFromQuestion, ".") {
|
|
||||||
if resolver, found := r.mapping["."]; found {
|
|
||||||
return r.internalResolve(resolver, domainFromQuestion, domain, request)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// try with domain with and without sub-domains
|
|
||||||
for len(domain) > 0 {
|
|
||||||
if resolver, found := r.mapping[domain]; found {
|
|
||||||
return r.internalResolve(resolver, domainFromQuestion, domain, request)
|
|
||||||
}
|
|
||||||
|
|
||||||
if i := strings.Index(domain, "."); i >= 0 {
|
|
||||||
domain = domain[i+1:]
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,54 @@ func (r *CustomDNSResolver) handleReverseDNS(request *model.Request) *model.Resp
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *CustomDNSResolver) processRequest(request *model.Request) *model.Response {
|
||||||
|
logger := withPrefix(request.Log, "custom_dns_resolver")
|
||||||
|
|
||||||
|
response := new(dns.Msg)
|
||||||
|
response.SetReply(request.Req)
|
||||||
|
|
||||||
|
question := request.Req.Question[0]
|
||||||
|
domain := util.ExtractDomain(question)
|
||||||
|
|
||||||
|
for len(domain) > 0 {
|
||||||
|
ips, found := r.mapping[domain]
|
||||||
|
if found {
|
||||||
|
for _, ip := range ips {
|
||||||
|
if isSupportedType(ip, question) {
|
||||||
|
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
|
||||||
|
response.Answer = append(response.Answer, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response.Answer) > 0 {
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"answer": util.AnswerToString(response.Answer),
|
||||||
|
"domain": domain,
|
||||||
|
}).Debugf("returning custom dns entry")
|
||||||
|
|
||||||
|
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mapping exists for this domain, but for another type
|
||||||
|
if !r.filterUnmappedTypes {
|
||||||
|
// go to next resolver
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// return NOERROR with empty result
|
||||||
|
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := strings.Index(domain, "."); i >= 0 {
|
||||||
|
domain = domain[i+1:]
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Resolve uses internal mapping to resolve the query
|
// Resolve uses internal mapping to resolve the query
|
||||||
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, error) {
|
||||||
logger := withPrefix(request.Log, "custom_dns_resolver")
|
logger := withPrefix(request.Log, "custom_dns_resolver")
|
||||||
|
@ -98,46 +146,9 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.mapping) > 0 {
|
if len(r.mapping) > 0 {
|
||||||
response := new(dns.Msg)
|
resp := r.processRequest(request)
|
||||||
response.SetReply(request.Req)
|
if resp != nil {
|
||||||
|
return resp, nil
|
||||||
question := request.Req.Question[0]
|
|
||||||
domain := util.ExtractDomain(question)
|
|
||||||
|
|
||||||
for len(domain) > 0 {
|
|
||||||
ips, found := r.mapping[domain]
|
|
||||||
if found {
|
|
||||||
for _, ip := range ips {
|
|
||||||
if isSupportedType(ip, question) {
|
|
||||||
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
|
|
||||||
response.Answer = append(response.Answer, rr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(response.Answer) > 0 {
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"answer": util.AnswerToString(response.Answer),
|
|
||||||
"domain": domain,
|
|
||||||
}).Debugf("returning custom dns entry")
|
|
||||||
|
|
||||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mapping exists for this domain, but for another type
|
|
||||||
if !r.filterUnmappedTypes {
|
|
||||||
// go to next resolver
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// return NOERROR with empty result
|
|
||||||
return &model.Response{Res: response, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if i := strings.Index(domain, "."); i >= 0 {
|
|
||||||
domain = domain[i+1:]
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,21 +76,7 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
||||||
domain := util.ExtractDomain(question)
|
domain := util.ExtractDomain(question)
|
||||||
|
|
||||||
for _, host := range r.hosts {
|
for _, host := range r.hosts {
|
||||||
if host.Hostname == domain {
|
response.Answer = append(response.Answer, r.processHostEntry(host, domain, question)...)
|
||||||
if isSupportedType(host.IP, question) {
|
|
||||||
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
|
||||||
response.Answer = append(response.Answer, rr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, alias := range host.Aliases {
|
|
||||||
if alias == domain {
|
|
||||||
if isSupportedType(host.IP, question) {
|
|
||||||
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
|
||||||
response.Answer = append(response.Answer, rr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(response.Answer) > 0 {
|
if len(response.Answer) > 0 {
|
||||||
|
@ -108,6 +94,26 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *HostsFileResolver) processHostEntry(host host, domain string, question dns.Question) (result []dns.RR) {
|
||||||
|
if host.Hostname == domain {
|
||||||
|
if isSupportedType(host.IP, question) {
|
||||||
|
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
||||||
|
result = append(result, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, alias := range host.Aliases {
|
||||||
|
if alias == domain {
|
||||||
|
if isSupportedType(host.IP, question) {
|
||||||
|
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
|
||||||
|
result = append(result, rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (r *HostsFileResolver) Configuration() (result []string) {
|
func (r *HostsFileResolver) Configuration() (result []string) {
|
||||||
if r.HostsFilePath != "" && len(r.hosts) != 0 {
|
if r.HostsFilePath != "" && len(r.hosts) != 0 {
|
||||||
result = append(result, fmt.Sprintf("hosts file path: %s", r.HostsFilePath))
|
result = append(result, fmt.Sprintf("hosts file path: %s", r.HostsFilePath))
|
||||||
|
@ -127,9 +133,7 @@ func NewHostsFileResolver(cfg config.HostsFileConfig) ChainedResolver {
|
||||||
refreshPeriod: time.Duration(cfg.RefreshPeriod),
|
refreshPeriod: time.Duration(cfg.RefreshPeriod),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.parseHostsFile()
|
if err := r.parseHostsFile(); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger := logger(hostsFileResolverLogger)
|
logger := logger(hostsFileResolverLogger)
|
||||||
logger.Warnf("cannot parse hosts file: %s, hosts file resolving is disabled", r.HostsFilePath)
|
logger.Warnf("cannot parse hosts file: %s, hosts file resolving is disabled", r.HostsFilePath)
|
||||||
r.HostsFilePath = ""
|
r.HostsFilePath = ""
|
||||||
|
@ -147,6 +151,8 @@ type host struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HostsFileResolver) parseHostsFile() error {
|
func (r *HostsFileResolver) parseHostsFile() error {
|
||||||
|
const minColumnCount = 2
|
||||||
|
|
||||||
if r.HostsFilePath == "" {
|
if r.HostsFilePath == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -177,7 +183,7 @@ func (r *HostsFileResolver) parseHostsFile() error {
|
||||||
fields = strings.Fields(trimmed[:end])
|
fields = strings.Fields(trimmed[:end])
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fields) < 2 {
|
if len(fields) < minColumnCount {
|
||||||
// Skip invalid entry
|
// Skip invalid entry
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -191,7 +197,7 @@ func (r *HostsFileResolver) parseHostsFile() error {
|
||||||
h.IP = net.ParseIP(fields[0])
|
h.IP = net.ParseIP(fields[0])
|
||||||
h.Hostname = fields[1]
|
h.Hostname = fields[1]
|
||||||
|
|
||||||
if len(fields) > 2 {
|
if len(fields) > minColumnCount {
|
||||||
for i := 2; i < len(fields); i++ {
|
for i := 2; i < len(fields); i++ {
|
||||||
h.Aliases = append(h.Aliases, fields[i])
|
h.Aliases = append(h.Aliases, fields[i])
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
@ -29,6 +28,7 @@ type MockResolver struct {
|
||||||
|
|
||||||
func (r *MockResolver) Configuration() []string {
|
func (r *MockResolver) Configuration() []string {
|
||||||
args := r.Called()
|
args := r.Called()
|
||||||
|
|
||||||
return args.Get(0).([]string)
|
return args.Get(0).([]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,6 +182,7 @@ func (t *MockUDPUpstreamServer) WithAnswerError(errorCode int) *MockUDPUpstreamS
|
||||||
|
|
||||||
func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response *dns.Msg)) *MockUDPUpstreamServer {
|
func (t *MockUDPUpstreamServer) WithAnswerFn(fn func(request *dns.Msg) (response *dns.Msg)) *MockUDPUpstreamServer {
|
||||||
t.answerFn = fn
|
t.answerFn = fn
|
||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,26 +196,33 @@ func (t *MockUDPUpstreamServer) Close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *MockUDPUpstreamServer) Start() config.Upstream {
|
func CreateConnection() *net.UDPConn {
|
||||||
a, err := net.ResolveUDPAddr("udp4", ":0")
|
a, err := net.ResolveUDPAddr("udp4", ":0")
|
||||||
util.FatalOnError("can't resolve address: ", err)
|
util.FatalOnError("can't resolve address: ", err)
|
||||||
|
|
||||||
ln, err := net.ListenUDP("udp4", a)
|
ln, err := net.ListenUDP("udp4", a)
|
||||||
util.FatalOnError("can't create connection: ", err)
|
util.FatalOnError("can't create connection: ", err)
|
||||||
|
|
||||||
t.ln = ln
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *MockUDPUpstreamServer) Start() config.Upstream {
|
||||||
|
ln := CreateConnection()
|
||||||
|
|
||||||
ladr := ln.LocalAddr().String()
|
ladr := ln.LocalAddr().String()
|
||||||
host := strings.Split(ladr, ":")[0]
|
host := strings.Split(ladr, ":")[0]
|
||||||
p, err := strconv.ParseUint(strings.Split(ladr, ":")[1], 10, 16)
|
p, err := config.ConvertPort(strings.Split(ladr, ":")[1])
|
||||||
|
|
||||||
util.FatalOnError("can't convert port: ", err)
|
util.FatalOnError("can't convert port: ", err)
|
||||||
|
|
||||||
port := uint16(p)
|
port := p
|
||||||
|
t.ln = ln
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
const bufferSize = 1024
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buffer := make([]byte, 1024)
|
buffer := make([]byte, bufferSize)
|
||||||
n, addr, err := ln.ReadFromUDP(buffer)
|
n, addr, err := ln.ReadFromUDP(buffer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -233,6 +241,7 @@ func (t *MockUDPUpstreamServer) Start() config.Upstream {
|
||||||
// nil should indicate an error
|
// nil should indicate an error
|
||||||
if response == nil {
|
if response == nil {
|
||||||
_, _ = ln.WriteToUDP([]byte("dummy"), addr)
|
_, _ = ln.WriteToUDP([]byte("dummy"), addr)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
const (
|
const (
|
||||||
upstreamDefaultCfgName = "default"
|
upstreamDefaultCfgName = "default"
|
||||||
parallelResolverLogger = "parallel_best_resolver"
|
parallelResolverLogger = "parallel_best_resolver"
|
||||||
|
resolverCount = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
|
// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer
|
||||||
|
@ -133,13 +134,14 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
||||||
|
|
||||||
if len(resolvers) == 1 {
|
if len(resolvers) == 1 {
|
||||||
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
|
logger.WithField("resolver", resolvers[0].resolver).Debug("delegating to resolver")
|
||||||
|
|
||||||
return resolvers[0].resolver.Resolve(request)
|
return resolvers[0].resolver.Resolve(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
r1, r2 := pickRandom(resolvers)
|
r1, r2 := pickRandom(resolvers)
|
||||||
logger.Debugf("using %s and %s as resolver", r1.resolver, r2.resolver)
|
logger.Debugf("using %s and %s as resolver", r1.resolver, r2.resolver)
|
||||||
|
|
||||||
ch := make(chan requestResponse, 2)
|
ch := make(chan requestResponse, resolverCount)
|
||||||
|
|
||||||
var collectedErrors []error
|
var collectedErrors []error
|
||||||
|
|
||||||
|
@ -152,7 +154,7 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
||||||
go resolve(request, r2, ch)
|
go resolve(request, r2, ch)
|
||||||
|
|
||||||
//nolint: gosimple
|
//nolint: gosimple
|
||||||
for len(collectedErrors) < 2 {
|
for len(collectedErrors) < resolverCount {
|
||||||
select {
|
select {
|
||||||
case result := <-ch:
|
case result := <-ch:
|
||||||
if result.err != nil {
|
if result.err != nil {
|
||||||
|
@ -163,6 +165,7 @@ func (r *ParallelBestResolver) Resolve(request *model.Request) (*model.Response,
|
||||||
"resolver": r1.resolver,
|
"resolver": r1.resolver,
|
||||||
"answer": util.AnswerToString(result.response.Res.Answer),
|
"answer": util.AnswerToString(result.response.Res.Answer),
|
||||||
}).Debug("using response from resolver")
|
}).Debug("using response from resolver")
|
||||||
|
|
||||||
return result.response, nil
|
return result.response, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,14 +184,17 @@ func pickRandom(resolvers []*upstreamResolverStatus) (resolver1, resolver2 *upst
|
||||||
}
|
}
|
||||||
|
|
||||||
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
|
func weightedRandom(in []*upstreamResolverStatus, exclude Resolver) *upstreamResolverStatus {
|
||||||
|
const errorWindowInSec = 60
|
||||||
|
|
||||||
var choices []weightedrand.Choice
|
var choices []weightedrand.Choice
|
||||||
|
|
||||||
for _, res := range in {
|
for _, res := range in {
|
||||||
var weight float64 = 60
|
var weight float64 = errorWindowInSec
|
||||||
|
|
||||||
if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour {
|
if time.Since(res.lastErrorTime.Load().(time.Time)) < time.Hour {
|
||||||
// reduce weight: consider last error time
|
// reduce weight: consider last error time
|
||||||
weight = math.Max(1, weight-(60-time.Since(res.lastErrorTime.Load().(time.Time)).Minutes()))
|
lastErrorTime := res.lastErrorTime.Load().(time.Time)
|
||||||
|
weight = math.Max(1, weight-(errorWindowInSec-time.Since(lastErrorTime).Minutes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if exclude != res.resolver {
|
if exclude != res.resolver {
|
||||||
|
|
|
@ -43,6 +43,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
return response
|
return response
|
||||||
})
|
})
|
||||||
DeferCleanup(slowTestUpstream.Close)
|
DeferCleanup(slowTestUpstream.Close)
|
||||||
|
@ -71,6 +72,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() {
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
return response
|
return response
|
||||||
})
|
})
|
||||||
DeferCleanup(slowTestUpstream.Close)
|
DeferCleanup(slowTestUpstream.Close)
|
||||||
|
|
|
@ -14,6 +14,7 @@ const (
|
||||||
cleanUpRunPeriod = 12 * time.Hour
|
cleanUpRunPeriod = 12 * time.Hour
|
||||||
queryLoggingResolverPrefix = "query_logging_resolver"
|
queryLoggingResolverPrefix = "query_logging_resolver"
|
||||||
logChanCap = 1000
|
logChanCap = 1000
|
||||||
|
defaultFlushPeriod = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueryLoggingResolver writes query information (question, answer, duration, ...)
|
// QueryLoggingResolver writes query information (question, answer, duration, ...)
|
||||||
|
@ -40,14 +41,15 @@ func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver {
|
||||||
case config.QueryLogTypeCsvClient:
|
case config.QueryLogTypeCsvClient:
|
||||||
writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays)
|
writer, err = querylog.NewCSVWriter(cfg.Target, true, cfg.LogRetentionDays)
|
||||||
case config.QueryLogTypeMysql:
|
case config.QueryLogTypeMysql:
|
||||||
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, 30*time.Second)
|
writer, err = querylog.NewDatabaseWriter("mysql", cfg.Target, cfg.LogRetentionDays, defaultFlushPeriod)
|
||||||
case config.QueryLogTypePostgresql:
|
case config.QueryLogTypePostgresql:
|
||||||
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, 30*time.Second)
|
writer, err = querylog.NewDatabaseWriter("postgresql", cfg.Target, cfg.LogRetentionDays, defaultFlushPeriod)
|
||||||
case config.QueryLogTypeConsole:
|
case config.QueryLogTypeConsole:
|
||||||
writer = querylog.NewLoggerWriter()
|
writer = querylog.NewLoggerWriter()
|
||||||
case config.QueryLogTypeNone:
|
case config.QueryLogTypeNone:
|
||||||
writer = querylog.NewNoneWriter()
|
writer = querylog.NewNoneWriter()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
retry.Attempts(uint(cfg.CreationAttempts)),
|
retry.Attempts(uint(cfg.CreationAttempts)),
|
||||||
|
@ -130,7 +132,7 @@ func (r *QueryLoggingResolver) writeLog() {
|
||||||
|
|
||||||
r.writer.Write(logEntry)
|
r.writer.Write(logEntry)
|
||||||
|
|
||||||
halfCap := cap(r.logChan) / 2
|
halfCap := cap(r.logChan) / 2 // nolint:gomnd
|
||||||
|
|
||||||
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
|
// if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.)
|
||||||
if len(r.logChan) > halfCap {
|
if len(r.logChan) > halfCap {
|
||||||
|
|
|
@ -80,6 +80,7 @@ func (r *RewriterResolver) Resolve(request *model.Request) (*model.Response, err
|
||||||
if response == NoResponse {
|
if response == NoResponse {
|
||||||
// Inner resolver had no response, continue with the normal chain
|
// Inner resolver had no response, continue with the normal chain
|
||||||
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")
|
||||||
|
|
||||||
return r.next.Resolve(request)
|
return r.next.Resolve(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,6 +129,7 @@ func (r *RewriterResolver) rewriteDomain(domain string) (string, string) {
|
||||||
for k, v := range r.rewrite {
|
for k, v := range r.rewrite {
|
||||||
if strings.HasSuffix(domain, "."+k) {
|
if strings.HasSuffix(domain, "."+k) {
|
||||||
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v
|
newDomain := strings.TrimSuffix(domain, "."+k) + "." + v
|
||||||
|
|
||||||
return newDomain, k
|
return newDomain, k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,6 +122,7 @@ var _ = Describe("RewriterResolver", func() {
|
||||||
mNext.On("Resolve", mock.Anything)
|
mNext.On("Resolve", mock.Anything)
|
||||||
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
mNext.ResolveFn = func(req *model.Request) (*model.Response, error) {
|
||||||
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||||
|
|
||||||
return mNextResponse, nil
|
return mNextResponse, nil
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -23,7 +23,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dnsContentType = "application/dns-message"
|
dnsContentType = "application/dns-message"
|
||||||
|
defaultTLSHandshakeTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint:gochecknoglobals
|
// nolint:gochecknoglobals
|
||||||
|
@ -66,7 +67,7 @@ func createUpstreamClient(cfg config.Upstream) upstreamClient {
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
TLSClientConfig: &tlsConfig,
|
TLSClientConfig: &tlsConfig,
|
||||||
TLSHandshakeTimeout: 5 * time.Second,
|
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
|
||||||
},
|
},
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
},
|
},
|
||||||
|
@ -246,6 +247,7 @@ func (r *UpstreamResolver) Resolve(request *model.Request) (response *model.Resp
|
||||||
"response_time_ms": rtt.Milliseconds(),
|
"response_time_ms": rtt.Milliseconds(),
|
||||||
}).Debugf("received response from upstream")
|
}).Debugf("received response from upstream")
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
},
|
},
|
||||||
retry.Attempts(retryAttempts),
|
retry.Attempts(retryAttempts),
|
||||||
|
|
|
@ -131,6 +131,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
||||||
response, err := util.NewMsgWithAnswer("example.com", 123, dns.Type(dns.TypeA), "123.124.122.122")
|
response, err := util.NewMsgWithAnswer("example.com", 123, dns.Type(dns.TypeA), "123.124.122.122")
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -25,6 +25,10 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxUDPBufferSize = 65535
|
||||||
|
)
|
||||||
|
|
||||||
// Server controls the endpoints for DNS and HTTP
|
// Server controls the endpoints for DNS and HTTP
|
||||||
type Server struct {
|
type Server struct {
|
||||||
dnsServers []*dns.Server
|
dnsServers []*dns.Server
|
||||||
|
@ -213,7 +217,7 @@ func createUDPServer(address string) (*dns.Server, error) {
|
||||||
NotifyStartedFunc: func() {
|
NotifyStartedFunc: func() {
|
||||||
logger().Infof("UDP server is up and running on address %s", address)
|
logger().Infof("UDP server is up and running on address %s", address)
|
||||||
},
|
},
|
||||||
UDPSize: 65535,
|
UDPSize: maxUDPBufferSize,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,7 +308,9 @@ func (s *Server) printConfiguration() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func toMB(b uint64) uint64 {
|
func toMB(b uint64) uint64 {
|
||||||
return b / 1024 / 1024
|
const bytesInKB = 1024
|
||||||
|
|
||||||
|
return b / bytesInKB / bytesInKB
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the server
|
// Start starts the server
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/api"
|
"github.com/0xERR0R/blocky/api"
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
@ -26,6 +27,7 @@ import (
|
||||||
const (
|
const (
|
||||||
dohMessageLimit = 512
|
dohMessageLimit = 512
|
||||||
dnsContentType = "application/dns-message"
|
dnsContentType = "application/dns-message"
|
||||||
|
corsMaxAge = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) registerAPIEndpoints(router *chi.Mux) {
|
func (s *Server) registerAPIEndpoints(router *chi.Mux) {
|
||||||
|
@ -109,6 +111,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,6 +123,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
||||||
b, err := resResponse.Res.Pack()
|
b, err := resResponse.Res.Pack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "can't serialize message: ", rw)
|
logAndResponseWithError(err, "can't serialize message: ", rw)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,6 +167,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "can't read request: ", rw)
|
logAndResponseWithError(err, "can't read request: ", rw)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,6 +194,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,6 +207,7 @@ func (s *Server) apiQuery(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "unable to marshal response: ", rw)
|
logAndResponseWithError(err, "unable to marshal response: ", rw)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,7 +296,7 @@ func configureCorsHandler(router *chi.Mux) {
|
||||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||||
ExposedHeaders: []string{"Link"},
|
ExposedHeaders: []string{"Link"},
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
MaxAge: 300,
|
MaxAge: int(corsMaxAge.Seconds()),
|
||||||
})
|
})
|
||||||
router.Use(crs.Handler)
|
router.Use(crs.Handler)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,7 @@ var _ = BeforeSuite(func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
return response
|
return response
|
||||||
})
|
})
|
||||||
DeferCleanup(googleMockUpstream.Close)
|
DeferCleanup(googleMockUpstream.Close)
|
||||||
|
@ -54,6 +55,7 @@ var _ = BeforeSuite(func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
return response
|
return response
|
||||||
})
|
})
|
||||||
DeferCleanup(fritzboxMockUpstream.Close)
|
DeferCleanup(fritzboxMockUpstream.Close)
|
||||||
|
@ -64,6 +66,7 @@ var _ = BeforeSuite(func() {
|
||||||
)
|
)
|
||||||
|
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
return response
|
return response
|
||||||
})
|
})
|
||||||
DeferCleanup(clientMockUpstream.Close)
|
DeferCleanup(clientMockUpstream.Close)
|
||||||
|
@ -153,6 +156,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
for res != nil {
|
for res != nil {
|
||||||
if t, ok := res.(*resolver.ClientNamesResolver); ok {
|
if t, ok := res.(*resolver.ClientNamesResolver); ok {
|
||||||
t.FlushCache()
|
t.FlushCache()
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if c, ok := res.(resolver.ChainedResolver); ok {
|
if c, ok := res.(resolver.ChainedResolver); ok {
|
||||||
|
@ -561,10 +565,9 @@ var _ = Describe("Running DNS server", func() {
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
|
|
||||||
errChan := make(chan error, 10)
|
errChan := make(chan error, 10)
|
||||||
|
|
||||||
// start server
|
// start server
|
||||||
go func() {
|
go server.Start(errChan)
|
||||||
server.Start(errChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
DeferCleanup(server.Stop)
|
DeferCleanup(server.Stop)
|
||||||
|
|
||||||
|
|
|
@ -193,7 +193,8 @@ func Chunks(s string, chunkSize int) []string {
|
||||||
|
|
||||||
// GenerateCacheKey return cacheKey by query type/domain
|
// GenerateCacheKey return cacheKey by query type/domain
|
||||||
func GenerateCacheKey(qType dns.Type, qName string) string {
|
func GenerateCacheKey(qType dns.Type, qName string) string {
|
||||||
b := make([]byte, 2+len(qName))
|
const qTypeLength = 2
|
||||||
|
b := make([]byte, qTypeLength+len(qName))
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(b, uint16(qType))
|
binary.BigEndian.PutUint16(b, uint16(qType))
|
||||||
copy(b[2:], strings.ToLower(qName))
|
copy(b[2:], strings.ToLower(qName))
|
||||||
|
@ -224,5 +225,6 @@ func CidrContainsIP(cidr string, ip net.IP) bool {
|
||||||
// ClientNameMatchesGroupName checks if a group with optional wildcards contains a client name
|
// ClientNameMatchesGroupName checks if a group with optional wildcards contains a client name
|
||||||
func ClientNameMatchesGroupName(group string, clientName string) bool {
|
func ClientNameMatchesGroupName(group string, clientName string) bool {
|
||||||
match, _ := filepath.Match(group, clientName)
|
match, _ := filepath.Match(group, clientName)
|
||||||
|
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue