From 01a8a402dc76994bd38e4dafbd46581df1198d69 Mon Sep 17 00:00:00 2001 From: Dimitri Herzog Date: Sun, 12 Jan 2020 18:23:35 +0100 Subject: [PATCH] initial commit --- .dockerignore | 4 + .github/workflows/ci-build.yml | 40 + .github/workflows/release.yml | 54 ++ .gitignore | 5 + .golangci.yml | 53 ++ Dockerfile | 38 + Makefile | 47 ++ config.yml | 47 ++ config/config.go | 148 ++++ config/config_test.go | 110 +++ docs/README.md | 127 +++ docs/blocky.svg | 768 ++++++++++++++++++ go.mod | 20 + go.sum | 82 ++ lists/list_cache.go | 253 ++++++ lists/list_cache_test.go | 132 +++ main.go | 88 ++ resolver/blocking_resolver.go | 220 +++++ resolver/blocking_resolver_test.go | 300 +++++++ resolver/caching_resolver.go | 120 +++ resolver/caching_resolver_test.go | 117 +++ resolver/client_names_resolver.go | 135 +++ resolver/client_names_resolver_test.go | 208 +++++ resolver/conditionall_upstream_resolver.go | 80 ++ .../conditionall_upstream_resolver_test.go | 86 ++ resolver/custom_dns_resolver.go | 94 +++ resolver/custom_dns_resolver_test.go | 101 +++ resolver/mocks.go | 85 ++ resolver/parallel_best_resolver.go | 113 +++ resolver/parallel_best_resolver_test.go | 39 + resolver/query_logging_resolver.go | 229 ++++++ resolver/query_logging_resolver_test.go | 211 +++++ resolver/resolver.go | 57 ++ resolver/upstream_resolver.go | 69 ++ resolver/upstream_resolver_test.go | 79 ++ server/server.go | 189 +++++ server/server_test.go | 323 ++++++++ testdata/config.yml | 44 + testdata/doubleclick.net.txt | 1 + testdata/heise.de.txt | 1 + testdata/www.bild.de.txt | 1 + testdata/youtube.com.txt | 1 + util/common.go | 80 ++ 43 files changed, 4999 insertions(+) create mode 100644 .dockerignore create mode 100644 .github/workflows/ci-build.yml create mode 100644 .github/workflows/release.yml create mode 100644 .gitignore create mode 100644 .golangci.yml create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 config.yml create mode 100644 config/config.go create mode 100644 config/config_test.go create mode 100644 docs/README.md create mode 100644 docs/blocky.svg create mode 100644 go.mod create mode 100644 go.sum create mode 100644 lists/list_cache.go create mode 100644 lists/list_cache_test.go create mode 100644 main.go create mode 100644 resolver/blocking_resolver.go create mode 100644 resolver/blocking_resolver_test.go create mode 100644 resolver/caching_resolver.go create mode 100644 resolver/caching_resolver_test.go create mode 100644 resolver/client_names_resolver.go create mode 100644 resolver/client_names_resolver_test.go create mode 100644 resolver/conditionall_upstream_resolver.go create mode 100644 resolver/conditionall_upstream_resolver_test.go create mode 100644 resolver/custom_dns_resolver.go create mode 100644 resolver/custom_dns_resolver_test.go create mode 100644 resolver/mocks.go create mode 100644 resolver/parallel_best_resolver.go create mode 100644 resolver/parallel_best_resolver_test.go create mode 100644 resolver/query_logging_resolver.go create mode 100644 resolver/query_logging_resolver_test.go create mode 100644 resolver/resolver.go create mode 100644 resolver/upstream_resolver.go create mode 100644 resolver/upstream_resolver_test.go create mode 100644 server/server.go create mode 100644 server/server_test.go create mode 100644 testdata/config.yml create mode 100644 testdata/doubleclick.net.txt create mode 100644 testdata/heise.de.txt create mode 100644 testdata/www.bild.de.txt create mode 100644 testdata/youtube.com.txt create mode 100644 util/common.go diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..7471db55 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +bin/ +.idea +.github +testdata/ \ No newline at end of file diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml new file mode 100644 index 00000000..12219b27 --- /dev/null +++ b/.github/workflows/ci-build.yml @@ -0,0 +1,40 @@ +name: CI Build +on: [push] +jobs: + + build: + name: Build + runs-on: ubuntu-latest + steps: + + - name: Set up Go 1.13 + uses: actions/setup-go@v1 + with: + go-version: 1.13 + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v1 + + - name: Get dependencies + run: | + go get -v -t -d ./... + if [ -f Gopkg.toml ]; then + curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh + dep ensure + fi + + - name: Install golangci-lint + run: curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b $(go env GOPATH)/bin v1.21.0 + + - name: Run golangci-lint + run: make lint + + - name: Test + run: make test + + - name: Build + run: make build + + - name: Docker images + run: make docker-build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..0c626645 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,54 @@ +name: Release + +on: + push: + tags: + - v* + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - name: Set up Go 1.13 + uses: actions/setup-go@v1 + with: + go-version: 1.13 + id: go + + - uses: actions/checkout@v1 + + - name: Build + run: make build + + - name: Test + run: make test + + - name: Build multiarch binaries + run: make buildMultiArchRelease + + - name: Upload amd64 binary to release + uses: svenstaro/upload-release-action@v1-release + with: + repo_token: ${{ secrets.GITHUB_TOKEN }} + file: bin/blocky_amd64 + asset_name: blocky_amd64 + tag: ${{ github.ref }} + overwrite: true + + - name: Upload arm32v6 binary to release + uses: svenstaro/upload-release-action@v1-release + with: + repo_token: ${{ secrets.GITHUB_TOKEN }} + file: bin/blocky_arm32v6 + asset_name: blocky_arm32v6 + tag: ${{ github.ref }} + overwrite: true + + - name: Build the Docker image and push + run: | + mkdir -p ~/.docker && echo "{\"experimental\": \"enabled\"}" > ~/.docker/config.json + echo ${{ secrets.DOCKER_PASSWORD }} | docker login -u ${{ secrets.DOCKER_USERNAME }} --password-stdin + make docker-build + make dockerManifestAndPush \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..6d0e921e --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.idea/ +*.iml +bin/ +config.yml +todo.txt \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..bfe660d1 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,53 @@ +linters: + enable: + - govet + - errcheck + - staticcheck + - unused + - gosimple + - structcheck + - varcheck + - ineffassign + - deadcode + - typecheck + - bodyclose + - golint + - stylecheck + - gosec + - interfacer + - unconvert + - dupl + - goconst + - gocyclo + - gocognit + - gofmt + - goimports + - maligned + - depguard + - misspell + - lll + - unparam + - dogsled + - funlen + - gochecknoglobals + - gochecknoinits + - gocritic + - godox + - nakedret + - prealloc + - whitespace + - wsl + + disable-all: false + presets: + - bugs + - unused + fast: false + +issues: + exclude-rules: + # Exclude some linters from running on tests files. + - path: _test\.go + linters: + - gochecknoglobals + - dupl diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..0f8e8121 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,38 @@ +# build stage +FROM golang:alpine AS build-env +RUN apk add --no-cache \ + git \ + make \ + gcc \ + libc-dev \ + tzdata \ + zip \ + ca-certificates + +ENV GO111MODULE=on \ + CGO_ENABLED=0 + +WORKDIR /src + +COPY go.mod . +COPY go.sum . +RUN go mod download + +# add source +ADD . . + +ARG opts +RUN env ${opts} make build + +# final stage +FROM scratch +COPY --from=build-env /src/bin/blocky /app/blocky + +# the timezone data: +COPY --from=build-env /usr/share/zoneinfo /usr/share/zoneinfo +# the tls certificates: +COPY --from=build-env /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ + +WORKDIR /app + +ENTRYPOINT ["/app/blocky"] \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..53e457ad --- /dev/null +++ b/Makefile @@ -0,0 +1,47 @@ +.PHONY: all clean build test lint run buildMultiArchRelease docker-build dockerManifestAndPush help +.DEFAULT_GOAL := help + +VERSION := $(shell git describe --always --tags) +BUILD_TIME=$(shell date '+%Y%m%d-%H%M%S') +DOCKER_IMAGE_NAME="spx01/blocky" +BINARY_NAME=blocky +BIN_OUT_DIR=bin + +all: test lint build ## Build binary (with tests) + +clean: ## cleans output directory + $(shell rm -rf $(BIN_OUT_DIR)/*) + +build: ## Build binary + go build -v -ldflags="-w -s -X main.version=${VERSION} -X main.buildTime=${BUILD_TIME}" -o $(BIN_OUT_DIR)/$(BINARY_NAME)$(BINARY_SUFFIX) + +test: ## run tests + go test -v -cover ./... + +lint: ## run golangcli-lint checks + $(shell go env GOPATH)/bin/golangci-lint run + +run: build ## Build and run binary + ./$(BIN_OUT_DIR)/$(BINARY_NAME) + +buildMultiArchRelease: ## builds binary for multiple archs + $(MAKE) build GOARCH=arm GOARM=6 BINARY_SUFFIX=_arm32v6 + $(MAKE) build GOARCH=amd64 BINARY_SUFFIX=_amd64 + +docker-build: ## Build multi arch docker images + docker build --build-arg opts="GOARCH=arm GOARM=6" --pull --tag ${DOCKER_IMAGE_NAME}:${VERSION}-arm32v6 . + docker build --build-arg opts="GOARCH=amd64" --pull --tag ${DOCKER_IMAGE_NAME}:${VERSION}-amd64 . + +dockerManifestAndPush: ## create manifest for multi arch images and push to docker hub + docker push ${DOCKER_IMAGE_NAME}:${VERSION}-arm32v6 + docker push ${DOCKER_IMAGE_NAME}:${VERSION}-amd64 + + docker manifest create ${DOCKER_IMAGE_NAME}:${VERSION} ${DOCKER_IMAGE_NAME}:${VERSION}-amd64 ${DOCKER_IMAGE_NAME}:${VERSION}-arm32v6 + docker manifest annotate ${DOCKER_IMAGE_NAME}:${VERSION} ${DOCKER_IMAGE_NAME}:${VERSION}-arm32v6 --os linux --arch arm + docker manifest push ${DOCKER_IMAGE_NAME}:${VERSION} --purge + docker manifest create ${DOCKER_IMAGE_NAME}:latest ${DOCKER_IMAGE_NAME}:${VERSION}-amd64 ${DOCKER_IMAGE_NAME}:${VERSION}-arm32v6 + docker manifest annotate ${DOCKER_IMAGE_NAME}:latest ${DOCKER_IMAGE_NAME}:${VERSION}-arm32v6 --os linux --arch arm + docker manifest push ${DOCKER_IMAGE_NAME}:latest --purge + +help: ## Shows help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' \ No newline at end of file diff --git a/config.yml b/config.yml new file mode 100644 index 00000000..db724587 --- /dev/null +++ b/config.yml @@ -0,0 +1,47 @@ +upstream: + externalResolvers: + # - udp:8.8.8.8 + # - udp:8.8.4.4 + - tcp-tls:1.1.1.1:853 + - tcp-tls:1.0.0.1:853 +customDNS: + mapping: + spx.duckdns.org: 192.168.178.3 +conditional: + mapping: + fritz.box: udp:192.168.178.1 +blocking: + blackLists: + ads: + - https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt + - https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts + - https://mirror1.malwaredomains.com/files/justdomains + - http://sysctl.org/cameleon/hosts + - https://zeustracker.abuse.ch/blocklist.php?download=domainblocklist + - https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt + special: + - https://hosts-file.net/ad_servers.txt + whiteLists: + ads: + - whitelist.txt + clientGroupsBlock: + default: + - ads + - special + Laptop-D.fritz.box: + - ads + blockType: zeroIp + +clientLookup: + upstream: udp:192.168.178.1 + singleNameOrder: + - 2 + - 1 + +#queryLog: +# dir: /tmp +# perClient: true +# logRetentionDays: 7 + +port: 55555 +logLevel: info \ No newline at end of file diff --git a/config/config.go b/config/config.go new file mode 100644 index 00000000..6b7dbdb7 --- /dev/null +++ b/config/config.go @@ -0,0 +1,148 @@ +package config + +import ( + "fmt" + "io/ioutil" + "log" + "net" + "reflect" + "strconv" + "strings" + + "gopkg.in/yaml.v2" +) + +// nolint:gochecknoglobals +var netDefaultPort = map[string]uint16{ + "udp": 53, + "tcp": 53, + "tcp-tls": 853, +} + +// Upstream is the definition of external DNS server +type Upstream struct { + Net string + Host string + Port uint16 +} + +func (u *Upstream) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + upstream, err := parseUpstream(s) + if err != nil { + return err + } + + *u = upstream + + return nil +} + +// parseUpstream creates new Upstream from passed string in format net:host:port +func parseUpstream(upstream string) (result Upstream, err error) { + if strings.Trim(upstream, " ") == "" { + return Upstream{}, nil + } + + parts := strings.Split(upstream, ":") + + if len(parts) < 2 || len(parts) > 3 { + err = fmt.Errorf("wrong configuration, couldn't parse input '%s', please enter net:host[:port]", upstream) + return + } + + net := strings.TrimSpace(parts[0]) + + if _, ok := netDefaultPort[net]; !ok { + err = fmt.Errorf("wrong configuration, couldn't parse net '%s', please user one of %s", + net, reflect.ValueOf(netDefaultPort).MapKeys()) + return + } + + var port uint16 + + host := strings.TrimSpace(parts[1]) + + if len(parts) == 3 { + var p int + p, err = strconv.Atoi(strings.TrimSpace(parts[2])) + + if err != nil { + err = fmt.Errorf("can't convert port to number %v", err) + return + } + + if p < 1 || p > 65535 { + err = fmt.Errorf("invalid port %d", p) + return + } + + port = uint16(p) + } else { + port = netDefaultPort[net] + } + + return Upstream{Net: net, Host: host, Port: port}, nil +} + +// main configuration +type Config struct { + Upstream UpstreamConfig `yaml:"upstream"` + CustomDNS CustomDNSConfig `yaml:"customDNS"` + Conditional ConditionalUpstreamConfig `yaml:"conditional"` + Blocking BlockingConfig `yaml:"blocking"` + ClientLookup ClientLookupConfig `yaml:"clientLookup"` + QueryLog QueryLogConfig `yaml:"queryLog"` + Port uint16 + LogLevel string `yaml:"logLevel"` +} + +type UpstreamConfig struct { + ExternalResolvers []Upstream `yaml:"externalResolvers"` +} + +type CustomDNSConfig struct { + Mapping map[string]net.IP `yaml:"mapping"` +} + +type ConditionalUpstreamConfig struct { + Mapping map[string]Upstream `yaml:"mapping"` +} + +type BlockingConfig struct { + BlackLists map[string][]string `yaml:"blackLists"` + WhiteLists map[string][]string `yaml:"whiteLists"` + ClientGroupsBlock map[string][]string `yaml:"clientGroupsBlock"` + BlockType string `yaml:"blockType"` +} + +type ClientLookupConfig struct { + Upstream Upstream `yaml:"upstream"` + SingleNameOrder []uint `yaml:"singleNameOrder"` +} + +type QueryLogConfig struct { + Dir string `yaml:"dir"` + PerClient bool `yaml:"perClient"` + LogRetentionDays uint64 `yaml:"logRetentionDays"` +} + +func NewConfig() Config { + cfg := Config{} + data, err := ioutil.ReadFile("config.yml") + + if err != nil { + log.Fatal("Can't read config file: ", err) + } + + err = yaml.UnmarshalStrict(data, &cfg) + if err != nil { + log.Fatal("wrong file structure: ", err) + } + + return cfg +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 00000000..493e4bed --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,110 @@ +package config + +import ( + "net" + "os" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewConfig(t *testing.T) { + err := os.Chdir("../testdata") + assert.NoError(t, err) + + cfg := NewConfig() + + assert.Equal(t, uint16(55555), cfg.Port) + assert.Len(t, cfg.Upstream.ExternalResolvers, 3) + assert.Equal(t, "8.8.8.8", cfg.Upstream.ExternalResolvers[0].Host) + assert.Equal(t, "8.8.4.4", cfg.Upstream.ExternalResolvers[1].Host) + assert.Equal(t, "1.1.1.1", cfg.Upstream.ExternalResolvers[2].Host) + assert.Len(t, cfg.CustomDNS.Mapping, 1) + assert.Equal(t, net.ParseIP("192.168.178.3"), cfg.CustomDNS.Mapping["my.duckdns.org"]) + assert.Len(t, cfg.Conditional.Mapping, 1) + assert.Equal(t, "192.168.178.1", cfg.ClientLookup.Upstream.Host) + assert.Equal(t, []uint{2, 1}, cfg.ClientLookup.SingleNameOrder) + assert.Len(t, cfg.Blocking.BlackLists, 2) + assert.Len(t, cfg.Blocking.WhiteLists, 1) + assert.Len(t, cfg.Blocking.ClientGroupsBlock, 2) +} + +var tests = []struct { + name string + args string + wantResult Upstream + wantErr bool +}{ + { + name: "udpWithPort", + args: "udp:4.4.4.4:531", + wantResult: Upstream{Net: "udp", Host: "4.4.4.4", Port: 531}, + }, + { + name: "udpDefault", + args: "udp:4.4.4.4", + wantResult: Upstream{Net: "udp", Host: "4.4.4.4", Port: 53}, + }, + { + name: "tcpWithPort", + args: "tcp:4.4.4.4:4711", + wantResult: Upstream{Net: "tcp", Host: "4.4.4.4", Port: 4711}, + }, + { + name: "tcpDefault", + args: "tcp:4.4.4.4", + wantResult: Upstream{Net: "tcp", Host: "4.4.4.4", Port: 53}, + }, + { + name: "tcpTlsDefault", + args: "tcp-tls:4.4.4.4", + wantResult: Upstream{Net: "tcp-tls", Host: "4.4.4.4", Port: 853}, + }, + { + name: "empty", + args: "", + wantResult: Upstream{}, + }, + { + name: "negativePort", + args: "tcp:4.4.4.4:-1", + wantErr: true, + }, + { + name: "invalidPort", + args: "tcp:4.4.4.4:65536", + wantErr: true, + }, + { + name: "notNumericPort", + args: "tcp:4.4.4.4:A53", + wantErr: true, + }, + { + name: "wrongProtocol", + args: "bla:4.4.4.4:53", + wantErr: true, + }, + { + name: "wrongFormat", + args: "tcp-4.4.4.4", + wantErr: true, + }, +} + +func Test_parseUpstream(t *testing.T) { + for _, tt := range tests { + rr := tt + t.Run(tt.name, func(t *testing.T) { + gotResult, err := parseUpstream(rr.args) + if (err != nil) != rr.wantErr { + t.Errorf("parseUpstream() error = %v, wantErr %v", err, rr.wantErr) + return + } + if !reflect.DeepEqual(gotResult, rr.wantResult) { + t.Errorf("parseUpstream() = %v, want %v", gotResult, rr.wantResult) + } + }) + } +} diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..ed3ecab0 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,127 @@ +![](https://github.com/0xERR0R/blocky/workflows/CI%20Build/badge.svg) ![](https://github.com/0xERR0R/blocky/workflows/Docker%20Image%20Release/badge.svg) + +

