diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ec941e14..653d4034 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -31,9 +31,7 @@ "GitHub.vscode-github-actions" ], "settings": { - "go.lintFlags": [ - "--config=${containerWorkspaceFolder}/.golangci.yml" - ], + "go.lintFlags": ["--config=${containerWorkspaceFolder}/.golangci.yml"], "go.alternateTools": { "go-langserver": "gopls" }, diff --git a/.devcontainer/scripts/runItOnGo.sh b/.devcontainer/scripts/runItOnGo.sh index ce57fe93..54bf1d81 100644 --- a/.devcontainer/scripts/runItOnGo.sh +++ b/.devcontainer/scripts/runItOnGo.sh @@ -2,17 +2,17 @@ FOLDER_PATH=$1 if [ -z "${FOLDER_PATH}" ]; then - FOLDER_PATH=$PWD + FOLDER_PATH=$PWD fi BASE_PATH=$2 if [ -z "${BASE_PATH}" ]; then - BASE_PATH=$WORKSPACE_FOLDER + BASE_PATH=$WORKSPACE_FOLDER fi if [ "$FOLDER_PATH" = "$BASE_PATH" ]; then - echo "Skipping lcov creation for base path" - exit 1 + echo "Skipping lcov creation for base path" + exit 1 fi FOLDER_NAME=${FOLDER_PATH#"$BASE_PATH/"} diff --git a/config/caching.go b/config/caching.go index c942398c..0d86a383 100644 --- a/config/caching.go +++ b/config/caching.go @@ -20,7 +20,7 @@ type CachingConfig struct { // IsEnabled implements `config.Configurable`. func (c *CachingConfig) IsEnabled() bool { - return c.MaxCachingTime.IsAboveZero() + return c.MaxCachingTime.IsAtLeastZero() } // LogConfig implements `config.Configurable`. diff --git a/config/caching_test.go b/config/caching_test.go index 1cc2c635..8be9dace 100644 --- a/config/caching_test.go +++ b/config/caching_test.go @@ -20,11 +20,22 @@ var _ = Describe("CachingConfig", func() { }) Describe("IsEnabled", func() { - It("should be false by default", func() { + It("should be true by default", func() { cfg := CachingConfig{} Expect(defaults.Set(&cfg)).Should(Succeed()) - Expect(cfg.IsEnabled()).Should(BeFalse()) + Expect(cfg.IsEnabled()).Should(BeTrue()) + }) + + When("the config is disabled", func() { + BeforeEach(func() { + cfg = CachingConfig{ + MaxCachingTime: Duration(time.Hour * -1), + } + }) + It("should be false", func() { + Expect(cfg.IsEnabled()).Should(BeFalse()) + }) }) When("the config is enabled", func() { @@ -72,7 +83,7 @@ var _ = Describe("CachingConfig", func() { Expect(cfg.Prefetching).Should(BeTrue()) Expect(cfg.PrefetchThreshold).Should(Equal(0)) - Expect(cfg.MaxCachingTime).ShouldNot(BeZero()) + Expect(cfg.MaxCachingTime).Should(BeZero()) }) }) }) diff --git a/config/config.go b/config/config.go index 6d249d70..0b885e65 100644 --- a/config/config.go +++ b/config/config.go @@ -211,6 +211,7 @@ type Config struct { FqdnOnly FqdnOnlyConfig `yaml:"fqdnOnly"` Filtering FilteringConfig `yaml:"filtering"` Ede EdeConfig `yaml:"ede"` + ECS ECSConfig `yaml:"ecs"` SUDN SUDNConfig `yaml:"specialUseDomains"` // Deprecated options diff --git a/config/duration.go b/config/duration.go index c8ef09a7..1a2c7a9b 100644 --- a/config/duration.go +++ b/config/duration.go @@ -8,24 +8,35 @@ import ( "github.com/hako/durafmt" ) +// Duration is a wrapper for time.Duration to support yaml unmarshalling type Duration time.Duration +// ToDuration converts Duration to time.Duration func (c Duration) ToDuration() time.Duration { return time.Duration(c) } +// IsAboveZero returns true if duration is above zero func (c Duration) IsAboveZero() bool { return c.ToDuration() > 0 } +// IsAtLeastZero returns true if duration is at least zero +func (c Duration) IsAtLeastZero() bool { + return c.ToDuration() >= 0 +} + +// Seconds returns duration in seconds func (c Duration) Seconds() float64 { return c.ToDuration().Seconds() } +// SecondsU32 returns duration in seconds as uint32 func (c Duration) SecondsU32() uint32 { return uint32(c.Seconds()) } +// String implements `fmt.Stringer` func (c Duration) String() string { return durafmt.Parse(c.ToDuration()).String() } diff --git a/config/ecs.go b/config/ecs.go new file mode 100644 index 00000000..042098d5 --- /dev/null +++ b/config/ecs.go @@ -0,0 +1,81 @@ +package config + +import ( + "fmt" + "strconv" + + "github.com/sirupsen/logrus" +) + +const ( + ecsIpv4MaskMax = uint8(32) + ecsIpv6MaskMax = uint8(128) +) + +// ECSv4Mask is the subnet mask to be added as EDNS0 option for IPv4 +type ECSv4Mask uint8 + +// UnmarshalText implements the encoding.TextUnmarshaler interface +func (x *ECSv4Mask) UnmarshalText(text []byte) error { + res, err := unmarshalInternal(text, ecsIpv4MaskMax, "IPv4") + if err != nil { + return err + } + + *x = ECSv4Mask(res) + + return nil +} + +// ECSv6Mask is the subnet mask to be added as EDNS0 option for IPv6 +type ECSv6Mask uint8 + +// UnmarshalText implements the encoding.TextUnmarshaler interface +func (x *ECSv6Mask) UnmarshalText(text []byte) error { + res, err := unmarshalInternal(text, ecsIpv6MaskMax, "IPv6") + if err != nil { + return err + } + + *x = ECSv6Mask(res) + + return nil +} + +// ECSConfig is the configuration of the ECS resolver +type ECSConfig struct { + UseAsClient bool `yaml:"useAsClient" default:"false"` + Forward bool `yaml:"forward" default:"false"` + IPv4Mask ECSv4Mask `yaml:"ipv4Mask" default:"0"` + IPv6Mask ECSv6Mask `yaml:"ipv6Mask" default:"0"` +} + +// IsEnabled returns true if the ECS resolver is enabled +func (c *ECSConfig) IsEnabled() bool { + return c.UseAsClient || c.Forward || c.IPv4Mask > 0 || c.IPv6Mask > 0 +} + +// LogConfig logs the configuration +func (c *ECSConfig) LogConfig(logger *logrus.Entry) { + logger.Infof("Use as client = %t", c.UseAsClient) + logger.Infof("Forward = %t", c.Forward) + logger.Infof("IPv4 netmask = %d", c.IPv4Mask) + logger.Infof("IPv6 netmask = %d", c.IPv6Mask) +} + +// unmarshalInternal unmarshals the subnet mask from the given text and checks if the value is valid +// it is used by the UnmarshalText methods of ECSv4Mask and ECSv6Mask +func unmarshalInternal(text []byte, maxvalue uint8, name string) (uint8, error) { + strVal := string(text) + + uiVal, err := strconv.ParseUint(strVal, 10, 8) + if err != nil { + return 0, err + } + + if uiVal > uint64(maxvalue) { + return 0, fmt.Errorf("mask value (%s) is too large for %s(max: %d)", strVal, name, maxvalue) + } + + return uint8(uiVal), nil +} diff --git a/config/ecs_test.go b/config/ecs_test.go new file mode 100644 index 00000000..372ccc05 --- /dev/null +++ b/config/ecs_test.go @@ -0,0 +1,167 @@ +package config + +import ( + "github.com/0xERR0R/blocky/log" + "github.com/creasty/defaults" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gopkg.in/yaml.v2" +) + +var _ = Describe("ECSConfig", func() { + var ( + c ECSConfig + err error + ) + + BeforeEach(func() { + err = defaults.Set(&c) + Expect(err).Should(Succeed()) + }) + + Describe("IsEnabled", func() { + When("all fields are default", func() { + It("should be disabled", func() { + Expect(c.IsEnabled()).Should(BeFalse()) + }) + }) + + When("UseEcsAsClient is true", func() { + BeforeEach(func() { + c.UseAsClient = true + }) + + It("should be enabled", func() { + Expect(c.IsEnabled()).Should(BeTrue()) + }) + }) + + When("ForwardEcs is true", func() { + BeforeEach(func() { + c.Forward = true + }) + + It("should be enabled", func() { + Expect(c.IsEnabled()).Should(BeTrue()) + }) + }) + + When("IPv4Mask is set", func() { + BeforeEach(func() { + c.IPv4Mask = 24 + }) + + It("should be enabled", func() { + Expect(c.IsEnabled()).Should(BeTrue()) + }) + }) + + When("IPv6Mask is set", func() { + BeforeEach(func() { + c.IPv6Mask = 64 + }) + + It("should be enabled", func() { + Expect(c.IsEnabled()).Should(BeTrue()) + }) + }) + }) + + Describe("LogConfig", func() { + BeforeEach(func() { + logger, hook = log.NewMockEntry() + }) + + It("should log configuration", func() { + c.LogConfig(logger) + + Expect(hook.Calls).Should(HaveLen(4)) + Expect(hook.Messages).Should(SatisfyAll( + ContainElement(ContainSubstring("Use as client")), + ContainElement(ContainSubstring("Forward")), + ContainElement(ContainSubstring("IPv4 netmask")), + ContainElement(ContainSubstring("IPv6 netmask")), + )) + }) + }) + + Describe("Parse", func() { + var data []byte + + Context("IPv4Mask", func() { + var ipmask ECSv4Mask + + When("Parse correct value", func() { + BeforeEach(func() { + data = []byte("24") + err = yaml.Unmarshal(data, &ipmask) + Expect(err).Should(Succeed()) + }) + + It("should be value", func() { + Expect(ipmask).Should(Equal(ECSv4Mask(24))) + }) + }) + + When("Parse NaN value", func() { + BeforeEach(func() { + data = []byte("FALSE") + err = yaml.Unmarshal(data, &ipmask) + }) + + It("should be error", func() { + Expect(err).Should(HaveOccurred()) + }) + }) + + When("Parse incorrect value", func() { + BeforeEach(func() { + data = []byte("35") + err = yaml.Unmarshal(data, &ipmask) + }) + + It("should be error", func() { + Expect(err).Should(HaveOccurred()) + }) + }) + }) + + Context("IPv6Mask", func() { + var ipmask ECSv6Mask + + When("Parse correct value", func() { + BeforeEach(func() { + data = []byte("64") + err = yaml.Unmarshal(data, &ipmask) + Expect(err).Should(Succeed()) + }) + + It("should be value", func() { + Expect(ipmask).Should(Equal(ECSv6Mask(64))) + }) + }) + + When("Parse NaN value", func() { + BeforeEach(func() { + data = []byte("FALSE") + err = yaml.Unmarshal(data, &ipmask) + }) + + It("should be error", func() { + Expect(err).Should(HaveOccurred()) + }) + }) + + When("Parse incorrect value", func() { + BeforeEach(func() { + data = []byte("130") + err = yaml.Unmarshal(data, &ipmask) + }) + + It("should be error", func() { + Expect(err).Should(HaveOccurred()) + }) + }) + }) + }) +}) diff --git a/docs/configuration.md b/docs/configuration.md index eca5bd0c..db717afe 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -12,7 +12,7 @@ configuration properties as [JSON](config.yml). ## Basic configuration | Parameter | Type | Mandatory | Default value | Description | -|---------------------|---------------------|-----------|---------------|------------------------------------------------------------------------------------------------------------| +| ------------------- | ------------------- | --------- | ------------- | ---------------------------------------------------------------------------------------------------------- | | certFile | path | no | | Path to cert and key file for SSL encryption (DoH and DoT); if empty, self-signed certificate is generated | | keyFile | path | no | | Path to cert and key file for SSL encryption (DoH and DoT); if empty, self-signed certificate is generated | | dohUserAgent | string | no | | HTTP User Agent for DoH upstreams | @@ -31,8 +31,8 @@ configuration properties as [JSON](config.yml). All logging port are optional. -| Parameter | Type | Default value | Description | -|------------|------------------------|---------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Parameter | Type | Default value | Description | +| ----------- | ---------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ports.dns | [IP]:port[,[IP]:port]* | 53 | Port(s) and optional bind ip address(es) to serve DNS endpoint (TCP and UDP). If you wish to specify a specific IP, you can do so such as `192.168.0.1:53`. Example: `53`, `:53`, `127.0.0.1:53,[::1]:53` | | ports.tls | [IP]:port[,[IP]:port]* | | Port(s) and optional bind ip address(es) to serve DoT DNS endpoint (DNS-over-TLS). If you wish to specify a specific IP, you can do so such as `192.168.0.1:853`. Example: `83`, `:853`, `127.0.0.1:853,[::1]:853` | | ports.http | [IP]:port[,[IP]:port]* | | Port(s) and optional bind ip address(es) to serve HTTP used for prometheus metrics, pprof, REST API, DoH... If you wish to specify a specific IP, you can do so such as `192.168.0.1:4000`. Example: `4000`, `:4000`, `127.0.0.1:4000,[::1]:4000` | @@ -52,7 +52,7 @@ All logging port are optional. All logging options are optional. | Parameter | Type | Default value | Description | -|---------------|---------------------------------|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------| +| ------------- | ------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ | | log.level | enum (debug, info, warn, error) | info | Log level | | log.format | enum (text, json) | text | Log format (text or json). | | log.timestamp | bool | true | Log time stamps (true or false). | @@ -86,7 +86,7 @@ following network protocols (net part of the resolver URL): Each resolver must be defined as a string in following format: `[net:]host:[port][/path][#commonName]`. | Parameter | Type | Mandatory | Default value | -|------------|----------------------------------|-----------|---------------------------------------------------| +| ---------- | -------------------------------- | --------- | ------------------------------------------------- | | net | enum (tcp+udp, tcp-tls or https) | no | tcp+udp | | host | IP or hostname | yes | | | port | int (1 - 65535) | no | 53 for udp/tcp, 853 for tcp-tls and 443 for https | @@ -182,7 +182,7 @@ These DNS servers are used to resolve upstream DoH and DoT servers that are spec It is useful if no system DNS resolver is configured, and/or to encrypt the bootstrap queries. | Parameter | Type | Mandatory | Default value | Description | -|-----------|----------------------|-----------------------------|---------------|--------------------------------------| +| --------- | -------------------- | --------------------------- | ------------- | ------------------------------------ | | upstream | Upstream (see above) | no | | | | ips | List of IPs | yes, if upstream is DoT/DoH | | Only valid if upstream is DoH or DoT | @@ -237,7 +237,7 @@ or define a domain name for your local device on order to use the HTTPS certific domain must be separated by a comma. | Parameter | Type | Mandatory | Default value | -|---------------------|-----------------------------------------|-----------|---------------| +| ------------------- | --------------------------------------- | --------- | ------------- | | customTTL | duration (no unit is minutes) | no | 1h | | rewrite | string: string (domain: domain) | no | | | mapping | string: string (hostname: address list) | no | | @@ -481,7 +481,7 @@ You can configure, which response should be sent to the client, if a requested q queries, NXDOMAIN for other types): | blockType | Example | Description | -|------------|---------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| ---------- | ------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | zeroIP | zeroIP | This is the default block type. Server returns 0.0.0.0 (or :: for IPv6) as result for A and AAAA queries | | nxDomain | nxDomain | return NXDOMAIN as return code | | custom IPs | 192.100.100.15, 2001:0db8:85a3:08d3:1319:8a2e:0370:7344 | comma separated list of destination IP addresses. Should contain ipv4 and ipv6 to cover all query types. Useful with running web server on this address to display the "blocked" page. | @@ -526,7 +526,7 @@ With following parameters you can tune the caching behavior: Wrong values can significantly increase external DNS traffic or memory consumption. | Parameter | Type | Mandatory | Default value | Description | -|-------------------------------|-----------------|-----------|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| ----------------------------- | --------------- | --------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | caching.minTime | duration format | no | 0 (use TTL) | How long a response must be cached (min value). If <=0, use response's TTL, if >0 use this value, if TTL is smaller | | caching.maxTime | duration format | no | 0 (use TTL) | How long a response must be cached (max value). If <0, do not cache responses. If 0, use TTL. If > 0, use this value, if TTL is greater | | caching.maxItemsCount | int | no | 0 (unlimited) | Max number of cache entries (responses) to be kept in cache (soft limit). Default (0): unlimited. Useful on systems with limited amount of RAM. | @@ -551,7 +551,7 @@ Blocky can synchronize its cache and blocking state between multiple instances t Synchronization is disabled if no address is configured. | Parameter | Type | Mandatory | Default value | Description | -|--------------------------|-----------------|-----------|---------------|---------------------------------------------------------------------| +| ------------------------ | --------------- | --------- | ------------- | ------------------------------------------------------------------- | | redis.address | string | no | | Server address and port or master name if sentinel is used | | redis.username | string | no | | Username if necessary | | redis.password | string | no | | Password if necessary | @@ -588,7 +588,7 @@ Blocky can expose various metrics for prometheus. To use the prometheus feature, see [Basic Configuration](#basic-configuration)). | Parameter | Mandatory | Default value | Description | -|-------------------|-----------|---------------|-------------------------------------| +| ----------------- | --------- | ------------- | ----------------------------------- | | prometheus.enable | no | false | If true, enables prometheus metrics | | prometheus.path | no | /metrics | URL path to the metrics endpoint | @@ -637,7 +637,7 @@ You can choose which information from processed DNS request and response should Configuration parameters: | Parameter | Type | Mandatory | Default value | Description | -|---------------------------|--------------------------------------------------------------------------------------|-----------|---------------|------------------------------------------------------------------------------------| +| ------------------------- | ------------------------------------------------------------------------------------ | --------- | ------------- | ---------------------------------------------------------------------------------- | | queryLog.type | enum (mysql, postgresql, csv, csv-client, console, none (see above)) | no | | Type of logging target. Console if empty | | queryLog.target | string | no | | directory for writing the logs (for csv) or database url (for mysql or postgresql) | | queryLog.logRetentionDays | int | no | 0 | if > 0, deletes log files/database entries which are older than ... days | @@ -682,7 +682,7 @@ You can enable resolving of entries, located in local hosts file. Configuration parameters: | Parameter | Type | Mandatory | Default value | Description | -|--------------------------|--------------------------------|-----------|---------------|-------------------------------------------------| +| ------------------------ | ------------------------------ | --------- | ------------- | ----------------------------------------------- | | hostsFile.filePath | string | no | | Path to hosts file (e.g. /etc/hosts on Linux) | | hostsFile.hostsTTL | duration (no units is minutes) | no | 1h | TTL | | hostsFile.refreshPeriod | duration format | no | 1h | Time between hosts file refresh | @@ -704,7 +704,7 @@ DNS responses can be extended with EDE codes according to [RFC8914](https://data Configuration parameters: | Parameter | Type | Mandatory | Default value | Description | -|------------|------|-----------|---------------|----------------------------------------------------| +| ---------- | ---- | --------- | ------------- | -------------------------------------------------- | | ede.enable | bool | no | false | If true, DNS responses are deliverd with EDE codes | !!! example @@ -714,6 +714,25 @@ Configuration parameters: enable: true ``` +## EDNS Client Subnet options + +EDNS Client Subnet (ECS) configuration parameters: + +| Parameter | Type | Mandatory | Default value | Description | +| --------------- | ---- | --------- | ------------- | --------------------------------------------------------------------------------------------- | +| ecs.useAsClient | bool | no | false | Use ECS information if it is present with a netmask is 32 for IPv4 or 128 for IPv6 as CientIP | +| ecs.forward | bool | no | false | Forward ECS option to upstream | +| ecs.ipv4Mask | int | no | 0 | Add ECS option for IPv4 requests if mask is greater than zero (max value 32) | +| ecs.ipv6Mask | int | no | 0 | Add ECS option for IPv6 requests if mask is greater than zero (max value 128) | + +!!! example + + ```yaml + ecs: + ipv4Mask: 32 + ipv6Mask: 128 + ``` + ## Special Use Domain Names SUDN (Special Use Domain Names) are always enabled as they are required by various RFCs. @@ -722,7 +741,7 @@ Some RFCs have optional recommendations, which are configurable as described bel Configuration parameters: | Parameter | Type | Mandatory | Default value | Description | -|-------------------------------------|------|-----------|---------------|-----------------------------------------------------------------------------------------------| +| ----------------------------------- | ---- | --------- | ------------- | --------------------------------------------------------------------------------------------- | | specialUseDomains.rfc6762-appendixG | bool | no | true | Block TLDs listed in [RFC 6762 Appendix G](https://www.rfc-editor.org/rfc/rfc6762#appendix-G) | !!! example @@ -801,7 +820,7 @@ A value of zero or less will disable this feature. Configures how HTTP(S) sources are downloaded: | Parameter | Type | Mandatory | Default value | Description | -|-----------|----------|-----------|---------------|------------------------------------------------| +| --------- | -------- | --------- | ------------- | ---------------------------------------------- | | timeout | duration | no | 5s | Download attempt timeout | | attempts | int | no | 3 | How many download attempts should be performed | | cooldown | duration | no | 500ms | Time between the download attempts | @@ -822,7 +841,7 @@ This configures how Blocky startup works. The default strategy is blocking. | strategy | Description | -|-------------|------------------------------------------------------------------------------------------------------------------------------------------| +| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------- | | blocking | all sources are loaded before DNS resolution starts | | failOnError | like blocking but blocky will shut down if any source fails to load | | fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after | diff --git a/helpertest/helper.go b/helpertest/helper.go index 7a1c00e4..f1f709e6 100644 --- a/helpertest/helper.go +++ b/helpertest/helper.go @@ -120,6 +120,32 @@ func HaveReturnCode(code int) types.GomegaMatcher { ) } +// HaveEdnsOption checks if the given message contains an EDNS0 record with the given option code. +func HaveEdnsOption(code uint16) types.GomegaMatcher { + return gcustom.MakeMatcher(func(actual any) (bool, error) { + var opt *dns.OPT + switch msg := actual.(type) { + case *model.Response: + opt = msg.Res.IsEdns0() + case *dns.Msg: + opt = msg.IsEdns0() + } + + if opt != nil { + for _, o := range opt.Option { + if o.Option() == code { + return true, nil + } + } + } + + return false, nil + }).WithTemplate( + "Expected:\n{{.Actual}}\n{{.To}} have EDNS option:\n{{format .Data 1}}", + code, + ) +} + func toFirstRR(actual interface{}) (dns.RR, error) { switch i := actual.(type) { case *model.Response: diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index e818546e..739d3ed8 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -157,12 +157,12 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) { logger.Infof("cache entries = %d", r.resultCache.TotalCount()) } -// Resolve checks if the current query result is already in the cache and returns it -// or delegates to the next resolver +// Resolve checks if the current query should use the cache and if the result is already in +// the cache and returns it or delegates to the next resolver func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) { logger := log.WithPrefix(request.Log, "caching_resolver") - if r.cfg.MaxCachingTime < 0 { + if !r.IsEnabled() || !isRequestCacheable(request) { logger.Debug("skip cache") return r.next.Resolve(request) @@ -232,22 +232,21 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) { } } -// removes EDNS OPT records from message -func removeEdns0Extra(msg *dns.Msg) { - if len(msg.Extra) > 0 { - extra := make([]dns.RR, 0, len(msg.Extra)) - - for _, rr := range msg.Extra { - if rr.Header().Rrtype != dns.TypeOPT { - extra = append(extra, rr) - } +// isRequestCacheable returns true if the request should be cached +func isRequestCacheable(request *model.Request) bool { + // don't cache responses with EDNS Client Subnet option with masks that include more than one client + if so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req); so != nil { + if (so.Family == ecsFamilyIPv4 && so.SourceNetmask != ecsMaskIPv4) || + (so.Family == ecsFamilyIPv6 && so.SourceNetmask != ecsMaskIPv6) { + return false } - - msg.Extra = extra } + + return true } -func shouldBeCached(msg *dns.Msg) bool { +// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set. +func isResponseCacheable(msg *dns.Msg) bool { // we don't cache truncated responses and responses with CD flag return !msg.Truncated && !msg.CheckingDisabled } @@ -258,13 +257,13 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response, respCopy := response.Res.Copy() // don't cache any EDNS OPT records - removeEdns0Extra(respCopy) + util.RemoveEdns0Record(respCopy) packed, err := respCopy.Pack() util.LogOnError("error on packing", err) if err == nil { - if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) { + if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) { // put value into cache r.resultCache.Put(cacheKey, &packed, ttl) } else if response.Res.Rcode == dns.RcodeNameError { diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index acc7be64..7f28bc29 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -3,6 +3,7 @@ package resolver import ( "context" "fmt" + "net" "time" "github.com/0xERR0R/blocky/config" @@ -56,8 +57,19 @@ var _ = Describe("CachingResolver", func() { }) Describe("IsEnabled", func() { - It("is false", func() { - Expect(sut.IsEnabled()).Should(BeFalse()) + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + + When("max caching time is negative", func() { + BeforeEach(func() { + sutConfig = config.CachingConfig{ + MaxCachingTime: config.Duration(time.Minute * -1), + } + }) + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) }) }) @@ -769,4 +781,39 @@ var _ = Describe("CachingResolver", func() { }) }) }) + Context("isRequestCacheable", func() { + var request *Request + When("request is not cacheable", func() { + BeforeEach(func() { + request = newRequest("example.com.", A) + e := new(dns.EDNS0_SUBNET) + e.SourceScope = 0 + e.Address = net.ParseIP("192.168.0.0") + e.Family = 1 + e.SourceNetmask = 24 + util.SetEdns0Option(request.Req, e) + }) + + It("should return false", func() { + Expect(isRequestCacheable(request)). + Should(BeFalse()) + }) + }) + When("request is cacheable", func() { + BeforeEach(func() { + request = newRequest("example.com.", A) + e := new(dns.EDNS0_SUBNET) + e.SourceScope = 0 + e.Address = net.ParseIP("192.168.0.10") + e.Family = 1 + e.SourceNetmask = 32 + util.SetEdns0Option(request.Req, e) + }) + + It("should return true", func() { + Expect(isRequestCacheable(request)). + Should(BeTrue()) + }) + }) + }) }) diff --git a/resolver/ecs_resolver.go b/resolver/ecs_resolver.go new file mode 100644 index 00000000..9a241550 --- /dev/null +++ b/resolver/ecs_resolver.go @@ -0,0 +1,104 @@ +package resolver + +import ( + "fmt" + "net" + + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" + "github.com/miekg/dns" +) + +// https://www.rfc-editor.org/rfc/rfc7871.html#section-6 +const ( + ecsSourceScope = uint8(0) + + ecsMaskIPv4 = uint8(net.IPv4len * 8) + ecsMaskIPv6 = uint8(net.IPv6len * 8) +) + +// https://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml +const ( + ecsFamilyIPv4 = uint16(iota + 1) + ecsFamilyIPv6 +) + +// ECSMask is an interface for all ECS subnet masks as type constraint for generics +type ECSMask interface { + config.ECSv4Mask | config.ECSv6Mask +} + +// ECSResolver is responsible for adding the EDNS Client Subnet information as EDNS0 option. +type ECSResolver struct { + configurable[*config.ECSConfig] + NextResolver + typed +} + +// NewECSResolver creates new resolver instance which adds the subnet information as EDNS0 option +func NewECSResolver(cfg config.ECSConfig) ChainedResolver { + return &ECSResolver{ + configurable: withConfig(&cfg), + typed: withType("extended_client_subnet"), + } +} + +// Resolve adds the subnet information as EDNS0 option to the request of the next resolver +// and sets the client IP from the EDNS0 option to the request if this option is enabled +func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) { + if r.cfg.IsEnabled() { + so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req) + // Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set + if r.cfg.UseAsClient && so != nil && ((so.Family == ecsFamilyIPv4 && so.SourceNetmask == ecsMaskIPv4) || + (so.Family == ecsFamilyIPv6 && so.SourceNetmask == ecsMaskIPv6)) { + request.ClientIP = so.Address + } + + // Set the Edns0 subnet option if the client IP is IPv4 or IPv6 and the masks are set in the configuration + if r.cfg.IPv4Mask > 0 || r.cfg.IPv6Mask > 0 { + r.setSubnet(request) + } + + // Remove the Edns0 subnet option if the client IP is IPv4 or IPv6 and the corresponding mask is not set + // and the forwardEcs option is not enabled + if r.cfg.IPv4Mask == 0 && r.cfg.IPv6Mask == 0 && so != nil && !r.cfg.Forward { + util.RemoveEdns0Option[*dns.EDNS0_SUBNET](request.Req) + } + } + + return r.next.Resolve(request) +} + +// setSubnet appends the subnet information to the request as EDNS0 option +// if the client IP is IPv4 or IPv6 and the corresponding mask is set in the configuration +func (r *ECSResolver) setSubnet(request *model.Request) { + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + e.SourceScope = ecsSourceScope + + if ip := request.ClientIP.To4(); ip != nil && r.cfg.IPv4Mask > 0 { + mip, err := maskIP(ip, r.cfg.IPv4Mask) + if err == nil { + e.Family = ecsFamilyIPv4 + e.SourceNetmask = uint8(r.cfg.IPv4Mask) + e.Address = mip + util.SetEdns0Option(request.Req, e) + } + } else if ip := request.ClientIP.To16(); ip != nil && r.cfg.IPv6Mask > 0 { + mip, err := maskIP(ip, r.cfg.IPv6Mask) + if err == nil { + e.Family = ecsFamilyIPv6 + e.SourceNetmask = uint8(r.cfg.IPv6Mask) + e.Address = mip + util.SetEdns0Option(request.Req, e) + } + } +} + +// maskIP masks the IP with the given mask and return an error if the mask is invalid +func maskIP[maskType ECSMask](ip net.IP, mask maskType) (net.IP, error) { + _, mip, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ip.String(), mask)) + + return mip.IP, err +} diff --git a/resolver/ecs_resolver_test.go b/resolver/ecs_resolver_test.go new file mode 100644 index 00000000..b10efb77 --- /dev/null +++ b/resolver/ecs_resolver_test.go @@ -0,0 +1,200 @@ +package resolver + +import ( + "net" + + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/util" + "github.com/creasty/defaults" + + . "github.com/0xERR0R/blocky/helpertest" + . "github.com/0xERR0R/blocky/model" + + "github.com/miekg/dns" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/stretchr/testify/mock" +) + +var _ = Describe("EcsResolver", func() { + var ( + sut *ECSResolver + sutConfig config.ECSConfig + m *mockResolver + mockAnswer *dns.Msg + err error + origIP net.IP + ecsIP net.IP + ) + + Describe("Type", func() { + It("follows conventions", func() { + expectValidResolverType(sut) + }) + }) + + BeforeEach(func() { + err = defaults.Set(&sutConfig) + Expect(err).ShouldNot(HaveOccurred()) + + mockAnswer = new(dns.Msg) + origIP = net.ParseIP("1.2.3.4") + ecsIP = net.ParseIP("4.3.2.1") + }) + + JustBeforeEach(func() { + if m == nil { + m = &mockResolver{} + m.On("Resolve", mock.Anything).Return(&Response{ + Res: mockAnswer, + RType: ResponseTypeCUSTOMDNS, + Reason: "Test", + }, nil) + } + + sut = NewECSResolver(sutConfig).(*ECSResolver) + sut.Next(m) + }) + + When("ecs is disabled", func() { + Describe("IsEnabled", func() { + It("is false", func() { + Expect(sut.IsEnabled()).Should(BeFalse()) + }) + }) + }) + + When("ecs is enabled", func() { + BeforeEach(func() { + sutConfig.UseAsClient = true + }) + + Describe("IsEnabled", func() { + It("is true", func() { + Expect(sut.IsEnabled()).Should(BeTrue()) + }) + }) + + When("use ecs client ip is enabled", func() { + BeforeEach(func() { + sutConfig.UseAsClient = true + }) + + It("should change ClientIP with subnet 32", func() { + request := newRequest("example.com.", A) + request.ClientIP = origIP + + addEcsOption(request.Req, ecsIP, ecsMaskIPv4) + + m.ResolveFn = func(req *Request) (*Response, error) { + Expect(req.ClientIP).Should(Equal(ecsIP)) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + + It("shouldn't change ClientIP with subnet 24", func() { + request := newRequest("example.com.", A) + request.ClientIP = origIP + + addEcsOption(request.Req, ecsIP, 24) + + m.ResolveFn = func(req *Request) (*Response, error) { + Expect(req.ClientIP).Should(Equal(origIP)) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + }) + + When("forward ecs is enabled", func() { + BeforeEach(func() { + sutConfig.Forward = true + sutConfig.IPv4Mask = 32 + sutConfig.IPv6Mask = 128 + }) + + It("should add Ecs information with subnet 32", func() { + request := newRequest("example.com.", A) + request.ClientIP = origIP + + m.ResolveFn = func(req *Request) (*Response, error) { + Expect(req.ClientIP).Should(Equal(origIP)) + Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + + It("should add Ecs information with subnet 128", func() { + request := newRequest("example.com.", AAAA) + request.ClientIP = net.ParseIP("2001:db8::68") + + m.ResolveFn = func(req *Request) (*Response, error) { + Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + + return respondWith(mockAnswer), nil + } + + Expect(sut.Resolve(request)). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveResponseType(ResponseTypeRESOLVED), + HaveReturnCode(dns.RcodeSuccess), + HaveReason("Test"))) + }) + }) + }) + + Context("maskIP", func() { + It("should mask IPv4", func() { + ip := net.ParseIP("192.168.10.123") + mask := config.ECSv4Mask(24) + + mip, err := maskIP(ip, mask) + Expect(err).ShouldNot(HaveOccurred()) + Expect(mip).Should(Equal(net.ParseIP("192.168.10.0").To4())) + }) + }) +}) + +// addEcsOption adds the subnet information to the request as EDNS0 option +func addEcsOption(req *dns.Msg, ip net.IP, netmask uint8) { + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + e.SourceScope = ecsSourceScope + e.Family = ecsFamilyIPv4 + e.SourceNetmask = netmask + e.Address = ip + util.SetEdns0Option(req, e) +} + +// respondWith creates a new Response with the given request and message +func respondWith(res *dns.Msg) *Response { + return &Response{Res: res, RType: ResponseTypeRESOLVED, Reason: "Test"} +} diff --git a/resolver/ede_resolver.go b/resolver/ede_resolver.go index 5e458d91..f6b507ff 100644 --- a/resolver/ede_resolver.go +++ b/resolver/ede_resolver.go @@ -3,15 +3,19 @@ package resolver import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/model" + "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" ) +// A EdeResolver is responsible for adding the reason for the response as EDNS0 option type EdeResolver struct { configurable[*config.EdeConfig] NextResolver typed } +// NewEdeResolver creates new resolver instance which adds the reason for +// the response as EDNS0 option to the response if it is enabled in the configuration func NewEdeResolver(cfg config.EdeConfig) *EdeResolver { return &EdeResolver{ configurable: withConfig(&cfg), @@ -19,6 +23,8 @@ func NewEdeResolver(cfg config.EdeConfig) *EdeResolver { } } +// Resolve adds the reason as EDNS0 option to the response of the next resolver +// if it is enabled in the configuration func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) { if !r.cfg.Enable { return r.next.Resolve(request) @@ -34,6 +40,7 @@ func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) { return resp, nil } +// addExtraReasoning adds the reason for the response as EDNS0 option func (r *EdeResolver) addExtraReasoning(res *model.Response) { infocode := res.RType.ToExtendedErrorCode() @@ -42,12 +49,9 @@ func (r *EdeResolver) addExtraReasoning(res *model.Response) { return } - opt := new(dns.OPT) - opt.Hdr.Name = "." - opt.Hdr.Rrtype = dns.TypeOPT - opt.Option = append(opt.Option, &dns.EDNS0_EDE{ - InfoCode: infocode, - ExtraText: res.Reason, - }) - res.Res.Extra = append(res.Res.Extra, opt) + edeOption := new(dns.EDNS0_EDE) + edeOption.InfoCode = infocode + edeOption.ExtraText = res.Reason + + util.SetEdns0Option(res.Res, edeOption) } diff --git a/resolver/ede_resolver_test.go b/resolver/ede_resolver_test.go index fc7e7ad5..cad90600 100644 --- a/resolver/ede_resolver_test.go +++ b/resolver/ede_resolver_test.go @@ -1,11 +1,14 @@ +// Description: Tests for ede_resolver.go package resolver import ( "errors" + "math" "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" + "github.com/0xERR0R/blocky/util" . "github.com/0xERR0R/blocky/model" @@ -60,7 +63,7 @@ var _ = Describe("EdeResolver", func() { HaveNoAnswer(), HaveResponseType(ResponseTypeCUSTOMDNS), HaveReturnCode(dns.RcodeSuccess), - WithTransform(ToExtra, BeEmpty()), + Not(HaveEdnsOption(dns.EDNS0EDE)), )) // delegated to next resolver @@ -81,8 +84,8 @@ var _ = Describe("EdeResolver", func() { } }) - extractFirstOptRecord := func(e []dns.RR) []dns.EDNS0 { - return e[0].(*dns.OPT).Option + extractEdeOption := func(res *Response) dns.EDNS0_EDE { + return *util.GetEdns0Option[*dns.EDNS0_EDE](res.Res) } It("should add EDE information", func() { @@ -92,19 +95,39 @@ var _ = Describe("EdeResolver", func() { HaveNoAnswer(), HaveResponseType(ResponseTypeCUSTOMDNS), HaveReturnCode(dns.RcodeSuccess), - // extra should contain one OPT record - WithTransform(ToExtra, + HaveEdnsOption(dns.EDNS0EDE), + WithTransform(extractEdeOption, SatisfyAll( - HaveLen(1), - WithTransform(extractFirstOptRecord, - SatisfyAll( - ContainElement(HaveField("InfoCode", Equal(dns.ExtendedErrorCodeForgedAnswer))), - ContainElement(HaveField("ExtraText", Equal("Test"))), - )), + HaveField("InfoCode", Equal(dns.ExtendedErrorCodeForgedAnswer)), + HaveField("ExtraText", Equal("Test")), )), )) }) + When("resolver returns other", func() { + BeforeEach(func() { + m = &mockResolver{} + m.On("Resolve", mock.Anything).Return(&Response{ + Res: mockAnswer, + RType: ResponseType(math.MaxInt), + Reason: "Test", + }, nil) + }) + + It("shouldn't add EDE information", func() { + Expect(sut.Resolve(newRequest("example.com.", A))). + Should( + SatisfyAll( + HaveNoAnswer(), + HaveReturnCode(dns.RcodeSuccess), + Not(HaveEdnsOption(dns.EDNS0EDE)), + )) + + // delegated to next resolver + Expect(m.Calls).Should(HaveLen(1)) + }) + }) + When("resolver returns an error", func() { resolveErr := errors.New("test") diff --git a/server/server.go b/server/server.go index 9105aec7..a091b6f3 100644 --- a/server/server.go +++ b/server/server.go @@ -415,6 +415,7 @@ func createQueryResolver( r = resolver.Chain( resolver.NewFilteringResolver(cfg.Filtering), resolver.NewFqdnOnlyResolver(cfg.FqdnOnly), + resolver.NewECSResolver(cfg.ECS), clientNames, resolver.NewEdeResolver(cfg.Ede), resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog), diff --git a/util/edns0.go b/util/edns0.go new file mode 100644 index 00000000..a4e86efd --- /dev/null +++ b/util/edns0.go @@ -0,0 +1,128 @@ +package util + +import ( + "fmt" + "slices" + + "github.com/miekg/dns" +) + +// EDNS0Option is an interface for all EDNS0 options as type constraint for generics. +type EDNS0Option interface { + *dns.EDNS0_SUBNET | *dns.EDNS0_EDE | *dns.EDNS0_LOCAL | *dns.EDNS0_NSID | *dns.EDNS0_COOKIE | *dns.EDNS0_UL + Option() uint16 +} + +// RemoveEdns0Record removes the OPT record from the Extra section of the given message. +// If the OPT record is removed, true will be returned. +func RemoveEdns0Record(msg *dns.Msg) bool { + if msg == nil || msg.IsEdns0() == nil { + return false + } + + for i, rr := range msg.Extra { + if rr.Header().Rrtype == dns.TypeOPT { + msg.Extra = slices.Delete(msg.Extra, i, i+1) + + return true + } + } + + return false +} + +// GetEdns0Option returns the option with the given code from the OPT record in the +// Extra section of the given message. +// If the option is not found, nil will be returned. +func GetEdns0Option[T EDNS0Option](msg *dns.Msg) T { + if msg == nil { + return nil + } + + opt := msg.IsEdns0() + if opt == nil { + return nil + } + + var t T + + for _, o := range opt.Option { + if o.Option() == t.Option() { + t, ok := o.(T) + if !ok { + panic(fmt.Errorf("dns option with code %d is not of type %T", t.Option(), t)) + } + + return t + } + } + + return nil +} + +// RemoveEdns0Option removes the option according to the given type from the OPT record +// in the Extra section of the given message. +// If there are no more options in the OPT record, the OPT record will be removed. +// If the option is successfully removed, true will be returned. +func RemoveEdns0Option[T EDNS0Option](msg *dns.Msg) bool { + if msg == nil { + return false + } + + opt := msg.IsEdns0() + if opt == nil { + return false + } + + res := false + + var t T + + for i, o := range opt.Option { + if o.Option() == t.Option() { + opt.Option = slices.Delete(opt.Option, i, i+1) + + res = true + + break + } + } + + if len(opt.Option) == 0 { + RemoveEdns0Record(msg) + } + + return res +} + +// SetEdns0Option adds the given option to the OPT record in the Extra section of the +// given message. +// If the option already exists, it will be replaced. +// If the option is successfully set, true will be returned. +func SetEdns0Option(msg *dns.Msg, opt dns.EDNS0) bool { + if msg == nil || opt == nil { + return false + } + + optRecord := msg.IsEdns0() + + if optRecord == nil { + optRecord = new(dns.OPT) + optRecord.Hdr.Name = "." + optRecord.Hdr.Rrtype = dns.TypeOPT + msg.Extra = append(msg.Extra, optRecord) + } + + newOpts := make([]dns.EDNS0, 0, len(optRecord.Option)+1) + + for _, o := range optRecord.Option { + if o.Option() != opt.Option() { + newOpts = append(newOpts, o) + } + } + + newOpts = append(newOpts, opt) + optRecord.Option = newOpts + + return true +} diff --git a/util/edns0_test.go b/util/edns0_test.go new file mode 100644 index 00000000..2129bb93 --- /dev/null +++ b/util/edns0_test.go @@ -0,0 +1,230 @@ +package util + +import ( + "net" + + . "github.com/0xERR0R/blocky/helpertest" + "github.com/miekg/dns" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + exampleDomain = "example.com" + testTxt = "test" +) + +var _ = Describe("EDNS0 utils", func() { + var baseMsg *dns.Msg + + BeforeEach(func() { + baseMsg = new(dns.Msg) + txt := new(dns.TXT) + txt.Hdr.Name = exampleDomain + "." + txt.Hdr.Rrtype = dns.TypeTXT + txt.Txt = []string{testTxt} + baseMsg.Extra = append(baseMsg.Extra, txt) + }) + + Describe("RemoveEdns0Record", func() { + When("OPT record is present", func() { + BeforeEach(func() { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + baseMsg.Extra = append(baseMsg.Extra, opt) + }) + + It("should remove it", func() { + Expect(RemoveEdns0Record(baseMsg)).Should(BeTrue()) + + Expect(baseMsg.IsEdns0()).Should(BeNil()) + }) + }) + + When("OPT record is not present", func() { + It("should do nothing ", func() { + Expect(RemoveEdns0Record(baseMsg)).Should(BeFalse()) + }) + }) + + When("Extra is nil", func() { + BeforeEach(func() { + baseMsg.Extra = nil + }) + + It("should do nothing", func() { + Expect(RemoveEdns0Record(baseMsg)).Should(BeFalse()) + }) + }) + + When("message is nil", func() { + It("should do nothing", func() { + Expect(RemoveEdns0Record(nil)).Should(BeFalse()) + }) + }) + }) + + Describe("GetEdns0Option", func() { + When("Option is present", func() { + var eso *dns.EDNS0_SUBNET + + BeforeEach(func() { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + eso = new(dns.EDNS0_SUBNET) + eso.Code = dns.EDNS0SUBNET + eso.Address = net.ParseIP("192.168.0.0") + eso.Family = 1 + eso.SourceNetmask = 24 + opt.Option = append(opt.Option, eso) + baseMsg.Extra = append(baseMsg.Extra, opt) + }) + + It("should return it", func() { + Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(Equal(eso)) + }) + }) + + When("Option is not present", func() { + BeforeEach(func() { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + opt.Option = append(opt.Option, new(dns.EDNS0_EDE)) + baseMsg.Extra = append(baseMsg.Extra, opt) + }) + + It("should return nil", func() { + Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeNil()) + }) + }) + + When("Extra is nil", func() { + BeforeEach(func() { + baseMsg.Extra = nil + }) + + It("should return nil", func() { + Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeNil()) + }) + }) + + When("message is nil", func() { + It("should return nil", func() { + Expect(GetEdns0Option[*dns.EDNS0_SUBNET](nil)).Should(BeNil()) + }) + }) + }) + + Describe("RemoveEdns0Option", func() { + When("Option is present", func() { + BeforeEach(func() { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + eso := new(dns.EDNS0_SUBNET) + eso.Code = dns.EDNS0SUBNET + opt.Option = append(opt.Option, eso) + baseMsg.Extra = append(baseMsg.Extra, opt) + }) + + It("should remove it", func() { + Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeTrue()) + + Expect(baseMsg).ShouldNot(HaveEdnsOption(dns.EDNS0SUBNET)) + }) + }) + + When("Option is not present", func() { + BeforeEach(func() { + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + opt.Option = append(opt.Option, new(dns.EDNS0_EDE)) + baseMsg.Extra = append(baseMsg.Extra, opt) + }) + It("should return false", func() { + Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeFalse()) + }) + }) + + When("Extra is nil", func() { + BeforeEach(func() { + baseMsg.Extra = nil + }) + + It("should return false", func() { + Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeFalse()) + }) + }) + + When("message is nil", func() { + It("should return false", func() { + Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](nil)).Should(BeFalse()) + }) + }) + }) + + Describe("SetEdns0Option", func() { + When("Option is not present", func() { + var eso *dns.EDNS0_SUBNET + + BeforeEach(func() { + Expect(baseMsg).ShouldNot(HaveEdnsOption(dns.EDNS0SUBNET)) + Expect(SetEdns0Option(baseMsg, new(dns.EDNS0_EDE))).Should(BeTrue()) + + eso = new(dns.EDNS0_SUBNET) + eso.Code = dns.EDNS0SUBNET + }) + + It("should add the option", func() { + Expect(SetEdns0Option(baseMsg, eso)).Should(BeTrue()) + + Expect(baseMsg).Should(HaveEdnsOption(dns.EDNS0SUBNET)) + }) + }) + + When("Option is present", func() { + var ( + eso *dns.EDNS0_SUBNET + eso2 *dns.EDNS0_SUBNET + ) + + BeforeEach(func() { + eso = new(dns.EDNS0_SUBNET) + eso.Code = dns.EDNS0SUBNET + eso.Address = net.ParseIP("1.1.1.1") + eso.Family = 1 + eso.SourceNetmask = 32 + + eso2 = new(dns.EDNS0_SUBNET) + eso2.Code = dns.EDNS0SUBNET + eso2.Address = net.ParseIP("2.2.2.2") + eso2.Family = 1 + eso2.SourceNetmask = 32 + + Expect(SetEdns0Option(baseMsg, eso)).Should(BeTrue()) + }) + + It("should replace it", func() { + Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(Equal(eso)) + Expect(SetEdns0Option(baseMsg, eso2)).Should(BeTrue()) + Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(Equal(eso2)) + }) + }) + + When("message is nil", func() { + It("should return false", func() { + Expect(SetEdns0Option(nil, new(dns.EDNS0_SUBNET))).Should(BeFalse()) + }) + }) + + When("option is nil", func() { + It("should do nothing if option is nil", func() { + Expect(SetEdns0Option(baseMsg, nil)).Should(BeFalse()) + }) + }) + }) +})