+ +

+ +# Blocky +Blocky is a DNS proxy for local network written in Go with following features: +- Blocking of DNS queries with external lists (Ad-block) with whitelisting + - Definition of black and white lists per client group (Kids, Smart home devices etc) -> for example: you can block some domains for you Kids and allow your network camera only domains from a whitelist + - periodical reload of external black and white lists +- Caching of DNS answers for queries -> improves DNS resolution speed and reduces amount of external DNS queries +- Custom DNS resolution for certain domain names +- Supports UDP, TCP and TCP over TLS DNS resolvers +- Delegates DNS query to 2 external resolver from a list of configured resolvers, uses the answer from the fastest one -> improves you privacy and resolution time +- Logging of all DNS queries per day / per client in a text file +- Simple configuration in a single file +- Only one binary in docker container, low memory footprint +- Runs fine on raspbery pi + +## Installation and configuration +Create `config.yml` file with your configuration: +```yml +upstream: + # these external DNS resolvers will be used. Blocky picks 2 random resolvers from the list for each query + # format for resolver: net:host:port. net could be tcp, udp or tcp-tls. If port is empty, default port will be used (53 for udp and tcp, 853 for tcp-tls) + externalResolvers: + - udp:8.8.8.8 + - udp:8.8.4.4 + - udp:1.1.1.1 + - tcp-tls:1.0.0.1:853 + +# optional: custom IP address for domain name (with all sub-domains) +# example: query "printer.lan" or "my.printer.lan" will return 192.168.178.3 +customDNS: + mapping: + printer.lan: 192.168.178.3 + +# optional: definition, which DNS resolver should be used for queries to the domain (with all sub-domains). +# Example: Query client.fritz.box will ask DNS server 192.168.178.1. This is necessary for local network, to resolve clients by host name +conditional: + mapping: + fritz.box: udp:192.168.178.1 + +# optional: use black and white lists to block queries (for example ads, trackers, adult pages etc.) +blocking: + # definition of blacklist groups. Can be external link (http/https) or local file + blackLists: + ads: + - https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt + - https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts + - https://mirror1.malwaredomains.com/files/justdomains + - http://sysctl.org/cameleon/hosts + - https://zeustracker.abuse.ch/blocklist.php?download=domainblocklist + - https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt + special: + - https://hosts-file.net/ad_servers.txt + # definition of whitelist groups. Attention: if the same group has black and whitelists, whitelists will be used to disable particular blacklist entries. If a group has only whitelist entries -> this means only domains from this list are allowed, all other domains will be blocked + whiteLists: + ads: + - whitelist.txt + # definition: which groups should be appied for which client + clientGroupsBlock: + # default will be used, if no special definition for a client name exists + default: + - ads + - special + # use client name or ip address + laptop.fritz.box: + - ads + # which response will be sent, if query is blocked: + # zeroIp: 0.0.0.0 will be returned (default) + # nxDomain: return NXDOMAIN as return code + blockType: zeroIp + +#optional: configuration of client name resolution +clientLookup: + # this DNS resolver will be used to perform reverse DNS lookup (typically local router) + upstream: udp:192.168.178.1 + # optional: some routers return multiple names for client (host name and user defined name). Define which single name should be used. + # Example: take second name if present, if not take first name + singleNameOrder: + - 2 + - 1 + +# optional: write query information (question, answer, client, duration etc) to daily csv file +queryLog: + # directory (should be mounted as volume in docker) + dir: /logs + # if true, write one file per client. Writes all queries to single file otherwise + perClient: true + # if > 0, deletes log files which are older than ... days + logRetentionDays: 7 + +# Port, should be 53 (UDP and TCP) +port: 53 +# Log level (one from debug, info, warn, error) +logLevel: info +``` + +### Run with docker +Start docker container with following `docker-compose.yml` file: +```yml +version: "2.1" +services: + blocky: + image: spx01/blocky + container_name: blocky + restart: unless-stopped + ports: + - "53:53/tcp" + - "53:53/udp" + environment: + - TZ=Europe/Berlin + volumes: + # config file + - ./config.yml:/app/config.yml + # write query logs in this directory. You can also use a volume + - ./logs:/logs +``` + +### Run standalone +Download binary file for your architecture, put it in one directory with config file. Please be aware, you must run the binary with root privileges if you want to use port 53 or 953. + +### Additional information +To print runtime configuration and statistics, you can send SIGUSR1 signal to running process: +`kill -s USR1 ` or `docker kill -s SIGUSR1 blocky` for docker setup diff --git a/docs/blocky.svg b/docs/blocky.svg new file mode 100644 index 00000000..7bb10e12 --- /dev/null +++ b/docs/blocky.svg @@ -0,0 +1,768 @@ + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..77bd39c0 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module blocky + +go 1.13 + +require ( + github.com/golang/protobuf v1.3.2 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/mattn/go-colorable v0.1.4 // indirect + github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect + github.com/miekg/dns v1.1.22 + github.com/onsi/ginkgo v1.11.0 // indirect + github.com/onsi/gomega v1.8.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/sirupsen/logrus v1.4.2 + github.com/stretchr/testify v1.4.0 + github.com/x-cray/logrus-prefixed-formatter v0.5.2 + golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v2 v2.2.4 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..5b43982f --- /dev/null +++ b/go.sum @@ -0,0 +1,82 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/miekg/dns v1.1.22 h1:Jm64b3bO9kP43ddLjL2EY3Io6bmy1qGb9Xxz6TqS6rc= +github.com/miekg/dns v1.1.22/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.11.0 h1:JAKSXpt1YjtLA7YpPiqO9ss6sNXEsPfSGdwN0UHqzrw= +github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.8.1 h1:C5Dqfs/LeauYDX0jJXIe2SWmwCbGzx9yF8C8xy3Lh34= +github.com/onsi/gomega v1.8.1/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392 h1:ACG4HJsFiNMf47Y4PeRoebLNy/2lXT9EtprMuTFWt1M= +golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe h1:6fAMxZRR6sl1Uq8U61gxU+kPTs2tR8uOySCbBP7BN/M= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190907020128-2ca718005c18 h1:xFbv3LvlvQAmbNJFCBKRv1Ccvnh9FVsW0FX2kTWWowE= +golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/lists/list_cache.go b/lists/list_cache.go new file mode 100644 index 00000000..76519953 --- /dev/null +++ b/lists/list_cache.go @@ -0,0 +1,253 @@ +package lists + +import ( + "bufio" + "fmt" + "io" + "net/http" + "os" + "sort" + "strings" + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +const ( + timeout = 30 * time.Second + listUpdatePeriod = 4 * time.Hour +) + +type Matcher interface { + // matches passed domain name against cached list entries + Match(domain string, groupsToCheck []string) (found bool, group string) + + // returns current configuration and stats + Configuration() []string +} + +type ListCache struct { + groupCaches map[string][]string + lock sync.RWMutex + + groupToLinks map[string][]string +} + +func (b *ListCache) Configuration() (result []string) { + result = append(result, "group links:") + for group, links := range b.groupToLinks { + result = append(result, fmt.Sprintf(" %s:", group)) + for _, link := range links { + result = append(result, fmt.Sprintf(" - %s", link)) + } + } + + result = append(result, "group caches:") + + var total int + + for group, cache := range b.groupCaches { + result = append(result, fmt.Sprintf(" %s: %d entries", group, len(cache))) + total += len(cache) + } + + result = append(result, fmt.Sprintf(" TOTAL: %d entries", total)) + + return +} + +// removes duplicates +func unique(in []string) []string { + keys := make(map[string]bool) + + var list []string + + for _, entry := range in { + if _, value := keys[entry]; !value { + keys[entry] = true + + list = append(list, entry) + } + } + + return list +} + +func NewListCache(groupToLinks map[string][]string) *ListCache { + groupCaches := make(map[string][]string) + + b := &ListCache{ + groupToLinks: groupToLinks, + groupCaches: groupCaches, + } + b.refresh() + + go periodicUpdate(b) + + return b +} + +// triggers periodical refresh (and download) of list entries +func periodicUpdate(cache *ListCache) { + ticker := time.NewTicker(listUpdatePeriod) + defer ticker.Stop() + + for { + <-ticker.C + cache.refresh() + } +} + +func logger() *logrus.Entry { + return logrus.WithField("prefix", "list_cache") +} + +// downloads and reads files with domain names and creates cache for them +func createCacheForGroup(links []string) []string { + cache := make([]string, 0) + + var wg sync.WaitGroup + + c := make(chan []string, len(links)) + + for _, link := range links { + wg.Add(1) + + go processFile(link, c, &wg) + } + + wg.Wait() + +Loop: + for { + select { + case res := <-c: + cache = append(cache, res...) + default: + close(c) + break Loop + } + } + + cache = unique(cache) + sort.Strings(cache) + + return cache +} + +func (b *ListCache) Match(domain string, groupsToCheck []string) (found bool, group string) { + b.lock.RLock() + defer b.lock.RUnlock() + + for _, g := range groupsToCheck { + if contains(domain, b.groupCaches[g]) { + return true, g + } + } + + return false, "" +} + +func contains(domain string, cache []string) bool { + idx := sort.SearchStrings(cache, domain) + if idx < len(cache) { + return cache[idx] == strings.ToLower(domain) + } + + return false +} + +func (b *ListCache) refresh() { + b.lock.Lock() + defer b.lock.Unlock() + + for group, links := range b.groupToLinks { + b.groupCaches[group] = createCacheForGroup(links) + + logger().WithFields(logrus.Fields{ + "group": group, + "total_count": len(b.groupCaches[group]), + }).Info("group import finished") + } +} + +func downloadFile(link string) (io.ReadCloser, error) { + client := http.Client{ + Timeout: timeout, + } + + logger().WithField("link", link).Info("starting download") + + //nolint:bodyclose + resp, err := client.Get(link) + + if err != nil { + return nil, err + } + + return resp.Body, nil +} + +func readFile(file string) (io.ReadCloser, error) { + logger().WithField("file", file).Info("starting processing of file") + file = strings.TrimPrefix(file, "file://") + + return os.Open(file) +} + +// downloads file (or reads local file) and writes file content as string array in the channel +func processFile(link string, ch chan<- []string, wg *sync.WaitGroup) { + defer wg.Done() + + result := make([]string, 0) + + var r io.ReadCloser + + var err error + + if strings.HasPrefix(link, "http") { + r, err = downloadFile(link) + } else { + r, err = readFile(link) + } + + if err != nil { + logger().Warn("error during file processing: ", err) + return + } + defer r.Close() + + var count int + + scanner := bufio.NewScanner(r) + + for scanner.Scan() { + line := scanner.Text() + // skip comments + if !strings.HasPrefix(line, "#") { + result = append(result, processLine(line)) + count++ + } + } + + if err := scanner.Err(); err != nil { + logger().Warn("can't parse file: ", err) + } else { + logger().WithFields(logrus.Fields{ + "source": link, + "count": count, + }).Info("file imported") + } + ch <- result +} + +// return only first column (see hosts format) +func processLine(line string) string { + parts := strings.Fields(line) + if len(parts) > 0 { + return parts[len(parts)-1] + } + + return "" +} diff --git a/lists/list_cache_test.go b/lists/list_cache_test.go new file mode 100644 index 00000000..76b6d54d --- /dev/null +++ b/lists/list_cache_test.go @@ -0,0 +1,132 @@ +package lists + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func Test_NoMatch_With_Empty_List(t *testing.T) { + file1 := tempFile("#empty file") + defer os.Remove(file1.Name()) + + lists := map[string][]string{ + "gr1": {file1.Name()}, + } + + sut := NewListCache(lists) + + found, group := sut.Match("google.com", []string{"gr1"}) + assert.Equal(t, false, found) + assert.Equal(t, "", group) +} + +func Test_Match_Download_Multiple_Groups(t *testing.T) { + server1 := testServer("blocked1.com\nblocked1a.com") + defer server1.Close() + + server2 := testServer("blocked2.com") + defer server2.Close() + + server3 := testServer("blocked3.com\nblocked1a.com") + defer server3.Close() + + lists := map[string][]string{ + "gr1": {server1.URL, server2.URL}, + "gr2": {server3.URL}, + } + + sut := NewListCache(lists) + + found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) + assert.Equal(t, true, found) + assert.Equal(t, "gr1", group) + + found, group = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) + assert.Equal(t, true, found) + assert.Equal(t, "gr1", group) + + found, group = sut.Match("blocked1a.com", []string{"gr2"}) + assert.Equal(t, true, found) + assert.Equal(t, "gr2", group) +} + +func Test_Match_Download_No_Group(t *testing.T) { + server1 := testServer("blocked1.com\nblocked1a.com") + defer server1.Close() + + server2 := testServer("blocked2.com") + defer server2.Close() + + server3 := testServer("blocked3.com\nblocked1a.com") + defer server3.Close() + + lists := map[string][]string{ + "gr1": {server1.URL, server2.URL}, + "gr2": {server3.URL}, + } + + sut := NewListCache(lists) + + found, group := sut.Match("blocked1.com", []string{}) + assert.Equal(t, false, found) + assert.Equal(t, "", group) +} + +func Test_Match_Files_Multiple_Groups(t *testing.T) { + file1 := tempFile("blocked1.com\nblocked1a.com") + defer os.Remove(file1.Name()) + + file2 := tempFile("blocked2.com") + defer os.Remove(file2.Name()) + + file3 := tempFile("blocked3.com\nblocked1a.com") + defer os.Remove(file3.Name()) + + lists := map[string][]string{ + "gr1": {file1.Name(), file2.Name()}, + "gr2": {"file://" + file3.Name()}, + } + + sut := NewListCache(lists) + + found, group := sut.Match("blocked1.com", []string{"gr1", "gr2"}) + assert.Equal(t, true, found) + assert.Equal(t, "gr1", group) + + found, group = sut.Match("blocked1a.com", []string{"gr1", "gr2"}) + assert.Equal(t, true, found) + assert.Equal(t, "gr1", group) + + found, group = sut.Match("blocked1a.com", []string{"gr2"}) + assert.Equal(t, true, found) + assert.Equal(t, "gr2", group) +} + +func tempFile(data string) *os.File { + f, err := ioutil.TempFile("", "prefix") + if err != nil { + log.Fatal(err) + } + + _, err = f.WriteString(data) + if err != nil { + log.Fatal(err) + } + + return f +} + +func testServer(data string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + _, err := rw.Write([]byte(data)) + if err != nil { + log.Fatal("can't write to buffer:", err) + } + })) +} diff --git a/main.go b/main.go new file mode 100644 index 00000000..02b9c5aa --- /dev/null +++ b/main.go @@ -0,0 +1,88 @@ +package main + +import ( + "blocky/config" + "blocky/server" + "os" + "os/signal" + "syscall" + + prefixed "github.com/x-cray/logrus-prefixed-formatter" + + "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" +) + +//nolint:gochecknoglobals +var version = "undefined" + +//nolint:gochecknoglobals +var buildTime = "undefined" + +func main() { + cfg := config.NewConfig() + configureLog(&cfg) + + printBanner() + + signals := make(chan os.Signal) + done := make(chan bool) + + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + server, err := server.NewServer(&cfg) + if err != nil { + log.Fatal("cant start server ", err) + } + + server.Start() + + go func() { + <-signals + log.Infof("Terminating...") + server.Stop() + done <- true + }() + + <-done +} + +func configureLog(cfg *config.Config) { + if level, err := log.ParseLevel(cfg.LogLevel); err != nil { + log.Fatalf("invalid log level %s %v", cfg.LogLevel, err) + } else { + log.SetLevel(level) + } + + logFormatter := &prefixed.TextFormatter{ + TimestampFormat: "2006-01-02 15:04:05", + FullTimestamp: true, + ForceFormatting: true, + ForceColors: true, + QuoteEmptyFields: true} + + logFormatter.SetColorScheme(&prefixed.ColorScheme{ + PrefixStyle: "blue+b", + TimestampStyle: "white+h", + }) + + logrus.SetFormatter(logFormatter) +} + +func printBanner() { + log.Info("_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/") + log.Info("_/ _/") + log.Info("_/ _/") + log.Info("_/ _/ _/ _/ _/") + log.Info("_/ _/_/_/ _/ _/_/ _/_/_/ _/ _/ _/ _/ _/") + log.Info("_/ _/ _/ _/ _/ _/ _/ _/_/ _/ _/ _/") + log.Info("_/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/ _/") + log.Info("_/ _/_/_/ _/ _/_/ _/_/_/ _/ _/ _/_/_/ _/") + log.Info("_/ _/ _/") + log.Info("_/ _/_/ _/") + log.Info("_/ _/") + log.Info("_/ _/") + log.Infof("_/ Version: %-18s Build time: %-18s _/", version, buildTime) + log.Info("_/ _/") + log.Info("_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/_/") +} diff --git a/resolver/blocking_resolver.go b/resolver/blocking_resolver.go new file mode 100644 index 00000000..edc7f504 --- /dev/null +++ b/resolver/blocking_resolver.go @@ -0,0 +1,220 @@ +package resolver + +import ( + "blocky/config" + "blocky/lists" + "blocky/util" + "fmt" + "net" + "reflect" + "sort" + "strings" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" +) + +const ( + BlockTTL = 6 * 60 * 60 +) + +type BlockType uint8 + +const ( + ZeroIP BlockType = iota + NxDomain +) + +func (b BlockType) String() string { + return [...]string{"ZeroIP", "NxDomain"}[b] +} + +// nolint:gochecknoglobals +var typeToZeroIP = map[uint16]net.IP{ + dns.TypeA: net.IPv4zero, + dns.TypeAAAA: net.IPv6zero, +} + +func resolveBlockType(cfg config.BlockingConfig) BlockType { + cfgBlockType := strings.TrimSpace(strings.ToUpper(cfg.BlockType)) + if cfgBlockType == "" || cfgBlockType == "ZEROIP" { + return ZeroIP + } + + if cfgBlockType == "NXDOMAIN" { + return NxDomain + } + + logrus.Fatalf("unknown blockType, please use one of: ZeroIP, NxDomain") + + return ZeroIP +} + +// checks request's question (domain name) against black and white lists +type BlockingResolver struct { + NextResolver + blacklistMatcher lists.Matcher + whitelistMatcher lists.Matcher + clientGroupsBlock map[string][]string + blockType BlockType + whitelistOnlyGroups []string +} + +func NewBlockingResolver(cfg config.BlockingConfig) ChainedResolver { + bt := resolveBlockType(cfg) + blacklistMatcher := lists.NewListCache(cfg.BlackLists) + whitelistMatcher := lists.NewListCache(cfg.WhiteLists) + whitelistOnlyGroups := determineWhitelistOnlyGroups(&cfg) + + return &BlockingResolver{ + blockType: bt, + clientGroupsBlock: cfg.ClientGroupsBlock, + blacklistMatcher: blacklistMatcher, + whitelistMatcher: whitelistMatcher, + whitelistOnlyGroups: whitelistOnlyGroups, + } +} + +// returns groups, which have only whitelist entries +func determineWhitelistOnlyGroups(cfg *config.BlockingConfig) (result []string) { + for g, links := range cfg.WhiteLists { + if len(links) > 0 { + if _, found := cfg.BlackLists[g]; !found { + result = append(result, g) + } + } + } + + sort.Strings(result) + + return +} + +// sets answer and/or return code for DNS response, if request should be blocked +func (r *BlockingResolver) handleBlocked(question dns.Question, response *dns.Msg) (*dns.Msg, error) { + switch r.blockType { + case ZeroIP: + rr, err := util.CreateAnswerFromQuestion(question, typeToZeroIP[question.Qtype], BlockTTL) + if err != nil { + return nil, err + } + + response.Answer = append(response.Answer, rr) + + case NxDomain: + response.Rcode = dns.RcodeNameError + } + + return response, nil +} + +func (r *BlockingResolver) Configuration() (result []string) { + if len(r.clientGroupsBlock) > 0 { + result = append(result, "clientGroupsBlock") + for key, val := range r.clientGroupsBlock { + result = append(result, fmt.Sprintf(" %s = \"%s\"", key, strings.Join(val, ";"))) + } + + result = append(result, fmt.Sprintf("blockType = \"%s\"", r.blockType)) + + result = append(result, "blacklist:") + for _, c := range r.blacklistMatcher.Configuration() { + result = append(result, fmt.Sprintf(" %s", c)) + } + + result = append(result, "whitelist:") + for _, c := range r.whitelistMatcher.Configuration() { + result = append(result, fmt.Sprintf(" %s", c)) + } + } else { + result = []string{"deactivated"} + } + + return +} + +func (r *BlockingResolver) Resolve(request *Request) (*Response, error) { + logger := withPrefix(request.Log, "blacklist_resolver") + groupsToCheck := r.groupsToCheckForClient(request) + + if len(groupsToCheck) > 0 { + logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request") + + for _, question := range request.Req.Question { + domain := util.ExtractDomain(question) + logger := logger.WithField("domain", domain) + whitelistOnlyAlowed := reflect.DeepEqual(groupsToCheck, r.whitelistOnlyGroups) + + if whitelisted, group := r.matches(groupsToCheck, r.whitelistMatcher, domain); whitelisted { + logger.WithField("group", group).Debugf("domain is whitelisted") + } else { + if whitelistOnlyAlowed { + logger.WithField("client_groups", groupsToCheck).Debug("white list only for client group(s), blocking...") + response := new(dns.Msg) + response.SetReply(request.Req) + resp, err := r.handleBlocked(question, response) + + return &Response{Res: resp, Reason: fmt.Sprintf("BLOCKED (WHITELIST ONLY)")}, err + } + if blocked, group := r.matches(groupsToCheck, r.blacklistMatcher, domain); blocked { + logger.WithField("group", group).Debug("domain is blocked") + + response := new(dns.Msg) + response.SetReply(request.Req) + resp, err := r.handleBlocked(question, response) + + return &Response{Res: resp, Reason: fmt.Sprintf("BLOCKED (%s)", group)}, err + } + } + } + } + + logger.WithField("next_resolver", r.next).Trace("go to next resolver") + + return r.next.Resolve(request) +} + +// returns groups which should be checked for client's request +func (r *BlockingResolver) groupsToCheckForClient(request *Request) (groups []string) { + // try client names + for _, cName := range request.ClientNames { + groupsByName, found := r.clientGroupsBlock[cName] + if found { + groups = append(groups, groupsByName...) + } + } + + // try IP + groupsByIP, found := r.clientGroupsBlock[request.ClientIP.String()] + + if found { + groups = append(groups, groupsByIP...) + } + + if len(groups) == 0 { + if !found { + // return default + groups = r.clientGroupsBlock["default"] + } + } + + sort.Strings(groups) + + return +} + +func (r *BlockingResolver) matches(groupsToCheck []string, m lists.Matcher, + domain string) (blocked bool, group string) { + if len(groupsToCheck) > 0 { + found, group := m.Match(domain, groupsToCheck) + if found { + return true, group + } + } + + return false, "" +} + +func (r BlockingResolver) String() string { + return fmt.Sprintf("blacklist resolver") +} diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go new file mode 100644 index 00000000..a8abc5a6 --- /dev/null +++ b/resolver/blocking_resolver_test.go @@ -0,0 +1,300 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "net" + "testing" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var clientBlock = map[string][]string{ + "default": {"gr0"}, + "client1": {"gr1", "gr2"}, + "altName": {"gr4"}, + "192.168.178.55": {"gr3"}, +} + +type MatcherMock struct { + mock.Mock +} + +func (b *MatcherMock) Configuration() (result []string) { + return +} + +func (b *MatcherMock) Match(domain string, groupsToCheck []string) (found bool, group string) { + args := b.Called(domain, groupsToCheck) + return args.Bool(0), args.String(1) +} + +func Test_Resolve_ClientName_IpZero(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "blocked1.com", []string{"gr1", "gr2", "gr3"}).Return(true, "gr1") + + w := &MatcherMock{} + w.On("Match", mock.Anything, mock.Anything).Return(false, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + whitelistMatcher: w, + } + + req := util.NewMsgWithQuestion("blocked1.com.", dns.TypeA) + + // A + resp, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"client1"}, + ClientIP: net.ParseIP("192.168.178.55"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "blocked1.com. 21600 IN A 0.0.0.0", resp.Res.Answer[0].String()) + b.AssertExpectations(t) + + // AAAA + req = util.NewMsgWithQuestion("blocked1.com.", dns.TypeAAAA) + resp, err = sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"client1"}, + ClientIP: net.ParseIP("192.168.178.55"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "blocked1.com. 21600 IN AAAA ::", resp.Res.Answer[0].String()) + b.AssertExpectations(t) +} + +func Test_Resolve_ClientIp_A_IpZero(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "blocked1.com", []string{"gr3"}).Return(true, "gr1") + + w := &MatcherMock{} + w.On("Match", mock.Anything, mock.Anything).Return(false, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + whitelistMatcher: w, + } + + req := util.NewMsgWithQuestion("blocked1.com.", dns.TypeA) + + resp, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.55"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "blocked1.com. 21600 IN A 0.0.0.0", resp.Res.Answer[0].String()) + b.AssertExpectations(t) +} + +func Test_Resolve_ClientWith2Names_A_IpZero(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "blocked1.com", []string{"gr1", "gr2", "gr3", "gr4"}).Return(true, "gr1") + + w := &MatcherMock{} + w.On("Match", mock.Anything, mock.Anything).Return(false, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + whitelistMatcher: w, + } + + req := util.NewMsgWithQuestion("blocked1.com.", dns.TypeA) + resp, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"client1", "altName"}, + ClientIP: net.ParseIP("192.168.178.55"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "blocked1.com. 21600 IN A 0.0.0.0", resp.Res.Answer[0].String()) + + b.AssertExpectations(t) +} + +func Test_Resolve_Default_A_IpZero(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "blocked1.com", []string{"gr0"}).Return(true, "gr1") + + w := &MatcherMock{} + w.On("Match", mock.Anything, mock.Anything).Return(false, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + whitelistMatcher: w, + } + + req := util.NewMsgWithQuestion("blocked1.com.", dns.TypeA) + resp, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.1"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "blocked1.com. 21600 IN A 0.0.0.0", resp.Res.Answer[0].String()) + b.AssertExpectations(t) +} + +func Test_Resolve_Default_Block_With_Whitelist(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "blocked1.com", []string{"gr0"}).Return(true, "gr") + + w := &MatcherMock{} + w.On("Match", "blocked1.com", []string{"gr0"}).Return(true, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + whitelistMatcher: w, + } + + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(new(Response), nil) + sut.Next(m) + + req := util.NewMsgWithQuestion("blocked1.com.", dns.TypeA) + _, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.1"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + w.AssertExpectations(t) + assert.Equal(t, 0, len(b.Calls)) +} + +func Test_Resolve_Whitelist_Only(t *testing.T) { + b := &MatcherMock{} + + w := &MatcherMock{} + w.On("Match", "whitelisted.com", []string{"gr0"}).Return(true, "gr0") + w.On("Match", mock.Anything, []string{"gr0"}).Return(false, "gr0") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + whitelistMatcher: w, + whitelistOnlyGroups: []string{"gr0"}, + } + + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(new(Response), nil) + sut.Next(m) + + req := util.NewMsgWithQuestion("whitelisted.com.", dns.TypeA) + _, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.1"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + w.AssertExpectations(t) + assert.Equal(t, 0, len(b.Calls)) + + req = new(dns.Msg) + req.SetQuestion("google.com.", dns.TypeA) + + resp, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.1"), + Log: logrus.NewEntry(logrus.New()), + }) + + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "google.com. 21600 IN A 0.0.0.0", resp.Res.Answer[0].String()) + w.AssertExpectations(t) + b.AssertExpectations(t) +} + +func Test_determineWhitelistOnlyGroups(t *testing.T) { + assert.Equal(t, []string{"w1"}, determineWhitelistOnlyGroups(&config.BlockingConfig{ + BlackLists: map[string][]string{}, + WhiteLists: map[string][]string{"w1": {"l1"}}, + })) + + assert.Equal(t, []string{"b1", "default"}, determineWhitelistOnlyGroups(&config.BlockingConfig{ + BlackLists: map[string][]string{ + "w1": {"y"}, + }, + WhiteLists: map[string][]string{ + "w1": {"l1"}, + "default": {"s1"}, + "b1": {"x"}}, + })) +} + +func Test_Resolve_Default_A_NxRecord(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "blocked1.com", []string{"gr0"}).Return(true, "gr1") + + w := &MatcherMock{} + w.On("Match", mock.Anything, mock.Anything).Return(false, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + blockType: NxDomain, + whitelistMatcher: w, + } + + req := util.NewMsgWithQuestion("blocked1.com.", dns.TypeA) + resp, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.1"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeNameError, resp.Res.Rcode) + b.AssertExpectations(t) +} + +func Test_Resolve_NoBlock(t *testing.T) { + b := &MatcherMock{} + b.On("Match", "example.com", []string{"gr0"}).Return(false, "") + + w := &MatcherMock{} + w.On("Match", mock.Anything, mock.Anything).Return(false, "gr1") + + sut := BlockingResolver{ + clientGroupsBlock: clientBlock, + blacklistMatcher: b, + blockType: NxDomain, + whitelistMatcher: w, + } + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(new(Response), nil) + sut.Next(m) + + req := util.NewMsgWithQuestion("example.com.", dns.TypeA) + _, err := sut.Resolve(&Request{ + Req: req, + ClientNames: []string{"unknown"}, + ClientIP: net.ParseIP("192.168.178.1"), + Log: logrus.NewEntry(logrus.New()), + }) + assert.NoError(t, err) + b.AssertExpectations(t) +} diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go new file mode 100644 index 00000000..f14f01bd --- /dev/null +++ b/resolver/caching_resolver.go @@ -0,0 +1,120 @@ +package resolver + +import ( + "blocky/util" + "fmt" + "time" + + "github.com/miekg/dns" + "github.com/patrickmn/go-cache" + log "github.com/sirupsen/logrus" +) + +// caches answers from dns queries with their TTL time, to avoid external resolver calls for recurrent queries +type CachingResolver struct { + NextResolver + cacheA *cache.Cache + cacheAAAA *cache.Cache +} + +const minTTL = 250 + +type Type uint8 + +const ( + A Type = iota + AAAA +) + +func NewCachingResolver() ChainedResolver { + return &CachingResolver{ + cacheA: cache.New(15*time.Minute, 5*time.Minute), + cacheAAAA: cache.New(15*time.Minute, 5*time.Minute), + } +} + +func (r *CachingResolver) getCache(queryType uint16) *cache.Cache { + switch queryType { + case dns.TypeA: + return r.cacheA + case dns.TypeAAAA: + return r.cacheAAAA + default: + log.Error("unknown type: ", queryType) + } + + return r.cacheA +} + +func (r *CachingResolver) Configuration() (result []string) { + result = append(result, fmt.Sprintf("A cache items count = %d", r.cacheA.ItemCount())) + result = append(result, fmt.Sprintf("AAAA cache items count = %d", r.cacheAAAA.ItemCount())) + + return +} + +func (r *CachingResolver) Resolve(request *Request) (response *Response, err error) { + logger := withPrefix(request.Log, "caching_resolver") + + resp := new(dns.Msg) + resp.SetReply(request.Req) + + for _, question := range request.Req.Question { + domain := util.ExtractDomain(question) + logger := logger.WithField("domain", domain) + + // we caching only A and AAAA queries + if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { + val, expiresAt, found := r.getCache(question.Qtype).GetWithExpiration(domain) + + if found { + logger.Debug("domain is cached") + + // calculate remaining TTL + remainingTTL := uint32(time.Until(expiresAt).Seconds()) + + resp.Answer = val.([]dns.RR) + for _, rr := range resp.Answer { + rr.Header().Ttl = remainingTTL + } + + return &Response{Res: resp, Reason: fmt.Sprintf("CACHED (ttl %d)", remainingTTL)}, nil + } + + logger.WithField("next_resolver", r.next).Debug("not in cache: go to next resolver") + response, err = r.next.Resolve(request) + + if err == nil { + var maxTTL uint32 + + for _, a := range response.Res.Answer { + // if TTL < mitTTL -> adjust the value, set minTTL + if a.Header().Ttl < minTTL { + logger.WithFields(log.Fields{ + "TTL": a.Header().Ttl, + "min_TTL": minTTL, + }).Debugf("ttl is < than min TTL, using min value") + + a.Header().Ttl = minTTL + } + + if maxTTL < a.Header().Ttl { + maxTTL = a.Header().Ttl + } + } + + // put value into cache + r.getCache(question.Qtype).Set(domain, response.Res.Answer, time.Duration(maxTTL)*time.Second) + } + } else { + logger.Debugf("not A/AAAA: go to next %s", r.next) + return r.next.Resolve(request) + } + } + + return response, err +} + +func (r CachingResolver) String() string { + return fmt.Sprintf("caching resolver") +} diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go new file mode 100644 index 00000000..a69f4f6a --- /dev/null +++ b/resolver/caching_resolver_test.go @@ -0,0 +1,117 @@ +package resolver + +import ( + "blocky/util" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Resolve_A_WithCachingAndMinTtl(t *testing.T) { + sut := NewCachingResolver() + m := &resolverMock{} + mockResp, err := util.NewMsgWithAnswer("example.com. 300 IN A 123.122.121.120") + + if err != nil { + t.Error(err) + } + + m.On("Resolve", mock.Anything).Return(&Response{Res: mockResp}, nil) + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + // first request + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "example.com. 300 IN A 123.122.121.120", resp.Res.Answer[0].String()) + assert.Equal(t, 1, len(m.Calls)) + + time.Sleep(500 * time.Millisecond) + + // second request + resp, err = sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + + // ttl is smaler + assert.Equal(t, "example.com. 299 IN A 123.122.121.120", resp.Res.Answer[0].String()) + + // still one call to resolver + assert.Equal(t, 1, len(m.Calls)) + + m.AssertExpectations(t) +} + +func Test_Resolve_AAAA_WithCachingAndMinTtl(t *testing.T) { + sut := NewCachingResolver() + m := &resolverMock{} + + mockResp, err := util.NewMsgWithAnswer("example.com. 123 IN AAAA 2001:0db8:85a3:08d3:1319:8a2e:0370:7344") + + if err != nil { + t.Error(err) + } + + m.On("Resolve", mock.Anything).Return(&Response{Res: mockResp}, nil) + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeAAAA), + Log: logrus.NewEntry(logrus.New()), + } + + // first request + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "example.com. 250 IN AAAA 2001:db8:85a3:8d3:1319:8a2e:370:7344", resp.Res.Answer[0].String()) + assert.Equal(t, 1, len(m.Calls)) + + time.Sleep(500 * time.Millisecond) + + // second request + resp, err = sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + + // ttl is smaler + assert.Equal(t, "example.com. 249 IN AAAA 2001:db8:85a3:8d3:1319:8a2e:370:7344", resp.Res.Answer[0].String()) + + // still one call to resolver + assert.Equal(t, 1, len(m.Calls)) + + m.AssertExpectations(t) +} + +func Test_Resolve_MX(t *testing.T) { + sut := NewCachingResolver() + m := &resolverMock{} + mockResp, err := util.NewMsgWithAnswer("google.de.\t180\tIN\tMX\t20\talt1.aspmx.l.google.com.") + + if err != nil { + t.Error(err) + } + + m.On("Resolve", mock.Anything).Return(&Response{Res: mockResp}, nil) + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("google.de.", dns.TypeMX), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "google.de.\t180\tIN\tMX\t20 alt1.aspmx.l.google.com.", resp.Res.Answer[0].String()) + assert.Equal(t, 1, len(m.Calls)) +} diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go new file mode 100644 index 00000000..ed335e9f --- /dev/null +++ b/resolver/client_names_resolver.go @@ -0,0 +1,135 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "fmt" + "net" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" +) + +// ClientNamesResolver tries to determine client name by asking responsible DNS server vie rDNS (reverse lookup) +type ClientNamesResolver struct { + cache *cache.Cache + externalResolver Resolver + singleNameOrder []uint + NextResolver +} + +func NewClientNamesResolver(cfg config.ClientLookupConfig) ChainedResolver { + var r Resolver + if (config.Upstream{}) != cfg.Upstream { + r = NewUpstreamResolver(cfg.Upstream) + } + + return &ClientNamesResolver{ + cache: cache.New(1*time.Hour, 1*time.Hour), + externalResolver: r, + singleNameOrder: cfg.SingleNameOrder, + } +} + +func (r *ClientNamesResolver) Configuration() (result []string) { + if r.externalResolver != nil { + result = append(result, fmt.Sprintf("singleNameOrder = \"%v\"", r.singleNameOrder)) + result = append(result, fmt.Sprintf("externalResolver = \"%s\"", r.externalResolver)) + result = append(result, fmt.Sprintf("cache item count = %d", r.cache.ItemCount())) + } else { + result = []string{"deactivated, use only IP address"} + } + + return +} + +func (r *ClientNamesResolver) Resolve(request *Request) (*Response, error) { + clientNames := r.getClientNames(request) + + request.ClientNames = clientNames + request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; ")) + + return r.next.Resolve(request) +} + +// returns names of client +func (r *ClientNamesResolver) getClientNames(request *Request) []string { + ip := request.ClientIP + c, found := r.cache.Get(ip.String()) + + if found { + if t, ok := c.([]string); ok { + return t + } + } + + names := r.resolveClientNames(ip, withPrefix(request.Log, "client_names_resolver")) + r.cache.Set(ip.String(), names, cache.DefaultExpiration) + + return names +} + +// performs reverse DNS lookup +func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) { + if r.externalResolver != nil { + reverse, err := dns.ReverseAddr(ip.String()) + + if err != nil { + logger.Warnf("can't create reverse address for %s", ip.String()) + return + } + + resp, err := r.externalResolver.Resolve(&Request{ + Req: util.NewMsgWithQuestion(reverse, dns.TypePTR), + Log: logger, + }) + + if err != nil { + logger.Error("can't resolve client name", err) + return + } + + var clientNames []string + + for _, answer := range resp.Res.Answer { + if t, ok := answer.(*dns.PTR); ok { + hostName := strings.TrimSuffix(t.Ptr, ".") + clientNames = append(clientNames, hostName) + } + } + + if len(clientNames) == 0 { + clientNames = []string{ip.String()} + } + + // optional: if singleNameOrder is set, use only one name in the defined order + if len(r.singleNameOrder) > 0 { + for _, i := range r.singleNameOrder { + if i > 0 && int(i) <= len(clientNames) { + result = []string{clientNames[i-1]} + break + } + } + } else { + result = clientNames + } + + logger.WithField("client_names", strings.Join(result, "; ")).Debug("resolved client name(s)") + } else { + result = []string{ip.String()} + } + + return result +} + +func (r ClientNamesResolver) String() string { + return fmt.Sprintf("client names resolver") +} + +// reset client name cache +func (r *ClientNamesResolver) FlushCache() { + r.cache.Flush() +} diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go new file mode 100644 index 00000000..9b777047 --- /dev/null +++ b/resolver/client_names_resolver_test.go @@ -0,0 +1,208 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "fmt" + "net" + "testing" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestClientNamesFromUpstream(t *testing.T) { + callCount := 0 + upstream := TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + callCount++ + r, err := dns.ReverseAddr("192.168.178.25") + assert.NoError(t, err) + + response, err := util.NewMsgWithAnswer(fmt.Sprintf("%s 300 IN PTR myhost", r)) + + assert.NoError(t, err) + return response + }) + + sut := NewClientNamesResolver(config.ClientLookupConfig{Upstream: upstream}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + // first request + request := &Request{ + ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err := sut.Resolve(request) + + assert.Equal(t, 1, callCount) + + m.AssertExpectations(t) + assert.NoError(t, err) + assert.Equal(t, "myhost", request.ClientNames[0]) + + // second request + request = &Request{ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err = sut.Resolve(request) + + // use cache -> call count 1 + assert.Equal(t, 1, callCount) + + m.AssertExpectations(t) + assert.NoError(t, err) + assert.Len(t, request.ClientNames, 1) + assert.Equal(t, "myhost", request.ClientNames[0]) +} + +func TestClientInfoFromUpstreamSingleNameWithOrder(t *testing.T) { + callCount := 0 + upstream := TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + callCount++ + r, err := dns.ReverseAddr("192.168.178.25") + assert.NoError(t, err) + + response, err := util.NewMsgWithAnswer(fmt.Sprintf("%s 300 IN PTR myhost", r)) + + assert.NoError(t, err) + return response + }) + + sut := NewClientNamesResolver(config.ClientLookupConfig{ + Upstream: upstream, + SingleNameOrder: []uint{2, 1}}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + // first request + request := &Request{ + ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err := sut.Resolve(request) + + assert.Equal(t, 1, callCount) + + m.AssertExpectations(t) + assert.NoError(t, err) + assert.Equal(t, "myhost", request.ClientNames[0]) + + // second request + request = &Request{ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err = sut.Resolve(request) + + // use cache -> call count 1 + assert.Equal(t, 1, callCount) + + m.AssertExpectations(t) + assert.NoError(t, err) + assert.Len(t, request.ClientNames, 1) + assert.Equal(t, "myhost", request.ClientNames[0]) +} + +func TestClientInfoFromUpstreamMultipleNames(t *testing.T) { + upstream := TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + r, err := dns.ReverseAddr("192.168.178.25") + assert.NoError(t, err) + + rr1, err := dns.NewRR(fmt.Sprintf("%s 300 IN PTR myhost1", r)) + assert.NoError(t, err) + rr2, err := dns.NewRR(fmt.Sprintf("%s 300 IN PTR myhost2", r)) + assert.NoError(t, err) + + msg := new(dns.Msg) + msg.Answer = []dns.RR{rr1, rr2} + + return msg + }) + + sut := NewClientNamesResolver(config.ClientLookupConfig{Upstream: upstream}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + request := &Request{ + ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err := sut.Resolve(request) + + m.AssertExpectations(t) + assert.NoError(t, err) + assert.Len(t, request.ClientNames, 2) + assert.Equal(t, "myhost1", request.ClientNames[0]) + assert.Equal(t, "myhost2", request.ClientNames[1]) +} + +func TestClientInfoFromUpstreamMultipleNamesSingleNameOrder(t *testing.T) { + upstream := TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + r, err := dns.ReverseAddr("192.168.178.25") + assert.NoError(t, err) + + rr1, err := dns.NewRR(fmt.Sprintf("%s 300 IN PTR myhost1", r)) + assert.NoError(t, err) + rr2, err := dns.NewRR(fmt.Sprintf("%s 300 IN PTR myhost2", r)) + assert.NoError(t, err) + + msg := new(dns.Msg) + msg.Answer = []dns.RR{rr1, rr2} + + return msg + }) + + sut := NewClientNamesResolver(config.ClientLookupConfig{ + Upstream: upstream, + SingleNameOrder: []uint{2, 1}}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + request := &Request{ + ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err := sut.Resolve(request) + + m.AssertExpectations(t) + assert.NoError(t, err) + assert.Len(t, request.ClientNames, 1) + assert.Equal(t, "myhost2", request.ClientNames[0]) +} + +func TestClientInfoFromUpstreamNotFound(t *testing.T) { + upstream := TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + msg := new(dns.Msg) + msg.SetRcode(request, dns.RcodeNameError) + + return msg + }) + + sut := NewClientNamesResolver(config.ClientLookupConfig{Upstream: upstream}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + request := &Request{ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err := sut.Resolve(request) + + assert.NoError(t, err) + assert.Len(t, request.ClientNames, 1) + assert.Equal(t, "192.168.178.25", request.ClientNames[0]) +} + +func TestClientInfoWithoutUpstream(t *testing.T) { + sut := NewClientNamesResolver(config.ClientLookupConfig{}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + request := &Request{ClientIP: net.ParseIP("192.168.178.25"), + Log: logrus.NewEntry(logrus.New())} + _, err := sut.Resolve(request) + + assert.NoError(t, err) + assert.Len(t, request.ClientNames, 1) + assert.Equal(t, "192.168.178.25", request.ClientNames[0]) +} diff --git a/resolver/conditionall_upstream_resolver.go b/resolver/conditionall_upstream_resolver.go new file mode 100644 index 00000000..6345ca87 --- /dev/null +++ b/resolver/conditionall_upstream_resolver.go @@ -0,0 +1,80 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +// ConditionalUpstreamResolver delegates DNS question to other DNS resolver dependent on domain name in question +type ConditionalUpstreamResolver struct { + NextResolver + mapping map[string]Resolver +} + +func NewConditionalUpstreamResolver(cfg config.ConditionalUpstreamConfig) ChainedResolver { + m := make(map[string]Resolver) + for domain, upstream := range cfg.Mapping { + m[strings.ToLower(domain)] = NewUpstreamResolver(upstream) + } + + return &ConditionalUpstreamResolver{mapping: m} +} + +func (r *ConditionalUpstreamResolver) Configuration() (result []string) { + if len(r.mapping) > 0 { + for key, val := range r.mapping { + result = append(result, fmt.Sprintf("%s = \"%s\"", key, val)) + } + } else { + result = []string{"deactivated"} + } + + return +} + +func (r *ConditionalUpstreamResolver) Resolve(request *Request) (*Response, error) { + logger := withPrefix(request.Log, "conditional_resolver") + + if len(r.mapping) > 0 { + for _, question := range request.Req.Question { + domain := util.ExtractDomain(question) + + // try with domain with and without sub-domains + for len(domain) > 0 { + r, found := r.mapping[domain] + if found { + response, err := r.Resolve(request) + if err == nil { + response.Reason = "CONDITIONAL" + } + + logger.WithFields(logrus.Fields{ + "answer": util.AnswerToString(response.Res.Answer), + "domain": domain, + "upstream": r, + }).Debugf("received response from conditional upstream") + + return response, err + } + + if i := strings.Index(domain, "."); i >= 0 { + domain = domain[i+1:] + } else { + break + } + } + } + } + + logger.WithField("next_resolver", r.next).Trace("go to next resolver") + + return r.next.Resolve(request) +} + +func (r ConditionalUpstreamResolver) String() string { + return fmt.Sprintf("conditional resolver") +} diff --git a/resolver/conditionall_upstream_resolver_test.go b/resolver/conditionall_upstream_resolver_test.go new file mode 100644 index 00000000..e95d22e7 --- /dev/null +++ b/resolver/conditionall_upstream_resolver_test.go @@ -0,0 +1,86 @@ +package resolver + +import ( + "blocky/util" + "testing" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func setup() (sut *ConditionalUpstreamResolver, cond *resolverMock, next *resolverMock) { + cond = &resolverMock{} + next = &resolverMock{} + sut = &ConditionalUpstreamResolver{ + mapping: map[string]Resolver{ + "fritz.box": cond, + "other.box": cond, + }, + } + + cond.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + next.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(next) + + return +} + +func Test_Resolve_Conditional_Exact(t *testing.T) { + sut, conditionalResolver, nextResolver := setup() + + request := &Request{ + Req: util.NewMsgWithQuestion("fritz.box.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, "CONDITIONAL", resp.Reason) + conditionalResolver.AssertExpectations(t) + nextResolver.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Conditional_ExactLast(t *testing.T) { + sut, conditionalResolver, nextResolver := setup() + + request := &Request{ + Req: util.NewMsgWithQuestion("other.box.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, "CONDITIONAL", resp.Reason) + conditionalResolver.AssertExpectations(t) + nextResolver.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Conditional_Subdomain(t *testing.T) { + sut, conditionalResolver, nextResolver := setup() + + request := &Request{ + Req: util.NewMsgWithQuestion("test.fritz.box.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + _, err := sut.Resolve(request) + assert.NoError(t, err) + conditionalResolver.AssertExpectations(t) + nextResolver.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Conditional_Not_Match(t *testing.T) { + sut, conditionalResolver, nextResolver := setup() + + request := &Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + _, err := sut.Resolve(request) + assert.NoError(t, err) + nextResolver.AssertExpectations(t) + conditionalResolver.AssertNotCalled(t, "Resolve", mock.Anything) +} diff --git a/resolver/custom_dns_resolver.go b/resolver/custom_dns_resolver.go new file mode 100644 index 00000000..e07aa905 --- /dev/null +++ b/resolver/custom_dns_resolver.go @@ -0,0 +1,94 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "fmt" + "net" + "strings" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" +) + +const customDNSTTL = 60 * 60 + +// CustomDNSResolver resolves passed domain name to ip address defined in domain-IP map +type CustomDNSResolver struct { + NextResolver + mapping map[string]net.IP +} + +func NewCustomDNSResolver(cfg config.CustomDNSConfig) ChainedResolver { + m := make(map[string]net.IP) + for url, ip := range cfg.Mapping { + m[strings.ToLower(url)] = ip + } + + return &CustomDNSResolver{mapping: m} +} + +func (r *CustomDNSResolver) Configuration() (result []string) { + if len(r.mapping) > 0 { + for key, val := range r.mapping { + result = append(result, fmt.Sprintf("%s = \"%s\"", key, val)) + } + } else { + result = []string{"deactivated"} + } + + return +} + +func (r *CustomDNSResolver) Resolve(request *Request) (*Response, error) { + logger := withPrefix(request.Log, "custom_dns_resolver") + + if len(r.mapping) > 0 { + for _, question := range request.Req.Question { + domain := util.ExtractDomain(question) + for len(domain) > 0 { + ip, found := r.mapping[domain] + if found { + response := new(dns.Msg) + response.SetReply(request.Req) + + if (ip.To4() != nil && question.Qtype == dns.TypeA) || + (strings.Contains(ip.String(), ":") && question.Qtype == dns.TypeAAAA) { + rr, err := util.CreateAnswerFromQuestion(question, ip, customDNSTTL) + + if err == nil { + response.Answer = append(response.Answer, rr) + + logger.WithFields(logrus.Fields{ + "answer": util.AnswerToString(response.Answer), + "domain": domain, + }).Debugf("returning custom dns entry") + + return &Response{Res: response, Reason: "CUSTOM DNS"}, nil + } + + return nil, err + } + + response.Rcode = dns.RcodeNameError + + return &Response{Res: response, Reason: "CUSTOM DNS"}, nil + } + + if i := strings.Index(domain, "."); i >= 0 { + domain = domain[i+1:] + } else { + break + } + } + } + } + + logger.WithField("resolver", r.next).Trace("go to next resolver") + + return r.next.Resolve(request) +} + +func (r CustomDNSResolver) String() string { + return fmt.Sprintf("custom resolver") +} diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go new file mode 100644 index 00000000..46edabdc --- /dev/null +++ b/resolver/custom_dns_resolver_test.go @@ -0,0 +1,101 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "net" + "testing" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Resolve_Custom_Name_Ip4_A(t *testing.T) { + sut := NewCustomDNSResolver(config.CustomDNSConfig{ + Mapping: map[string]net.IP{"custom.domain": net.ParseIP("192.168.143.123")}}) + m := &resolverMock{} + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("custom.domain.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "custom.domain. 3600 IN A 192.168.143.123", resp.Res.Answer[0].String()) + m.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Custom_Name_Ip4_AAAA(t *testing.T) { + sut := NewCustomDNSResolver(config.CustomDNSConfig{ + Mapping: map[string]net.IP{"custom.domain": net.ParseIP("192.168.143.123")}}) + m := &resolverMock{} + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("custom.domain.", dns.TypeAAAA), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeNameError, resp.Res.Rcode) + m.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Custom_Name_Ip6_AAAA(t *testing.T) { + sut := NewCustomDNSResolver(config.CustomDNSConfig{ + Mapping: map[string]net.IP{"custom.domain": net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}}) + m := &resolverMock{} + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("custom.domain.", dns.TypeAAAA), + Log: logrus.NewEntry(logrus.New()), + } + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "custom.domain. 3600 IN AAAA 2001:db8:85a3::8a2e:370:7334", resp.Res.Answer[0].String()) + m.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Custom_Name_Subdomain(t *testing.T) { + sut := NewCustomDNSResolver(config.CustomDNSConfig{ + Mapping: map[string]net.IP{"custom.domain": net.ParseIP("192.168.143.123")}}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("ABC.CUSTOM.DOMAIN.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "ABC.CUSTOM.DOMAIN. 3600 IN A 192.168.143.123", resp.Res.Answer[0].String()) + m.AssertNotCalled(t, "Resolve", mock.Anything) +} + +func Test_Resolve_Delegate_Next(t *testing.T) { + sut := NewCustomDNSResolver(config.CustomDNSConfig{ + Mapping: map[string]net.IP{"custom.domain": net.ParseIP("192.168.143.123")}}) + m := &resolverMock{} + m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) + sut.Next(m) + + request := &Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + _, _ = sut.Resolve(request) + + m.AssertExpectations(t) +} diff --git a/resolver/mocks.go b/resolver/mocks.go new file mode 100644 index 00000000..ba986fe8 --- /dev/null +++ b/resolver/mocks.go @@ -0,0 +1,85 @@ +package resolver + +import ( + "blocky/config" + "net" + "strconv" + "strings" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" +) + +type resolverMock struct { + mock.Mock + NextResolver +} + +func (r *resolverMock) Configuration() (result []string) { + return +} + +func (r *resolverMock) Resolve(req *Request) (*Response, error) { + args := r.Called(req) + return args.Get(0).(*Response), args.Error(1) +} + +func TestUDPUpstreamWithResponse(response *dns.Msg) config.Upstream { + return TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + return response + }) +} + +func TestUDPUpstream(fn func(request *dns.Msg) (response *dns.Msg)) config.Upstream { + a, err := net.ResolveUDPAddr("udp4", ":0") + if err != nil { + log.Fatal("can't resolve address: ", err) + } + + ln, err := net.ListenUDP("udp4", a) + if err != nil { + log.Fatal("can't create connection: ", err) + } + + ladr := ln.LocalAddr().String() + host := strings.Split(ladr, ":")[0] + p, err := strconv.Atoi(strings.Split(ladr, ":")[1]) + + if err != nil { + log.Fatal("can't convert port: ", err) + } + + port := uint16(p) + + go func() { + for { + buffer := make([]byte, 1024) + n, addr, err := ln.ReadFromUDP(buffer) + if err != nil { + log.Fatal("error on reading from udp: ", err) + } + + msg := new(dns.Msg) + err = msg.Unpack(buffer[0 : n-1]) + if err != nil { + log.Fatal("can't deserialize message: ", err) + } + + response := fn(msg) + response.SetReply(msg) + + b, err := response.Pack() + if err != nil { + log.Fatal("can't serialize message: ", err) + } + + _, err = ln.WriteToUDP(b, addr) + if err != nil { + log.Fatal("can't write to UDP: ", err) + } + } + }() + + return config.Upstream{Net: "udp", Host: host, Port: port} +} diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go new file mode 100644 index 00000000..3b766695 --- /dev/null +++ b/resolver/parallel_best_resolver.go @@ -0,0 +1,113 @@ +package resolver + +import ( + "blocky/util" + "fmt" + "math/rand" + + "github.com/sirupsen/logrus" +) + +// ParallelBestResolver delegates the DNS message to 2 upstream resolvers and returns the fastest answer +type ParallelBestResolver struct { + resolvers []Resolver +} + +func NewParallelBestResolver(resolvers []Resolver) Resolver { + return &ParallelBestResolver{resolvers: resolvers} +} + +func (r *ParallelBestResolver) Configuration() (result []string) { + result = append(result, "upstream resolvers:") + for _, res := range r.resolvers { + result = append(result, fmt.Sprintf("- %s", res)) + } + + return +} + +func (r *ParallelBestResolver) Resolve(request *Request) (*Response, error) { + logger := request.Log.WithField("prefix", "parallel_best_resolver") + + r1, r2 := r.pickRandom() + logger.Debugf("using %s and %s as resolver", r1, r2) + + ch1 := make(chan struct { + *Response + error + }) + ch2 := make(chan struct { + *Response + error + }) + + var err1, err2 error + + logger.WithField("resolver", r1).Debug("delegating to resolver") + + go resolve(request, r1, ch1) + + logger.WithField("resolver", r2).Debug("delegating to resolver") + + go resolve(request, r2, ch2) + + for err1 == nil || err2 == nil { + select { + case msg1 := <-ch1: + if msg1.error != nil { + err1 = msg1.error + ch1 = nil + + logger.WithField("resolver", r1).Debug("resolution failed from resolver, cause: ", msg1.error) + } else { + logger.WithFields(logrus.Fields{ + "resolver": r1, + "answer": util.AnswerToString(msg1.Response.Res.Answer), + }).Debug("using response from resolver") + return msg1.Response, nil + } + case msg2 := <-ch2: + if msg2.error != nil { + err2 = msg2.error + ch2 = nil + + logger.WithField("resolver", r2).Debug("resolution failed from resolver, cause: ", msg2.error) + } else { + logger.WithFields(logrus.Fields{ + "resolver": r2, + "answer": util.AnswerToString(msg2.Response.Res.Answer), + }).Debug("using response from resolver") + return msg2.Response, nil + } + } + } + + return nil, fmt.Errorf("resolution was not successful, errors: '%v', '%v'", err1, err2) +} + +// pick 2 different random resolvers from the resolver pool +func (r *ParallelBestResolver) pickRandom() (resolver1, resolver2 Resolver) { + resolver1 = r.resolvers[rand.Intn(len(r.resolvers))] + for resolver2 == resolver1 || resolver2 == nil { + resolver2 = r.resolvers[rand.Intn(len(r.resolvers))] + } + + return +} + +func resolve(req *Request, resolver Resolver, ch chan struct { + *Response + error +}) { + defer close(ch) + + resp, err := resolver.Resolve(req) + ch <- struct { + *Response + error + }{resp, err} +} + +func (r ParallelBestResolver) String() string { + return fmt.Sprintf("parallel best resolver '%s'", r.resolvers) +} diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go new file mode 100644 index 00000000..5215b13c --- /dev/null +++ b/resolver/parallel_best_resolver_test.go @@ -0,0 +1,39 @@ +package resolver + +import ( + "blocky/util" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Resolve_Best_Result(t *testing.T) { + fast := &resolverMock{} + + mockResp, err := util.NewMsgWithAnswer("example.com. 123 IN A 192.168.178.44") + if err != nil { + t.Error(err) + } + + fast.On("Resolve", mock.Anything).Return(&Response{Res: mockResp}, nil) + + slow := &resolverMock{} + slow.On("Resolve", mock.Anything).WaitUntil(time.After(50*time.Millisecond)).Return(&Response{Res: new(dns.Msg)}, nil) + + sut := NewParallelBestResolver([]Resolver{slow, fast}) + + resp, err := sut.Resolve(&Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + }) + + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "example.com. 123 IN A 192.168.178.44", resp.Res.Answer[0].String()) + fast.AssertExpectations(t) + slow.AssertExpectations(t) +} diff --git a/resolver/query_logging_resolver.go b/resolver/query_logging_resolver.go new file mode 100644 index 00000000..f91ab311 --- /dev/null +++ b/resolver/query_logging_resolver.go @@ -0,0 +1,229 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "encoding/csv" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +const ( + cleanUpRunPeriod = 12 * time.Hour + queryLoggingResolverPrefix = "query_logging_resolver" + logChanCap = 1000 +) + +// QueryLoggingResolver writes query information (question, answer, duration, ...) into +// log file or as log entry (if log directory is not configured) +type QueryLoggingResolver struct { + NextResolver + logDir string + perClient bool + logRetentionDays uint64 + logChan chan *queryLogEntry +} + +type queryLogEntry struct { + request *Request + response *Response + start time.Time + durationMs int64 + logger *logrus.Entry +} + +func NewQueryLoggingResolver(cfg config.QueryLogConfig) ChainedResolver { + if cfg.Dir != "" && unix.Access(cfg.Dir, unix.W_OK) != nil { + logger(queryLoggingResolverPrefix).Fatalf("query log directory '%s' does not exist or is not writable", cfg.Dir) + } + + logChan := make(chan *queryLogEntry, logChanCap) + + resolver := QueryLoggingResolver{ + logDir: cfg.Dir, + perClient: cfg.PerClient, + logRetentionDays: cfg.LogRetentionDays, + logChan: logChan, + } + + go resolver.writeLog() + + if cfg.LogRetentionDays > 0 { + go resolver.periodicCleanUp() + } + + return &resolver +} + +// triggers periodically cleanup of old log files +func (r *QueryLoggingResolver) periodicCleanUp() { + ticker := time.NewTicker(cleanUpRunPeriod) + defer ticker.Stop() + + for { + <-ticker.C + r.doCleanUp() + } +} + +// deletes old log files +func (r *QueryLoggingResolver) doCleanUp() { + logger := logger(queryLoggingResolverPrefix) + + logger.Trace("starting clean up") + + files, err := ioutil.ReadDir(r.logDir) + if err != nil { + logger.WithField("log_dir", r.logDir).Error("can't list log directory: ", err) + } + + // search for log files, which names starts with date + for _, f := range files { + if strings.HasSuffix(f.Name(), ".log") && len(f.Name()) > 10 { + t, err := time.Parse("2006-01-02", f.Name()[:10]) + if err == nil { + differenceDays := uint64(time.Since(t).Hours() / 24) + if r.logRetentionDays > 0 && differenceDays > r.logRetentionDays { + logger.WithFields(logrus.Fields{ + "file": f.Name(), + "ageInDays": differenceDays, + "logRetentionDays": r.logRetentionDays, + }).Info("existing log file is older than retention time and will be deleted") + + err := os.Remove(filepath.Join(r.logDir, f.Name())) + if err != nil { + logger.WithField("file", f.Name()).Error("can't remove file: ", err) + } + } + } + } + } +} + +func (r *QueryLoggingResolver) Resolve(request *Request) (*Response, error) { + logger := withPrefix(request.Log, queryLoggingResolverPrefix) + + start := time.Now() + + resp, err := r.next.Resolve(request) + + duration := time.Since(start).Milliseconds() + + if err == nil { + select { + case r.logChan <- &queryLogEntry{ + request: request, + response: resp, + start: start, + durationMs: duration, + logger: logger}: + default: + logger.Error("query log writer is too slow, log entry will be dropped") + } + } + + return resp, err +} + +// write entry: if log directory is configured, write to log file +func (r *QueryLoggingResolver) writeLog() { + for logEntry := range r.logChan { + if r.logDir != "" { + var clientPrefix string + + start := time.Now() + + dateString := logEntry.start.Format("2006-01-02") + + if r.perClient { + clientPrefix = strings.Join(logEntry.request.ClientNames, "-") + } else { + clientPrefix = "ALL" + } + + writePath := filepath.Join(r.logDir, fmt.Sprintf("%s_%s.log", dateString, clientPrefix)) + + file, err := os.OpenFile(writePath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666) + + if err != nil { + logEntry.logger.WithField("file_name", writePath).Error("can't create/open file", err) + } else { + writer := createCsvWriter(file) + + err := writer.Write(createQueryLogRow(logEntry)) + if err != nil { + logEntry.logger.WithField("file_name", writePath).Error("can't write to file", err) + } + writer.Flush() + } + + halfCap := cap(r.logChan) / 2 + + // if log channel is > 50% full, this could be a problem with slow writer (external storage over network etc.) + if len(r.logChan) > halfCap { + logEntry.logger.WithField("channel_len", + len(r.logChan)).Warnf("query log writer is too slow, write duration: %d ms", time.Since(start).Milliseconds()) + } + } else { + logEntry.logger.WithFields( + logrus.Fields{ + "response_reason": logEntry.response.Reason, + "answer": util.AnswerToString(logEntry.response.Res.Answer), + "duration_ms": logEntry.durationMs, + }, + ).Infof("query resolved") + } + } +} + +func createCsvWriter(file io.Writer) *csv.Writer { + writer := csv.NewWriter(file) + writer.Comma = '\t' + + return writer +} + +func createQueryLogRow(logEntry *queryLogEntry) []string { + request := logEntry.request + response := logEntry.response + + return []string{ + logEntry.start.Format("2006-01-02 15:04:05"), + request.ClientIP.String(), + strings.Join(request.ClientNames, "; "), + fmt.Sprintf("%d", logEntry.durationMs), + response.Reason, + util.QuestionToString(request.Req.Question), + util.AnswerToString(response.Res.Answer), + dns.RcodeToString[response.Res.Rcode], + } +} + +func (r *QueryLoggingResolver) Configuration() (result []string) { + if r.logDir != "" { + result = append(result, fmt.Sprintf("logDir= \"%s\"", r.logDir)) + result = append(result, fmt.Sprintf("perClient = %t", r.perClient)) + result = append(result, fmt.Sprintf("logRetentionDays= %d", r.logRetentionDays)) + + if r.logRetentionDays == 0 { + result = append(result, "log cleanup deactivated") + } + } else { + result = []string{"deactivated"} + } + + return +} + +func (r QueryLoggingResolver) String() string { + return fmt.Sprintf("query logging resolver") +} diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go new file mode 100644 index 00000000..54dd3b1e --- /dev/null +++ b/resolver/query_logging_resolver_test.go @@ -0,0 +1,211 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "bufio" + "encoding/csv" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_doCleanUp(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "queryLoggingResolver") + defer os.RemoveAll(tmpDir) + assert.NoError(t, err) + + // create 2 files, 7 and 8 days old + dateBefore7Days := time.Now().AddDate(0, 0, -7) + dateBefore8Days := time.Now().AddDate(0, 0, -8) + + f1, err := os.Create(filepath.Join(tmpDir, fmt.Sprintf("%s-test.log", dateBefore7Days.Format("2006-01-02")))) + assert.NoError(t, err) + + f2, err := os.Create(filepath.Join(tmpDir, fmt.Sprintf("%s-test.log", dateBefore8Days.Format("2006-01-02")))) + assert.NoError(t, err) + + sut := NewQueryLoggingResolver(config.QueryLogConfig{ + Dir: tmpDir, + LogRetentionDays: 7, + }) + + sut.(*QueryLoggingResolver).doCleanUp() + + // file 1 exist + _, err = os.Stat(f1.Name()) + assert.NoError(t, err) + + // file 2 was deleted + _, err = os.Stat(f2.Name()) + assert.Error(t, err) + assert.True(t, os.IsNotExist(err)) +} + +func Test_Resolve_WithEmptyConfig(t *testing.T) { + sut := NewQueryLoggingResolver(config.QueryLogConfig{}) + m := &resolverMock{} + resp, err := util.NewMsgWithAnswer("example.com. 300 IN A 123.122.121.120") + assert.NoError(t, err) + + m.On("Resolve", mock.Anything).Return(&Response{Res: resp, Reason: "reason"}, nil) + sut.Next(m) + + _, err = sut.Resolve(&Request{ + ClientIP: net.ParseIP("192.168.178.25"), + ClientNames: []string{"client1"}, + Req: util.NewMsgWithQuestion("google.de.", dns.TypeA), + Log: logrus.NewEntry(logrus.New())}) + assert.NoError(t, err) + m.AssertExpectations(t) +} +func Test_Resolve_WithLoggingPerClient(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "queryLoggingResolver") + assert.NoError(t, err) + + defer os.RemoveAll(tmpDir) + + sut := NewQueryLoggingResolver(config.QueryLogConfig{ + Dir: tmpDir, + PerClient: true, + }) + + m := &resolverMock{} + resp, err := util.NewMsgWithAnswer("example.com. 300 IN A 123.122.121.120") + assert.NoError(t, err) + + m.On("Resolve", mock.Anything).Return(&Response{Res: resp, Reason: "reason"}, nil) + sut.Next(m) + + // request client1 + _, err = sut.Resolve(&Request{ + ClientIP: net.ParseIP("192.168.178.25"), + ClientNames: []string{"client1"}, + Req: util.NewMsgWithQuestion("google.de.", dns.TypeA), + Log: logrus.NewEntry(logrus.New())}) + assert.NoError(t, err) + + // request client2 + _, err = sut.Resolve(&Request{ + ClientIP: net.ParseIP("192.168.178.26"), + ClientNames: []string{"client2"}, + Req: util.NewMsgWithQuestion("google.de.", dns.TypeA), + Log: logrus.NewEntry(logrus.New())}) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + m.AssertExpectations(t) + + // client1 + csvLines := readCsv(filepath.Join(tmpDir, fmt.Sprintf("%s_client1.log", time.Now().Format("2006-01-02")))) + + assert.Len(t, csvLines, 1) + assert.Equal(t, "192.168.178.25", csvLines[0][1]) + assert.Equal(t, "client1", csvLines[0][2]) + assert.Equal(t, "reason", csvLines[0][4]) + assert.Equal(t, "A (google.de.)", csvLines[0][5]) + assert.Equal(t, "A (123.122.121.120)", csvLines[0][6]) + + // client2 + csvLines = readCsv(filepath.Join(tmpDir, fmt.Sprintf("%s_client2.log", time.Now().Format("2006-01-02")))) + + assert.Len(t, csvLines, 1) + assert.Equal(t, "192.168.178.26", csvLines[0][1]) + assert.Equal(t, "client2", csvLines[0][2]) + assert.Equal(t, "reason", csvLines[0][4]) + assert.Equal(t, "A (google.de.)", csvLines[0][5]) + assert.Equal(t, "A (123.122.121.120)", csvLines[0][6]) +} + +func Test_Resolve_WithLoggingAll(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "queryLoggingResolver") + assert.NoError(t, err) + + defer os.RemoveAll(tmpDir) + + sut := NewQueryLoggingResolver(config.QueryLogConfig{ + Dir: tmpDir, + PerClient: false, + }) + + m := &resolverMock{} + resp, err := util.NewMsgWithAnswer("example.com. 300 IN A 123.122.121.120") + assert.NoError(t, err) + + m.On("Resolve", mock.Anything).Return(&Response{Res: resp, Reason: "reason"}, nil) + sut.Next(m) + + // request client1 + _, err = sut.Resolve(&Request{ + ClientIP: net.ParseIP("192.168.178.25"), + ClientNames: []string{"client1"}, + Req: util.NewMsgWithQuestion("google.de.", dns.TypeA), + Log: logrus.NewEntry(logrus.New())}) + assert.NoError(t, err) + + // request client2 + _, err = sut.Resolve(&Request{ + ClientIP: net.ParseIP("192.168.178.26"), + ClientNames: []string{"client2"}, + Req: util.NewMsgWithQuestion("google.de.", dns.TypeA), + Log: logrus.NewEntry(logrus.New())}) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + m.AssertExpectations(t) + + csvLines := readCsv(filepath.Join(tmpDir, fmt.Sprintf("%s_ALL.log", time.Now().Format("2006-01-02")))) + assert.Len(t, csvLines, 2) + + // client1 -> first line + assert.Equal(t, "192.168.178.25", csvLines[0][1]) + assert.Equal(t, "client1", csvLines[0][2]) + assert.Equal(t, "reason", csvLines[0][4]) + assert.Equal(t, "A (google.de.)", csvLines[0][5]) + assert.Equal(t, "A (123.122.121.120)", csvLines[0][6]) + + // client2 -> second line + assert.Equal(t, "192.168.178.26", csvLines[1][1]) + assert.Equal(t, "client2", csvLines[1][2]) + assert.Equal(t, "reason", csvLines[1][4]) + assert.Equal(t, "A (google.de.)", csvLines[1][5]) + assert.Equal(t, "A (123.122.121.120)", csvLines[1][6]) +} + +func readCsv(file string) [][]string { + var result [][]string + + csvFile, err := os.Open(file) + if err != nil { + log.Fatal("can't open file", err) + } + + reader := csv.NewReader(bufio.NewReader(csvFile)) + reader.Comma = '\t' + + for { + line, err := reader.Read() + if err == io.EOF { + break + } else if err != nil { + log.Fatal("can't read line", err) + } + + result = append(result, line) + } + + return result +} diff --git a/resolver/resolver.go b/resolver/resolver.go new file mode 100644 index 00000000..0cb73f70 --- /dev/null +++ b/resolver/resolver.go @@ -0,0 +1,57 @@ +package resolver + +import ( + "net" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" +) + +type Request struct { + ClientIP net.IP + ClientNames []string + Req *dns.Msg + Log *logrus.Entry +} + +type ResponseType uint8 + +const ( + Resolved ResponseType = iota + Blocked +) + +type Response struct { + Res *dns.Msg + Reason string +} +type Resolver interface { + Resolve(req *Request) (*Response, error) + Configuration() []string +} + +type ChainedResolver interface { + Resolver + Next(n Resolver) + GetNext() Resolver +} + +type NextResolver struct { + next Resolver +} + +func (r *NextResolver) Next(n Resolver) { + r.next = n +} + +func (r *NextResolver) GetNext() Resolver { + return r.next +} + +func logger(prefix string) *logrus.Entry { + return logrus.WithField("prefix", prefix) +} + +func withPrefix(logger *logrus.Entry, prefix string) *logrus.Entry { + return logger.WithField("prefix", prefix) +} diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go new file mode 100644 index 00000000..9791da53 --- /dev/null +++ b/resolver/upstream_resolver.go @@ -0,0 +1,69 @@ +package resolver + +import ( + "blocky/config" + "blocky/util" + "fmt" + "net" + "strconv" + "time" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" +) + +// UpstreamResolver sends request to external DNS server +type UpstreamResolver struct { + NextResolver + client *dns.Client + upstream string +} + +func NewUpstreamResolver(upstream config.Upstream) Resolver { + client := new(dns.Client) + client.Net = upstream.Net + + return &UpstreamResolver{ + client: client, + upstream: net.JoinHostPort(upstream.Host, strconv.Itoa(int(upstream.Port)))} +} + +func (r *UpstreamResolver) Configuration() (result []string) { + return +} + +func (r *UpstreamResolver) Resolve(request *Request) (response *Response, err error) { + logger := withPrefix(request.Log, "upstream_resolver") + + attempt := 1 + + var rtt time.Duration + + var resp *dns.Msg + + for attempt <= 3 { + if resp, rtt, err = r.client.Exchange(request.Req, r.upstream); err == nil { + logger.WithFields(logrus.Fields{ + "answer": util.AnswerToString(resp.Answer), + "return_code": dns.RcodeToString[resp.Rcode], + "upstream": r.upstream, + "response_time_ms": rtt.Milliseconds(), + }).Debugf("received response from upstream") + + return &Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s) in %d ms", r.upstream, rtt.Milliseconds())}, err + } + + if errNet, ok := err.(net.Error); ok && (errNet.Timeout() || errNet.Temporary()) { + logger.WithField("attempt", attempt).Debugf("Temporary network error / Timeout occurred, retrying...") + attempt++ + } else { + return nil, err + } + } + + return +} + +func (r UpstreamResolver) String() string { + return fmt.Sprintf("upstream '%s'", r.upstream) +} diff --git a/resolver/upstream_resolver_test.go b/resolver/upstream_resolver_test.go new file mode 100644 index 00000000..74ab8967 --- /dev/null +++ b/resolver/upstream_resolver_test.go @@ -0,0 +1,79 @@ +package resolver + +import ( + "blocky/util" + "fmt" + "strings" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func Test_Resolve_Upstream(t *testing.T) { + upstream := TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + response, err := util.NewMsgWithAnswer("example.com 123 IN A 123.124.122.122") + + assert.NoError(t, err) + return response + }) + + sut := NewUpstreamResolver(upstream) + + request := &Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + resp, err := sut.Resolve(request) + assert.NoError(t, err) + assert.Equal(t, dns.RcodeSuccess, resp.Res.Rcode) + assert.Equal(t, "example.com. 123 IN A 123.124.122.122", resp.Res.Answer[0].String()) +} + +func TestUpstreamTimeout(t *testing.T) { + counter := 0 + attemptsWithTimeout := 2 + + upstream := TestUDPUpstream(func(request *dns.Msg) (response *dns.Msg) { + counter++ + // timeout on first x attempts + if counter <= attemptsWithTimeout { + fmt.Print("timeout") + time.Sleep(110 * time.Millisecond) + } + response, err := util.NewMsgWithAnswer("example.com 123 IN A 123.124.122.122") + assert.NoError(t, err) + + return response + }) + + sut := NewUpstreamResolver(upstream).(*UpstreamResolver) + sut.client.ReadTimeout = 100 * time.Millisecond + + request := &Request{ + Req: util.NewMsgWithQuestion("example.com.", dns.TypeA), + Log: logrus.NewEntry(logrus.New()), + } + + // first request -> after 2 timeouts success + response, err := sut.Resolve(request) + assert.NoError(t, err) + + if response != nil { + assert.Equal(t, dns.RcodeSuccess, response.Res.Rcode) + assert.Equal(t, "example.com.\t123\tIN\tA\t123.124.122.122", response.Res.Answer[0].String()) + } + + attemptsWithTimeout = 3 + counter = 0 + + // second request + // all 3 attempts with timeout + response, err = sut.Resolve(request) + assert.Error(t, err) + assert.True(t, strings.HasSuffix(err.Error(), "i/o timeout")) + assert.Nil(t, response) +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 00000000..af787bf7 --- /dev/null +++ b/server/server.go @@ -0,0 +1,189 @@ +package server + +import ( + "blocky/config" + "blocky/resolver" + "os" + "os/signal" + "syscall" + + "blocky/util" + "fmt" + "net" + + "github.com/miekg/dns" + "github.com/sirupsen/logrus" +) + +type Server struct { + udpServer *dns.Server + tcpServer *dns.Server + queryResolver resolver.Resolver +} + +func logger() *logrus.Entry { + return logrus.WithField("prefix", "server") +} + +func NewServer(cfg *config.Config) (*Server, error) { + udpHandler := dns.NewServeMux() + tcpHandler := dns.NewServeMux() + udpServer := &dns.Server{ + Addr: fmt.Sprintf(":%d", cfg.Port), + Net: "udp", + Handler: udpHandler, + NotifyStartedFunc: func() { + logger().Infof("udp server is up and running") + }, + UDPSize: 65535} + tcpServer := &dns.Server{ + Addr: fmt.Sprintf(":%d", cfg.Port), + Net: "tcp", + Handler: tcpHandler, + NotifyStartedFunc: func() { + logger().Infof("tcp server is up and running") + }, + } + + var queryResolver resolver.Resolver + + clientNamesResolver := resolver.NewClientNamesResolver(cfg.ClientLookup) + queryLoggingResolver := resolver.NewQueryLoggingResolver(cfg.QueryLog) + conditionalUpstreamResolver := resolver.NewConditionalUpstreamResolver(cfg.Conditional) + customDNSResolver := resolver.NewCustomDNSResolver(cfg.CustomDNS) + blacklistResolver := resolver.NewBlockingResolver(cfg.Blocking) + + cachingResolver := resolver.NewCachingResolver() + parallelUpstreamResolver := createParallelUpstreamResolver(cfg.Upstream.ExternalResolvers) + + clientNamesResolver.Next(queryLoggingResolver) + queryLoggingResolver.Next(conditionalUpstreamResolver) + conditionalUpstreamResolver.Next(customDNSResolver) + customDNSResolver.Next(blacklistResolver) + blacklistResolver.Next(cachingResolver) + cachingResolver.Next(parallelUpstreamResolver) + + queryResolver = clientNamesResolver + + server := Server{ + udpServer: udpServer, + tcpServer: tcpServer, + queryResolver: queryResolver, + } + + server.printConfiguration() + + udpHandler.HandleFunc(".", server.OnRequest) + tcpHandler.HandleFunc(".", server.OnRequest) + + return &server, nil +} + +func (s *Server) printConfiguration() { + logger().Info("current configuration:") + + res := s.queryResolver + for res != nil { + logger().Infof("-> resolver: '%s'", res) + + for _, c := range res.Configuration() { + logger().Infof(" %s", c) + } + + if c, ok := res.(resolver.ChainedResolver); ok { + res = c.GetNext() + } else { + break + } + } +} + +func createParallelUpstreamResolver(upstream []config.Upstream) resolver.Resolver { + if len(upstream) == 1 { + return resolver.NewUpstreamResolver(upstream[0]) + } + + resolvers := make([]resolver.Resolver, len(upstream)) + + for i, u := range upstream { + resolvers[i] = resolver.NewUpstreamResolver(u) + } + + return resolver.NewParallelBestResolver(resolvers) +} + +func (s *Server) Start() { + logger().Info("Starting server") + + go func() { + if err := s.udpServer.ListenAndServe(); err != nil { + logger().Fatalf("start %s listener failed: %v", s.udpServer.Net, err) + } + }() + + go func() { + if err := s.tcpServer.ListenAndServe(); err != nil { + logger().Fatalf("start %s listener failed: %v", s.tcpServer.Net, err) + } + }() + + signals := make(chan os.Signal) + signal.Notify(signals, syscall.SIGUSR1) + + go func() { + for { + <-signals + s.printConfiguration() + } + }() +} + +func (s *Server) Stop() { + logger().Info("Stopping server") + + if err := s.udpServer.Shutdown(); err != nil { + logger().Fatalf("stop %s listener failed: %v", s.udpServer.Net, err) + } + + if err := s.tcpServer.Shutdown(); err != nil { + logger().Fatalf("stop %s listener failed: %v", s.tcpServer.Net, err) + } +} + +func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) { + logger().Debug("new request") + + clientIP := resolveClientIP(w.RemoteAddr()) + r := &resolver.Request{ + ClientIP: clientIP, + Req: request, + Log: logrus.WithFields(logrus.Fields{ + "question": util.QuestionToString(request.Question), + "client_ip": clientIP, + }), + } + + response, err := s.queryResolver.Resolve(r) + + if err != nil { + logger().Errorf("error on processing request: %v", err) + dns.HandleFailed(w, request) + } else { + response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired + + if err := w.WriteMsg(response.Res); err != nil { + logger().Error("can't write message: ", err) + } + } +} + +func resolveClientIP(addr net.Addr) net.IP { + var clientIP net.IP + if t, ok := addr.(*net.UDPAddr); ok { + clientIP = t.IP + } else if t, ok := addr.(*net.TCPAddr); ok { + clientIP = t.IP + } + + return clientIP +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 00000000..055871eb --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,323 @@ +package server + +import ( + "blocky/config" + "blocky/resolver" + "blocky/util" + "fmt" + "log" + "net" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +var mockClientName string + +// test case definition +var tests = []struct { + name string + request *dns.Msg + mockClientName string + respValidator func(*testing.T, *dns.Msg) +}{ + { + // resolve query via external dns + name: "resolveWithUpstream", + request: util.NewMsgWithQuestion("google.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "google.de.\t250\tIN\tA\t123.124.122.122", resp.Answer[0].String()) + }, + }, + { + // custom dnd entry with exact match + name: "customDns", + request: util.NewMsgWithQuestion("custom.lan.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "custom.lan.\t3600\tIN\tA\t192.168.178.55", resp.Answer[0].String()) + }, + }, + { + // sub domain custom dns + name: "customDnsWithSubdomain", + request: util.NewMsgWithQuestion("host.lan.home.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "host.lan.home.\t3600\tIN\tA\t192.168.178.56", resp.Answer[0].String()) + }, + }, + { + // delegate to special dns upstream + name: "conditional", + request: util.NewMsgWithQuestion("host.fritz.box.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "host.fritz.box.\t3600\tIN\tA\t192.168.178.2", resp.Answer[0].String()) + }, + }, + { + // blocking default group + name: "blockDefault", + request: util.NewMsgWithQuestion("doubleclick.net.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "doubleclick.net.\t21600\tIN\tA\t0.0.0.0", resp.Answer[0].String()) + }, + }, + { + // blocking default group with sub domain + name: "blockDefaultWithSubdomain", + request: util.NewMsgWithQuestion("www.bild.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "www.bild.de.\t21600\tIN\tA\t0.0.0.0", resp.Answer[0].String()) + }, + }, + { + // no blocking default group with sub domain + name: "noBlockDefaultWithSubdomain", + request: util.NewMsgWithQuestion("bild.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "bild.de.\t250\tIN\tA\t123.124.122.122", resp.Answer[0].String()) + }, + }, + { + // white and block default group + name: "whiteBlackDefault", + request: util.NewMsgWithQuestion("heise.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "heise.de.\t250\tIN\tA\t123.124.122.122", resp.Answer[0].String()) + }, + }, + { + // no block client whitelist only + name: "noBlockWhitelistOnly", + mockClientName: "clWhitelistOnly", + request: util.NewMsgWithQuestion("heise.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "123.124.122.122", resp.Answer[0].(*dns.A).A.String()) + }, + }, + { + // block client whitelist only + name: "blockWhitelistOnly", + mockClientName: "clWhitelistOnly", + request: util.NewMsgWithQuestion("google.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "0.0.0.0", resp.Answer[0].(*dns.A).A.String()) + }, + }, + { + // block client with 2 groups + name: "block2groups1", + mockClientName: "clAdsAndYoutube", + request: util.NewMsgWithQuestion("www.bild.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "0.0.0.0", resp.Answer[0].(*dns.A).A.String()) + }, + }, + { + // block client with 2 groups + name: "block2groups2", + mockClientName: "clAdsAndYoutube", + request: util.NewMsgWithQuestion("youtube.com.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "0.0.0.0", resp.Answer[0].(*dns.A).A.String()) + }, + }, + { + // lient with 1 group: no block if domain in other group + name: "noBlockBlacklistOtherGroup", + mockClientName: "clYoutubeOnly", + request: util.NewMsgWithQuestion("www.bild.de.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "123.124.122.122", resp.Answer[0].(*dns.A).A.String()) + }, + }, + { + // block client with 1 group + name: "blockBlacklist", + mockClientName: "clYoutubeOnly", + request: util.NewMsgWithQuestion("youtube.com.", dns.TypeA), + respValidator: func(t *testing.T, resp *dns.Msg) { + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + assert.Equal(t, "0.0.0.0", resp.Answer[0].(*dns.A).A.String()) + }, + }, +} + +//nolint:funlen +func TestDnsRequest(t *testing.T) { + upstreamGoogle := resolver.TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + response, err := util.NewMsgWithAnswer(fmt.Sprintf("%s %d %s %s %s", + util.ExtractDomain(request.Question[0]), 123, "IN", "A", "123.124.122.122")) + + assert.NoError(t, err) + return response + }) + upstreamFritzbox := resolver.TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + response, err := util.NewMsgWithAnswer(fmt.Sprintf("%s %d %s %s %s", + util.ExtractDomain(request.Question[0]), 3600, "IN", "A", "192.168.178.2")) + + assert.NoError(t, err) + return response + }) + + upstreamClient := resolver.TestUDPUpstream(func(request *dns.Msg) *dns.Msg { + response, err := util.NewMsgWithAnswer(fmt.Sprintf("%s %d %s %s %s", + util.ExtractDomain(request.Question[0]), 3600, "IN", "PTR", mockClientName)) + + assert.NoError(t, err) + return response + }) + + // create server + server, err := NewServer(&config.Config{ + CustomDNS: config.CustomDNSConfig{ + Mapping: map[string]net.IP{ + "custom.lan": net.ParseIP("192.168.178.55"), + "lan.home": net.ParseIP("192.168.178.56"), + }, + }, + Conditional: config.ConditionalUpstreamConfig{ + Mapping: map[string]config.Upstream{"fritz.box": upstreamFritzbox}, + }, + Blocking: config.BlockingConfig{ + BlackLists: map[string][]string{ + "ads": { + "../testdata/doubleclick.net.txt", + "../testdata/www.bild.de.txt", + "../testdata/heise.de.txt"}, + "youtube": {"../testdata/youtube.com.txt"}}, + WhiteLists: map[string][]string{ + "ads": {"../testdata/heise.de.txt"}, + "whitelist": {"../testdata/heise.de.txt"}, + }, + ClientGroupsBlock: map[string][]string{ + "default": {"ads"}, + "clWhitelistOnly": {"whitelist"}, + "clAdsAndYoutube": {"ads", "youtube"}, + "clYoutubeOnly": {"youtube"}, + }, + }, + Upstream: config.UpstreamConfig{ + ExternalResolvers: []config.Upstream{upstreamGoogle}, + }, + ClientLookup: config.ClientLookupConfig{ + Upstream: upstreamClient, + }, + + Port: 55555, + }) + + assert.NoError(t, err) + + // start server + go func() { + server.Start() + }() + + defer server.Stop() + + time.Sleep(100 * time.Millisecond) + + for _, tt := range tests { + tst := tt + t.Run(tt.name, func(t *testing.T) { + res := server.queryResolver + for res != nil { + if t, ok := res.(*resolver.ClientNamesResolver); ok { + t.FlushCache() + break + } + if c, ok := res.(resolver.ChainedResolver); ok { + res = c.GetNext() + } else { + break + } + } + + mockClientName = tst.mockClientName + response := requestServer(tst.request) + + tst.respValidator(t, response) + }) + } +} + +// 26106 43273 ns/op 12518 B/op 137 allocs/op +func BenchmarkServerExternalResolver(b *testing.B) { + msg, _ := util.NewMsgWithAnswer(fmt.Sprintf("example.com IN A 123.124.122.122")) + upstreamExternal := resolver.TestUDPUpstreamWithResponse(msg) + + // create server + server, err := NewServer(&config.Config{ + Upstream: config.UpstreamConfig{ + ExternalResolvers: []config.Upstream{upstreamExternal}, + }, + Port: 55555, + }) + + assert.NoError(b, err) + + // start server + go func() { + server.Start() + }() + + defer server.Stop() + + time.Sleep(100 * time.Millisecond) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = requestServer(util.NewMsgWithQuestion("google.de.", dns.TypeA)) + } + }) +} + +func requestServer(request *dns.Msg) *dns.Msg { + conn, err := net.Dial("udp", ":55555") + if err != nil { + log.Fatal("could not connect to server: ", err) + } + defer conn.Close() + + msg, err := request.Pack() + if err != nil { + log.Fatal("can't pack request: ", err) + } + + _, err = conn.Write(msg) + if err != nil { + log.Fatal("can't send request to server: ", err) + } + + out := make([]byte, 1024) + + if _, err := conn.Read(out); err == nil { + response := new(dns.Msg) + err := response.Unpack(out) + + if err != nil { + log.Fatal("can't unpack response: ", err) + } + + return response + } + + log.Fatal("could not read from connection", err) + + return nil +} diff --git a/testdata/config.yml b/testdata/config.yml new file mode 100644 index 00000000..b24fbd81 --- /dev/null +++ b/testdata/config.yml @@ -0,0 +1,44 @@ +upstream: + externalResolvers: + - udp:8.8.8.8 + - udp:8.8.4.4 + - udp:1.1.1.1 +customDNS: + mapping: + my.duckdns.org: 192.168.178.3 +conditional: + mapping: + fritz.box: udp:192.168.178.1 +blocking: + blackLists: + ads: + - https://s3.amazonaws.com/lists.disconnect.me/simple_ad.txt + - https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts + - https://mirror1.malwaredomains.com/files/justdomains + - http://sysctl.org/cameleon/hosts + - https://zeustracker.abuse.ch/blocklist.php?download=domainblocklist + - https://s3.amazonaws.com/lists.disconnect.me/simple_tracking.txt + special: + - https://hosts-file.net/ad_servers.txt + whiteLists: + ads: + - whitelist.txt + clientGroupsBlock: + default: + - ads + - special + Laptop-D.fritz.box: + - ads + #blockMode: zeroIP +clientLookup: + upstream: udp:192.168.178.1 + singleNameOrder: + - 2 + - 1 + +queryLog: + dir: /opt/log + perClient: true + +port: 55555 +logLevel: debug \ No newline at end of file diff --git a/testdata/doubleclick.net.txt b/testdata/doubleclick.net.txt new file mode 100644 index 00000000..b96a408e --- /dev/null +++ b/testdata/doubleclick.net.txt @@ -0,0 +1 @@ +doubleclick.net \ No newline at end of file diff --git a/testdata/heise.de.txt b/testdata/heise.de.txt new file mode 100644 index 00000000..5eab4544 --- /dev/null +++ b/testdata/heise.de.txt @@ -0,0 +1 @@ +heise.de \ No newline at end of file diff --git a/testdata/www.bild.de.txt b/testdata/www.bild.de.txt new file mode 100644 index 00000000..eaddd548 --- /dev/null +++ b/testdata/www.bild.de.txt @@ -0,0 +1 @@ +www.bild.de \ No newline at end of file diff --git a/testdata/youtube.com.txt b/testdata/youtube.com.txt new file mode 100644 index 00000000..7bfa9c11 --- /dev/null +++ b/testdata/youtube.com.txt @@ -0,0 +1 @@ +youtube.com \ No newline at end of file diff --git a/util/common.go b/util/common.go new file mode 100644 index 00000000..15be4e44 --- /dev/null +++ b/util/common.go @@ -0,0 +1,80 @@ +package util + +import ( + "fmt" + "net" + "strings" + + "github.com/miekg/dns" +) + +func qTypeToString() func(uint16) string { + innerMap := map[uint16]string{ + dns.TypeA: "A", + dns.TypeAAAA: "AAAA", + dns.TypeCNAME: "CNAME", + dns.TypePTR: "PTR", + dns.TypeMX: "MX", + } + + return func(key uint16) string { + return innerMap[key] + } +} + +func AnswerToString(answer []dns.RR) string { + answers := make([]string, len(answer)) + + for i, record := range answer { + switch v := record.(type) { + case *dns.A: + answers[i] = fmt.Sprintf("A (%s)", v.A) + case *dns.AAAA: + answers[i] = fmt.Sprintf("AAAA (%s)", v.AAAA) + case *dns.CNAME: + answers[i] = fmt.Sprintf("CNAME (%s)", v.Target) + case *dns.PTR: + answers[i] = fmt.Sprintf("PTR (%s)", v.Ptr) + default: + answers[i] = fmt.Sprint(record) + } + } + + return strings.Join(answers, ", ") +} + +func QuestionToString(questions []dns.Question) string { + result := make([]string, len(questions)) + for i, question := range questions { + result[i] = fmt.Sprintf("%s (%s)", qTypeToString()(question.Qtype), question.Name) + } + + return strings.Join(result, ", ") +} + +func CreateAnswerFromQuestion(question dns.Question, ip net.IP, remainingTTL uint32) (dns.RR, error) { + return dns.NewRR(fmt.Sprintf("%s %d %s %s %s", question.Name, remainingTTL, "IN", qTypeToString()(question.Qtype), ip)) +} + +func ExtractDomain(question dns.Question) string { + return strings.TrimSuffix(strings.ToLower(question.Name), ".") +} + +func NewMsgWithQuestion(question string, mType uint16) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(question, mType) + + return msg +} + +func NewMsgWithAnswer(answer string) (*dns.Msg, error) { + rr, err := dns.NewRR(answer) + if err != nil { + return nil, err + } + + msg := new(dns.Msg) + msg.Answer = []dns.RR{rr} + + return msg, nil +}