EDNS: Client Subnet (#1007)

* added util for handling EDNS0 options

* disable caching if the request contains a netmask size greater than 1

* added config section for ECS handling and validation for it

*added ecs_resolver for enhancing and cleaning subnet and client IP information
This commit is contained in:
Kwitsch 2023-11-20 16:56:56 +01:00 committed by GitHub
parent d52c598546
commit d37d18348f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1116 additions and 66 deletions

View File

@ -31,9 +31,7 @@
"GitHub.vscode-github-actions"
],
"settings": {
"go.lintFlags": [
"--config=${containerWorkspaceFolder}/.golangci.yml"
],
"go.lintFlags": ["--config=${containerWorkspaceFolder}/.golangci.yml"],
"go.alternateTools": {
"go-langserver": "gopls"
},

View File

@ -2,17 +2,17 @@
FOLDER_PATH=$1
if [ -z "${FOLDER_PATH}" ]; then
FOLDER_PATH=$PWD
FOLDER_PATH=$PWD
fi
BASE_PATH=$2
if [ -z "${BASE_PATH}" ]; then
BASE_PATH=$WORKSPACE_FOLDER
BASE_PATH=$WORKSPACE_FOLDER
fi
if [ "$FOLDER_PATH" = "$BASE_PATH" ]; then
echo "Skipping lcov creation for base path"
exit 1
echo "Skipping lcov creation for base path"
exit 1
fi
FOLDER_NAME=${FOLDER_PATH#"$BASE_PATH/"}

View File

@ -20,7 +20,7 @@ type CachingConfig struct {
// IsEnabled implements `config.Configurable`.
func (c *CachingConfig) IsEnabled() bool {
return c.MaxCachingTime.IsAboveZero()
return c.MaxCachingTime.IsAtLeastZero()
}
// LogConfig implements `config.Configurable`.

View File

@ -20,11 +20,22 @@ var _ = Describe("CachingConfig", func() {
})
Describe("IsEnabled", func() {
It("should be false by default", func() {
It("should be true by default", func() {
cfg := CachingConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Expect(cfg.IsEnabled()).Should(BeFalse())
Expect(cfg.IsEnabled()).Should(BeTrue())
})
When("the config is disabled", func() {
BeforeEach(func() {
cfg = CachingConfig{
MaxCachingTime: Duration(time.Hour * -1),
}
})
It("should be false", func() {
Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
When("the config is enabled", func() {
@ -72,7 +83,7 @@ var _ = Describe("CachingConfig", func() {
Expect(cfg.Prefetching).Should(BeTrue())
Expect(cfg.PrefetchThreshold).Should(Equal(0))
Expect(cfg.MaxCachingTime).ShouldNot(BeZero())
Expect(cfg.MaxCachingTime).Should(BeZero())
})
})
})

View File

@ -211,6 +211,7 @@ type Config struct {
FqdnOnly FqdnOnlyConfig `yaml:"fqdnOnly"`
Filtering FilteringConfig `yaml:"filtering"`
Ede EdeConfig `yaml:"ede"`
ECS ECSConfig `yaml:"ecs"`
SUDN SUDNConfig `yaml:"specialUseDomains"`
// Deprecated options

View File

@ -8,24 +8,35 @@ import (
"github.com/hako/durafmt"
)
// Duration is a wrapper for time.Duration to support yaml unmarshalling
type Duration time.Duration
// ToDuration converts Duration to time.Duration
func (c Duration) ToDuration() time.Duration {
return time.Duration(c)
}
// IsAboveZero returns true if duration is above zero
func (c Duration) IsAboveZero() bool {
return c.ToDuration() > 0
}
// IsAtLeastZero returns true if duration is at least zero
func (c Duration) IsAtLeastZero() bool {
return c.ToDuration() >= 0
}
// Seconds returns duration in seconds
func (c Duration) Seconds() float64 {
return c.ToDuration().Seconds()
}
// SecondsU32 returns duration in seconds as uint32
func (c Duration) SecondsU32() uint32 {
return uint32(c.Seconds())
}
// String implements `fmt.Stringer`
func (c Duration) String() string {
return durafmt.Parse(c.ToDuration()).String()
}

81
config/ecs.go Normal file
View File

@ -0,0 +1,81 @@
package config
import (
"fmt"
"strconv"
"github.com/sirupsen/logrus"
)
const (
ecsIpv4MaskMax = uint8(32)
ecsIpv6MaskMax = uint8(128)
)
// ECSv4Mask is the subnet mask to be added as EDNS0 option for IPv4
type ECSv4Mask uint8
// UnmarshalText implements the encoding.TextUnmarshaler interface
func (x *ECSv4Mask) UnmarshalText(text []byte) error {
res, err := unmarshalInternal(text, ecsIpv4MaskMax, "IPv4")
if err != nil {
return err
}
*x = ECSv4Mask(res)
return nil
}
// ECSv6Mask is the subnet mask to be added as EDNS0 option for IPv6
type ECSv6Mask uint8
// UnmarshalText implements the encoding.TextUnmarshaler interface
func (x *ECSv6Mask) UnmarshalText(text []byte) error {
res, err := unmarshalInternal(text, ecsIpv6MaskMax, "IPv6")
if err != nil {
return err
}
*x = ECSv6Mask(res)
return nil
}
// ECSConfig is the configuration of the ECS resolver
type ECSConfig struct {
UseAsClient bool `yaml:"useAsClient" default:"false"`
Forward bool `yaml:"forward" default:"false"`
IPv4Mask ECSv4Mask `yaml:"ipv4Mask" default:"0"`
IPv6Mask ECSv6Mask `yaml:"ipv6Mask" default:"0"`
}
// IsEnabled returns true if the ECS resolver is enabled
func (c *ECSConfig) IsEnabled() bool {
return c.UseAsClient || c.Forward || c.IPv4Mask > 0 || c.IPv6Mask > 0
}
// LogConfig logs the configuration
func (c *ECSConfig) LogConfig(logger *logrus.Entry) {
logger.Infof("Use as client = %t", c.UseAsClient)
logger.Infof("Forward = %t", c.Forward)
logger.Infof("IPv4 netmask = %d", c.IPv4Mask)
logger.Infof("IPv6 netmask = %d", c.IPv6Mask)
}
// unmarshalInternal unmarshals the subnet mask from the given text and checks if the value is valid
// it is used by the UnmarshalText methods of ECSv4Mask and ECSv6Mask
func unmarshalInternal(text []byte, maxvalue uint8, name string) (uint8, error) {
strVal := string(text)
uiVal, err := strconv.ParseUint(strVal, 10, 8)
if err != nil {
return 0, err
}
if uiVal > uint64(maxvalue) {
return 0, fmt.Errorf("mask value (%s) is too large for %s(max: %d)", strVal, name, maxvalue)
}
return uint8(uiVal), nil
}

167
config/ecs_test.go Normal file
View File

@ -0,0 +1,167 @@
package config
import (
"github.com/0xERR0R/blocky/log"
"github.com/creasty/defaults"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"gopkg.in/yaml.v2"
)
var _ = Describe("ECSConfig", func() {
var (
c ECSConfig
err error
)
BeforeEach(func() {
err = defaults.Set(&c)
Expect(err).Should(Succeed())
})
Describe("IsEnabled", func() {
When("all fields are default", func() {
It("should be disabled", func() {
Expect(c.IsEnabled()).Should(BeFalse())
})
})
When("UseEcsAsClient is true", func() {
BeforeEach(func() {
c.UseAsClient = true
})
It("should be enabled", func() {
Expect(c.IsEnabled()).Should(BeTrue())
})
})
When("ForwardEcs is true", func() {
BeforeEach(func() {
c.Forward = true
})
It("should be enabled", func() {
Expect(c.IsEnabled()).Should(BeTrue())
})
})
When("IPv4Mask is set", func() {
BeforeEach(func() {
c.IPv4Mask = 24
})
It("should be enabled", func() {
Expect(c.IsEnabled()).Should(BeTrue())
})
})
When("IPv6Mask is set", func() {
BeforeEach(func() {
c.IPv6Mask = 64
})
It("should be enabled", func() {
Expect(c.IsEnabled()).Should(BeTrue())
})
})
})
Describe("LogConfig", func() {
BeforeEach(func() {
logger, hook = log.NewMockEntry()
})
It("should log configuration", func() {
c.LogConfig(logger)
Expect(hook.Calls).Should(HaveLen(4))
Expect(hook.Messages).Should(SatisfyAll(
ContainElement(ContainSubstring("Use as client")),
ContainElement(ContainSubstring("Forward")),
ContainElement(ContainSubstring("IPv4 netmask")),
ContainElement(ContainSubstring("IPv6 netmask")),
))
})
})
Describe("Parse", func() {
var data []byte
Context("IPv4Mask", func() {
var ipmask ECSv4Mask
When("Parse correct value", func() {
BeforeEach(func() {
data = []byte("24")
err = yaml.Unmarshal(data, &ipmask)
Expect(err).Should(Succeed())
})
It("should be value", func() {
Expect(ipmask).Should(Equal(ECSv4Mask(24)))
})
})
When("Parse NaN value", func() {
BeforeEach(func() {
data = []byte("FALSE")
err = yaml.Unmarshal(data, &ipmask)
})
It("should be error", func() {
Expect(err).Should(HaveOccurred())
})
})
When("Parse incorrect value", func() {
BeforeEach(func() {
data = []byte("35")
err = yaml.Unmarshal(data, &ipmask)
})
It("should be error", func() {
Expect(err).Should(HaveOccurred())
})
})
})
Context("IPv6Mask", func() {
var ipmask ECSv6Mask
When("Parse correct value", func() {
BeforeEach(func() {
data = []byte("64")
err = yaml.Unmarshal(data, &ipmask)
Expect(err).Should(Succeed())
})
It("should be value", func() {
Expect(ipmask).Should(Equal(ECSv6Mask(64)))
})
})
When("Parse NaN value", func() {
BeforeEach(func() {
data = []byte("FALSE")
err = yaml.Unmarshal(data, &ipmask)
})
It("should be error", func() {
Expect(err).Should(HaveOccurred())
})
})
When("Parse incorrect value", func() {
BeforeEach(func() {
data = []byte("130")
err = yaml.Unmarshal(data, &ipmask)
})
It("should be error", func() {
Expect(err).Should(HaveOccurred())
})
})
})
})
})

View File

@ -12,7 +12,7 @@ configuration properties as [JSON](config.yml).
## Basic configuration
| Parameter | Type | Mandatory | Default value | Description |
|---------------------|---------------------|-----------|---------------|------------------------------------------------------------------------------------------------------------|
| ------------------- | ------------------- | --------- | ------------- | ---------------------------------------------------------------------------------------------------------- |
| certFile | path | no | | Path to cert and key file for SSL encryption (DoH and DoT); if empty, self-signed certificate is generated |
| keyFile | path | no | | Path to cert and key file for SSL encryption (DoH and DoT); if empty, self-signed certificate is generated |
| dohUserAgent | string | no | | HTTP User Agent for DoH upstreams |
@ -31,8 +31,8 @@ configuration properties as [JSON](config.yml).
All logging port are optional.
| Parameter | Type | Default value | Description |
|------------|------------------------|---------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Parameter | Type | Default value | Description |
| ----------- | ---------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ports.dns | [IP]:port[,[IP]:port]* | 53 | Port(s) and optional bind ip address(es) to serve DNS endpoint (TCP and UDP). If you wish to specify a specific IP, you can do so such as `192.168.0.1:53`. Example: `53`, `:53`, `127.0.0.1:53,[::1]:53` |
| ports.tls | [IP]:port[,[IP]:port]* | | Port(s) and optional bind ip address(es) to serve DoT DNS endpoint (DNS-over-TLS). If you wish to specify a specific IP, you can do so such as `192.168.0.1:853`. Example: `83`, `:853`, `127.0.0.1:853,[::1]:853` |
| ports.http | [IP]:port[,[IP]:port]* | | Port(s) and optional bind ip address(es) to serve HTTP used for prometheus metrics, pprof, REST API, DoH... If you wish to specify a specific IP, you can do so such as `192.168.0.1:4000`. Example: `4000`, `:4000`, `127.0.0.1:4000,[::1]:4000` |
@ -52,7 +52,7 @@ All logging port are optional.
All logging options are optional.
| Parameter | Type | Default value | Description |
|---------------|---------------------------------|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| ------------- | ------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
| log.level | enum (debug, info, warn, error) | info | Log level |
| log.format | enum (text, json) | text | Log format (text or json). |
| log.timestamp | bool | true | Log time stamps (true or false). |
@ -86,7 +86,7 @@ following network protocols (net part of the resolver URL):
Each resolver must be defined as a string in following format: `[net:]host:[port][/path][#commonName]`.
| Parameter | Type | Mandatory | Default value |
|------------|----------------------------------|-----------|---------------------------------------------------|
| ---------- | -------------------------------- | --------- | ------------------------------------------------- |
| net | enum (tcp+udp, tcp-tls or https) | no | tcp+udp |
| host | IP or hostname | yes | |
| port | int (1 - 65535) | no | 53 for udp/tcp, 853 for tcp-tls and 443 for https |
@ -182,7 +182,7 @@ These DNS servers are used to resolve upstream DoH and DoT servers that are spec
It is useful if no system DNS resolver is configured, and/or to encrypt the bootstrap queries.
| Parameter | Type | Mandatory | Default value | Description |
|-----------|----------------------|-----------------------------|---------------|--------------------------------------|
| --------- | -------------------- | --------------------------- | ------------- | ------------------------------------ |
| upstream | Upstream (see above) | no | | |
| ips | List of IPs | yes, if upstream is DoT/DoH | | Only valid if upstream is DoH or DoT |
@ -237,7 +237,7 @@ or define a domain name for your local device on order to use the HTTPS certific
domain must be separated by a comma.
| Parameter | Type | Mandatory | Default value |
|---------------------|-----------------------------------------|-----------|---------------|
| ------------------- | --------------------------------------- | --------- | ------------- |
| customTTL | duration (no unit is minutes) | no | 1h |
| rewrite | string: string (domain: domain) | no | |
| mapping | string: string (hostname: address list) | no | |
@ -481,7 +481,7 @@ You can configure, which response should be sent to the client, if a requested q
queries, NXDOMAIN for other types):
| blockType | Example | Description |
|------------|---------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| ---------- | ------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| zeroIP | zeroIP | This is the default block type. Server returns 0.0.0.0 (or :: for IPv6) as result for A and AAAA queries |
| nxDomain | nxDomain | return NXDOMAIN as return code |
| custom IPs | 192.100.100.15, 2001:0db8:85a3:08d3:1319:8a2e:0370:7344 | comma separated list of destination IP addresses. Should contain ipv4 and ipv6 to cover all query types. Useful with running web server on this address to display the "blocked" page. |
@ -526,7 +526,7 @@ With following parameters you can tune the caching behavior:
Wrong values can significantly increase external DNS traffic or memory consumption.
| Parameter | Type | Mandatory | Default value | Description |
|-------------------------------|-----------------|-----------|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| ----------------------------- | --------------- | --------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| caching.minTime | duration format | no | 0 (use TTL) | How long a response must be cached (min value). If <=0, use response's TTL, if >0 use this value, if TTL is smaller |
| caching.maxTime | duration format | no | 0 (use TTL) | How long a response must be cached (max value). If <0, do not cache responses. If 0, use TTL. If > 0, use this value, if TTL is greater |
| caching.maxItemsCount | int | no | 0 (unlimited) | Max number of cache entries (responses) to be kept in cache (soft limit). Default (0): unlimited. Useful on systems with limited amount of RAM. |
@ -551,7 +551,7 @@ Blocky can synchronize its cache and blocking state between multiple instances t
Synchronization is disabled if no address is configured.
| Parameter | Type | Mandatory | Default value | Description |
|--------------------------|-----------------|-----------|---------------|---------------------------------------------------------------------|
| ------------------------ | --------------- | --------- | ------------- | ------------------------------------------------------------------- |
| redis.address | string | no | | Server address and port or master name if sentinel is used |
| redis.username | string | no | | Username if necessary |
| redis.password | string | no | | Password if necessary |
@ -588,7 +588,7 @@ Blocky can expose various metrics for prometheus. To use the prometheus feature,
see [Basic Configuration](#basic-configuration)).
| Parameter | Mandatory | Default value | Description |
|-------------------|-----------|---------------|-------------------------------------|
| ----------------- | --------- | ------------- | ----------------------------------- |
| prometheus.enable | no | false | If true, enables prometheus metrics |
| prometheus.path | no | /metrics | URL path to the metrics endpoint |
@ -637,7 +637,7 @@ You can choose which information from processed DNS request and response should
Configuration parameters:
| Parameter | Type | Mandatory | Default value | Description |
|---------------------------|--------------------------------------------------------------------------------------|-----------|---------------|------------------------------------------------------------------------------------|
| ------------------------- | ------------------------------------------------------------------------------------ | --------- | ------------- | ---------------------------------------------------------------------------------- |
| queryLog.type | enum (mysql, postgresql, csv, csv-client, console, none (see above)) | no | | Type of logging target. Console if empty |
| queryLog.target | string | no | | directory for writing the logs (for csv) or database url (for mysql or postgresql) |
| queryLog.logRetentionDays | int | no | 0 | if > 0, deletes log files/database entries which are older than ... days |
@ -682,7 +682,7 @@ You can enable resolving of entries, located in local hosts file.
Configuration parameters:
| Parameter | Type | Mandatory | Default value | Description |
|--------------------------|--------------------------------|-----------|---------------|-------------------------------------------------|
| ------------------------ | ------------------------------ | --------- | ------------- | ----------------------------------------------- |
| hostsFile.filePath | string | no | | Path to hosts file (e.g. /etc/hosts on Linux) |
| hostsFile.hostsTTL | duration (no units is minutes) | no | 1h | TTL |
| hostsFile.refreshPeriod | duration format | no | 1h | Time between hosts file refresh |
@ -704,7 +704,7 @@ DNS responses can be extended with EDE codes according to [RFC8914](https://data
Configuration parameters:
| Parameter | Type | Mandatory | Default value | Description |
|------------|------|-----------|---------------|----------------------------------------------------|
| ---------- | ---- | --------- | ------------- | -------------------------------------------------- |
| ede.enable | bool | no | false | If true, DNS responses are deliverd with EDE codes |
!!! example
@ -714,6 +714,25 @@ Configuration parameters:
enable: true
```
## EDNS Client Subnet options
EDNS Client Subnet (ECS) configuration parameters:
| Parameter | Type | Mandatory | Default value | Description |
| --------------- | ---- | --------- | ------------- | --------------------------------------------------------------------------------------------- |
| ecs.useAsClient | bool | no | false | Use ECS information if it is present with a netmask is 32 for IPv4 or 128 for IPv6 as CientIP |
| ecs.forward | bool | no | false | Forward ECS option to upstream |
| ecs.ipv4Mask | int | no | 0 | Add ECS option for IPv4 requests if mask is greater than zero (max value 32) |
| ecs.ipv6Mask | int | no | 0 | Add ECS option for IPv6 requests if mask is greater than zero (max value 128) |
!!! example
```yaml
ecs:
ipv4Mask: 32
ipv6Mask: 128
```
## Special Use Domain Names
SUDN (Special Use Domain Names) are always enabled as they are required by various RFCs.
@ -722,7 +741,7 @@ Some RFCs have optional recommendations, which are configurable as described bel
Configuration parameters:
| Parameter | Type | Mandatory | Default value | Description |
|-------------------------------------|------|-----------|---------------|-----------------------------------------------------------------------------------------------|
| ----------------------------------- | ---- | --------- | ------------- | --------------------------------------------------------------------------------------------- |
| specialUseDomains.rfc6762-appendixG | bool | no | true | Block TLDs listed in [RFC 6762 Appendix G](https://www.rfc-editor.org/rfc/rfc6762#appendix-G) |
!!! example
@ -801,7 +820,7 @@ A value of zero or less will disable this feature.
Configures how HTTP(S) sources are downloaded:
| Parameter | Type | Mandatory | Default value | Description |
|-----------|----------|-----------|---------------|------------------------------------------------|
| --------- | -------- | --------- | ------------- | ---------------------------------------------- |
| timeout | duration | no | 5s | Download attempt timeout |
| attempts | int | no | 3 | How many download attempts should be performed |
| cooldown | duration | no | 500ms | Time between the download attempts |
@ -822,7 +841,7 @@ This configures how Blocky startup works.
The default strategy is blocking.
| strategy | Description |
|-------------|------------------------------------------------------------------------------------------------------------------------------------------|
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| blocking | all sources are loaded before DNS resolution starts |
| failOnError | like blocking but blocky will shut down if any source fails to load |
| fast | blocky starts serving DNS immediately and sources are loaded asynchronously. The features requiring the sources should enable soon after |

View File

@ -120,6 +120,32 @@ func HaveReturnCode(code int) types.GomegaMatcher {
)
}
// HaveEdnsOption checks if the given message contains an EDNS0 record with the given option code.
func HaveEdnsOption(code uint16) types.GomegaMatcher {
return gcustom.MakeMatcher(func(actual any) (bool, error) {
var opt *dns.OPT
switch msg := actual.(type) {
case *model.Response:
opt = msg.Res.IsEdns0()
case *dns.Msg:
opt = msg.IsEdns0()
}
if opt != nil {
for _, o := range opt.Option {
if o.Option() == code {
return true, nil
}
}
}
return false, nil
}).WithTemplate(
"Expected:\n{{.Actual}}\n{{.To}} have EDNS option:\n{{format .Data 1}}",
code,
)
}
func toFirstRR(actual interface{}) (dns.RR, error) {
switch i := actual.(type) {
case *model.Response:

View File

@ -157,12 +157,12 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {
logger.Infof("cache entries = %d", r.resultCache.TotalCount())
}
// Resolve checks if the current query result is already in the cache and returns it
// or delegates to the next resolver
// Resolve checks if the current query should use the cache and if the result is already in
// the cache and returns it or delegates to the next resolver
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver")
if r.cfg.MaxCachingTime < 0 {
if !r.IsEnabled() || !isRequestCacheable(request) {
logger.Debug("skip cache")
return r.next.Resolve(request)
@ -232,22 +232,21 @@ func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) {
}
}
// removes EDNS OPT records from message
func removeEdns0Extra(msg *dns.Msg) {
if len(msg.Extra) > 0 {
extra := make([]dns.RR, 0, len(msg.Extra))
for _, rr := range msg.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
extra = append(extra, rr)
}
// isRequestCacheable returns true if the request should be cached
func isRequestCacheable(request *model.Request) bool {
// don't cache responses with EDNS Client Subnet option with masks that include more than one client
if so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req); so != nil {
if (so.Family == ecsFamilyIPv4 && so.SourceNetmask != ecsMaskIPv4) ||
(so.Family == ecsFamilyIPv6 && so.SourceNetmask != ecsMaskIPv6) {
return false
}
msg.Extra = extra
}
return true
}
func shouldBeCached(msg *dns.Msg) bool {
// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set.
func isResponseCacheable(msg *dns.Msg) bool {
// we don't cache truncated responses and responses with CD flag
return !msg.Truncated && !msg.CheckingDisabled
}
@ -258,13 +257,13 @@ func (r *CachingResolver) putInCache(cacheKey string, response *model.Response,
respCopy := response.Res.Copy()
// don't cache any EDNS OPT records
removeEdns0Extra(respCopy)
util.RemoveEdns0Record(respCopy)
packed, err := respCopy.Pack()
util.LogOnError("error on packing", err)
if err == nil {
if response.Res.Rcode == dns.RcodeSuccess && shouldBeCached(response.Res) {
if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) {
// put value into cache
r.resultCache.Put(cacheKey, &packed, ttl)
} else if response.Res.Rcode == dns.RcodeNameError {

View File

@ -3,6 +3,7 @@ package resolver
import (
"context"
"fmt"
"net"
"time"
"github.com/0xERR0R/blocky/config"
@ -56,8 +57,19 @@ var _ = Describe("CachingResolver", func() {
})
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
When("max caching time is negative", func() {
BeforeEach(func() {
sutConfig = config.CachingConfig{
MaxCachingTime: config.Duration(time.Minute * -1),
}
})
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
})
@ -769,4 +781,39 @@ var _ = Describe("CachingResolver", func() {
})
})
})
Context("isRequestCacheable", func() {
var request *Request
When("request is not cacheable", func() {
BeforeEach(func() {
request = newRequest("example.com.", A)
e := new(dns.EDNS0_SUBNET)
e.SourceScope = 0
e.Address = net.ParseIP("192.168.0.0")
e.Family = 1
e.SourceNetmask = 24
util.SetEdns0Option(request.Req, e)
})
It("should return false", func() {
Expect(isRequestCacheable(request)).
Should(BeFalse())
})
})
When("request is cacheable", func() {
BeforeEach(func() {
request = newRequest("example.com.", A)
e := new(dns.EDNS0_SUBNET)
e.SourceScope = 0
e.Address = net.ParseIP("192.168.0.10")
e.Family = 1
e.SourceNetmask = 32
util.SetEdns0Option(request.Req, e)
})
It("should return true", func() {
Expect(isRequestCacheable(request)).
Should(BeTrue())
})
})
})
})

104
resolver/ecs_resolver.go Normal file
View File

@ -0,0 +1,104 @@
package resolver
import (
"fmt"
"net"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
)
// https://www.rfc-editor.org/rfc/rfc7871.html#section-6
const (
ecsSourceScope = uint8(0)
ecsMaskIPv4 = uint8(net.IPv4len * 8)
ecsMaskIPv6 = uint8(net.IPv6len * 8)
)
// https://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml
const (
ecsFamilyIPv4 = uint16(iota + 1)
ecsFamilyIPv6
)
// ECSMask is an interface for all ECS subnet masks as type constraint for generics
type ECSMask interface {
config.ECSv4Mask | config.ECSv6Mask
}
// ECSResolver is responsible for adding the EDNS Client Subnet information as EDNS0 option.
type ECSResolver struct {
configurable[*config.ECSConfig]
NextResolver
typed
}
// NewECSResolver creates new resolver instance which adds the subnet information as EDNS0 option
func NewECSResolver(cfg config.ECSConfig) ChainedResolver {
return &ECSResolver{
configurable: withConfig(&cfg),
typed: withType("extended_client_subnet"),
}
}
// Resolve adds the subnet information as EDNS0 option to the request of the next resolver
// and sets the client IP from the EDNS0 option to the request if this option is enabled
func (r *ECSResolver) Resolve(request *model.Request) (*model.Response, error) {
if r.cfg.IsEnabled() {
so := util.GetEdns0Option[*dns.EDNS0_SUBNET](request.Req)
// Set the client IP from the Edns0 subnet option if the option is enabled and the correct subnet mask is set
if r.cfg.UseAsClient && so != nil && ((so.Family == ecsFamilyIPv4 && so.SourceNetmask == ecsMaskIPv4) ||
(so.Family == ecsFamilyIPv6 && so.SourceNetmask == ecsMaskIPv6)) {
request.ClientIP = so.Address
}
// Set the Edns0 subnet option if the client IP is IPv4 or IPv6 and the masks are set in the configuration
if r.cfg.IPv4Mask > 0 || r.cfg.IPv6Mask > 0 {
r.setSubnet(request)
}
// Remove the Edns0 subnet option if the client IP is IPv4 or IPv6 and the corresponding mask is not set
// and the forwardEcs option is not enabled
if r.cfg.IPv4Mask == 0 && r.cfg.IPv6Mask == 0 && so != nil && !r.cfg.Forward {
util.RemoveEdns0Option[*dns.EDNS0_SUBNET](request.Req)
}
}
return r.next.Resolve(request)
}
// setSubnet appends the subnet information to the request as EDNS0 option
// if the client IP is IPv4 or IPv6 and the corresponding mask is set in the configuration
func (r *ECSResolver) setSubnet(request *model.Request) {
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
e.SourceScope = ecsSourceScope
if ip := request.ClientIP.To4(); ip != nil && r.cfg.IPv4Mask > 0 {
mip, err := maskIP(ip, r.cfg.IPv4Mask)
if err == nil {
e.Family = ecsFamilyIPv4
e.SourceNetmask = uint8(r.cfg.IPv4Mask)
e.Address = mip
util.SetEdns0Option(request.Req, e)
}
} else if ip := request.ClientIP.To16(); ip != nil && r.cfg.IPv6Mask > 0 {
mip, err := maskIP(ip, r.cfg.IPv6Mask)
if err == nil {
e.Family = ecsFamilyIPv6
e.SourceNetmask = uint8(r.cfg.IPv6Mask)
e.Address = mip
util.SetEdns0Option(request.Req, e)
}
}
}
// maskIP masks the IP with the given mask and return an error if the mask is invalid
func maskIP[maskType ECSMask](ip net.IP, mask maskType) (net.IP, error) {
_, mip, err := net.ParseCIDR(fmt.Sprintf("%s/%d", ip.String(), mask))
return mip.IP, err
}

View File

@ -0,0 +1,200 @@
package resolver
import (
"net"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/util"
"github.com/creasty/defaults"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/0xERR0R/blocky/model"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/mock"
)
var _ = Describe("EcsResolver", func() {
var (
sut *ECSResolver
sutConfig config.ECSConfig
m *mockResolver
mockAnswer *dns.Msg
err error
origIP net.IP
ecsIP net.IP
)
Describe("Type", func() {
It("follows conventions", func() {
expectValidResolverType(sut)
})
})
BeforeEach(func() {
err = defaults.Set(&sutConfig)
Expect(err).ShouldNot(HaveOccurred())
mockAnswer = new(dns.Msg)
origIP = net.ParseIP("1.2.3.4")
ecsIP = net.ParseIP("4.3.2.1")
})
JustBeforeEach(func() {
if m == nil {
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{
Res: mockAnswer,
RType: ResponseTypeCUSTOMDNS,
Reason: "Test",
}, nil)
}
sut = NewECSResolver(sutConfig).(*ECSResolver)
sut.Next(m)
})
When("ecs is disabled", func() {
Describe("IsEnabled", func() {
It("is false", func() {
Expect(sut.IsEnabled()).Should(BeFalse())
})
})
})
When("ecs is enabled", func() {
BeforeEach(func() {
sutConfig.UseAsClient = true
})
Describe("IsEnabled", func() {
It("is true", func() {
Expect(sut.IsEnabled()).Should(BeTrue())
})
})
When("use ecs client ip is enabled", func() {
BeforeEach(func() {
sutConfig.UseAsClient = true
})
It("should change ClientIP with subnet 32", func() {
request := newRequest("example.com.", A)
request.ClientIP = origIP
addEcsOption(request.Req, ecsIP, ecsMaskIPv4)
m.ResolveFn = func(req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(ecsIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
It("shouldn't change ClientIP with subnet 24", func() {
request := newRequest("example.com.", A)
request.ClientIP = origIP
addEcsOption(request.Req, ecsIP, 24)
m.ResolveFn = func(req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
})
When("forward ecs is enabled", func() {
BeforeEach(func() {
sutConfig.Forward = true
sutConfig.IPv4Mask = 32
sutConfig.IPv6Mask = 128
})
It("should add Ecs information with subnet 32", func() {
request := newRequest("example.com.", A)
request.ClientIP = origIP
m.ResolveFn = func(req *Request) (*Response, error) {
Expect(req.ClientIP).Should(Equal(origIP))
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
It("should add Ecs information with subnet 128", func() {
request := newRequest("example.com.", AAAA)
request.ClientIP = net.ParseIP("2001:db8::68")
m.ResolveFn = func(req *Request) (*Response, error) {
Expect(req.Req).Should(HaveEdnsOption(dns.EDNS0SUBNET))
return respondWith(mockAnswer), nil
}
Expect(sut.Resolve(request)).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveResponseType(ResponseTypeRESOLVED),
HaveReturnCode(dns.RcodeSuccess),
HaveReason("Test")))
})
})
})
Context("maskIP", func() {
It("should mask IPv4", func() {
ip := net.ParseIP("192.168.10.123")
mask := config.ECSv4Mask(24)
mip, err := maskIP(ip, mask)
Expect(err).ShouldNot(HaveOccurred())
Expect(mip).Should(Equal(net.ParseIP("192.168.10.0").To4()))
})
})
})
// addEcsOption adds the subnet information to the request as EDNS0 option
func addEcsOption(req *dns.Msg, ip net.IP, netmask uint8) {
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
e.SourceScope = ecsSourceScope
e.Family = ecsFamilyIPv4
e.SourceNetmask = netmask
e.Address = ip
util.SetEdns0Option(req, e)
}
// respondWith creates a new Response with the given request and message
func respondWith(res *dns.Msg) *Response {
return &Response{Res: res, RType: ResponseTypeRESOLVED, Reason: "Test"}
}

View File

@ -3,15 +3,19 @@ package resolver
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
)
// A EdeResolver is responsible for adding the reason for the response as EDNS0 option
type EdeResolver struct {
configurable[*config.EdeConfig]
NextResolver
typed
}
// NewEdeResolver creates new resolver instance which adds the reason for
// the response as EDNS0 option to the response if it is enabled in the configuration
func NewEdeResolver(cfg config.EdeConfig) *EdeResolver {
return &EdeResolver{
configurable: withConfig(&cfg),
@ -19,6 +23,8 @@ func NewEdeResolver(cfg config.EdeConfig) *EdeResolver {
}
}
// Resolve adds the reason as EDNS0 option to the response of the next resolver
// if it is enabled in the configuration
func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) {
if !r.cfg.Enable {
return r.next.Resolve(request)
@ -34,6 +40,7 @@ func (r *EdeResolver) Resolve(request *model.Request) (*model.Response, error) {
return resp, nil
}
// addExtraReasoning adds the reason for the response as EDNS0 option
func (r *EdeResolver) addExtraReasoning(res *model.Response) {
infocode := res.RType.ToExtendedErrorCode()
@ -42,12 +49,9 @@ func (r *EdeResolver) addExtraReasoning(res *model.Response) {
return
}
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.Option = append(opt.Option, &dns.EDNS0_EDE{
InfoCode: infocode,
ExtraText: res.Reason,
})
res.Res.Extra = append(res.Res.Extra, opt)
edeOption := new(dns.EDNS0_EDE)
edeOption.InfoCode = infocode
edeOption.ExtraText = res.Reason
util.SetEdns0Option(res.Res, edeOption)
}

View File

@ -1,11 +1,14 @@
// Description: Tests for ede_resolver.go
package resolver
import (
"errors"
"math"
"github.com/0xERR0R/blocky/config"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/util"
. "github.com/0xERR0R/blocky/model"
@ -60,7 +63,7 @@ var _ = Describe("EdeResolver", func() {
HaveNoAnswer(),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToExtra, BeEmpty()),
Not(HaveEdnsOption(dns.EDNS0EDE)),
))
// delegated to next resolver
@ -81,8 +84,8 @@ var _ = Describe("EdeResolver", func() {
}
})
extractFirstOptRecord := func(e []dns.RR) []dns.EDNS0 {
return e[0].(*dns.OPT).Option
extractEdeOption := func(res *Response) dns.EDNS0_EDE {
return *util.GetEdns0Option[*dns.EDNS0_EDE](res.Res)
}
It("should add EDE information", func() {
@ -92,19 +95,39 @@ var _ = Describe("EdeResolver", func() {
HaveNoAnswer(),
HaveResponseType(ResponseTypeCUSTOMDNS),
HaveReturnCode(dns.RcodeSuccess),
// extra should contain one OPT record
WithTransform(ToExtra,
HaveEdnsOption(dns.EDNS0EDE),
WithTransform(extractEdeOption,
SatisfyAll(
HaveLen(1),
WithTransform(extractFirstOptRecord,
SatisfyAll(
ContainElement(HaveField("InfoCode", Equal(dns.ExtendedErrorCodeForgedAnswer))),
ContainElement(HaveField("ExtraText", Equal("Test"))),
)),
HaveField("InfoCode", Equal(dns.ExtendedErrorCodeForgedAnswer)),
HaveField("ExtraText", Equal("Test")),
)),
))
})
When("resolver returns other", func() {
BeforeEach(func() {
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{
Res: mockAnswer,
RType: ResponseType(math.MaxInt),
Reason: "Test",
}, nil)
})
It("shouldn't add EDE information", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
Should(
SatisfyAll(
HaveNoAnswer(),
HaveReturnCode(dns.RcodeSuccess),
Not(HaveEdnsOption(dns.EDNS0EDE)),
))
// delegated to next resolver
Expect(m.Calls).Should(HaveLen(1))
})
})
When("resolver returns an error", func() {
resolveErr := errors.New("test")

View File

@ -415,6 +415,7 @@ func createQueryResolver(
r = resolver.Chain(
resolver.NewFilteringResolver(cfg.Filtering),
resolver.NewFqdnOnlyResolver(cfg.FqdnOnly),
resolver.NewECSResolver(cfg.ECS),
clientNames,
resolver.NewEdeResolver(cfg.Ede),
resolver.NewQueryLoggingResolver(ctx, cfg.QueryLog),

128
util/edns0.go Normal file
View File

@ -0,0 +1,128 @@
package util
import (
"fmt"
"slices"
"github.com/miekg/dns"
)
// EDNS0Option is an interface for all EDNS0 options as type constraint for generics.
type EDNS0Option interface {
*dns.EDNS0_SUBNET | *dns.EDNS0_EDE | *dns.EDNS0_LOCAL | *dns.EDNS0_NSID | *dns.EDNS0_COOKIE | *dns.EDNS0_UL
Option() uint16
}
// RemoveEdns0Record removes the OPT record from the Extra section of the given message.
// If the OPT record is removed, true will be returned.
func RemoveEdns0Record(msg *dns.Msg) bool {
if msg == nil || msg.IsEdns0() == nil {
return false
}
for i, rr := range msg.Extra {
if rr.Header().Rrtype == dns.TypeOPT {
msg.Extra = slices.Delete(msg.Extra, i, i+1)
return true
}
}
return false
}
// GetEdns0Option returns the option with the given code from the OPT record in the
// Extra section of the given message.
// If the option is not found, nil will be returned.
func GetEdns0Option[T EDNS0Option](msg *dns.Msg) T {
if msg == nil {
return nil
}
opt := msg.IsEdns0()
if opt == nil {
return nil
}
var t T
for _, o := range opt.Option {
if o.Option() == t.Option() {
t, ok := o.(T)
if !ok {
panic(fmt.Errorf("dns option with code %d is not of type %T", t.Option(), t))
}
return t
}
}
return nil
}
// RemoveEdns0Option removes the option according to the given type from the OPT record
// in the Extra section of the given message.
// If there are no more options in the OPT record, the OPT record will be removed.
// If the option is successfully removed, true will be returned.
func RemoveEdns0Option[T EDNS0Option](msg *dns.Msg) bool {
if msg == nil {
return false
}
opt := msg.IsEdns0()
if opt == nil {
return false
}
res := false
var t T
for i, o := range opt.Option {
if o.Option() == t.Option() {
opt.Option = slices.Delete(opt.Option, i, i+1)
res = true
break
}
}
if len(opt.Option) == 0 {
RemoveEdns0Record(msg)
}
return res
}
// SetEdns0Option adds the given option to the OPT record in the Extra section of the
// given message.
// If the option already exists, it will be replaced.
// If the option is successfully set, true will be returned.
func SetEdns0Option(msg *dns.Msg, opt dns.EDNS0) bool {
if msg == nil || opt == nil {
return false
}
optRecord := msg.IsEdns0()
if optRecord == nil {
optRecord = new(dns.OPT)
optRecord.Hdr.Name = "."
optRecord.Hdr.Rrtype = dns.TypeOPT
msg.Extra = append(msg.Extra, optRecord)
}
newOpts := make([]dns.EDNS0, 0, len(optRecord.Option)+1)
for _, o := range optRecord.Option {
if o.Option() != opt.Option() {
newOpts = append(newOpts, o)
}
}
newOpts = append(newOpts, opt)
optRecord.Option = newOpts
return true
}

230
util/edns0_test.go Normal file
View File

@ -0,0 +1,230 @@
package util
import (
"net"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/miekg/dns"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const (
exampleDomain = "example.com"
testTxt = "test"
)
var _ = Describe("EDNS0 utils", func() {
var baseMsg *dns.Msg
BeforeEach(func() {
baseMsg = new(dns.Msg)
txt := new(dns.TXT)
txt.Hdr.Name = exampleDomain + "."
txt.Hdr.Rrtype = dns.TypeTXT
txt.Txt = []string{testTxt}
baseMsg.Extra = append(baseMsg.Extra, txt)
})
Describe("RemoveEdns0Record", func() {
When("OPT record is present", func() {
BeforeEach(func() {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
baseMsg.Extra = append(baseMsg.Extra, opt)
})
It("should remove it", func() {
Expect(RemoveEdns0Record(baseMsg)).Should(BeTrue())
Expect(baseMsg.IsEdns0()).Should(BeNil())
})
})
When("OPT record is not present", func() {
It("should do nothing ", func() {
Expect(RemoveEdns0Record(baseMsg)).Should(BeFalse())
})
})
When("Extra is nil", func() {
BeforeEach(func() {
baseMsg.Extra = nil
})
It("should do nothing", func() {
Expect(RemoveEdns0Record(baseMsg)).Should(BeFalse())
})
})
When("message is nil", func() {
It("should do nothing", func() {
Expect(RemoveEdns0Record(nil)).Should(BeFalse())
})
})
})
Describe("GetEdns0Option", func() {
When("Option is present", func() {
var eso *dns.EDNS0_SUBNET
BeforeEach(func() {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
eso = new(dns.EDNS0_SUBNET)
eso.Code = dns.EDNS0SUBNET
eso.Address = net.ParseIP("192.168.0.0")
eso.Family = 1
eso.SourceNetmask = 24
opt.Option = append(opt.Option, eso)
baseMsg.Extra = append(baseMsg.Extra, opt)
})
It("should return it", func() {
Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(Equal(eso))
})
})
When("Option is not present", func() {
BeforeEach(func() {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.Option = append(opt.Option, new(dns.EDNS0_EDE))
baseMsg.Extra = append(baseMsg.Extra, opt)
})
It("should return nil", func() {
Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeNil())
})
})
When("Extra is nil", func() {
BeforeEach(func() {
baseMsg.Extra = nil
})
It("should return nil", func() {
Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeNil())
})
})
When("message is nil", func() {
It("should return nil", func() {
Expect(GetEdns0Option[*dns.EDNS0_SUBNET](nil)).Should(BeNil())
})
})
})
Describe("RemoveEdns0Option", func() {
When("Option is present", func() {
BeforeEach(func() {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
eso := new(dns.EDNS0_SUBNET)
eso.Code = dns.EDNS0SUBNET
opt.Option = append(opt.Option, eso)
baseMsg.Extra = append(baseMsg.Extra, opt)
})
It("should remove it", func() {
Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeTrue())
Expect(baseMsg).ShouldNot(HaveEdnsOption(dns.EDNS0SUBNET))
})
})
When("Option is not present", func() {
BeforeEach(func() {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
opt.Option = append(opt.Option, new(dns.EDNS0_EDE))
baseMsg.Extra = append(baseMsg.Extra, opt)
})
It("should return false", func() {
Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeFalse())
})
})
When("Extra is nil", func() {
BeforeEach(func() {
baseMsg.Extra = nil
})
It("should return false", func() {
Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(BeFalse())
})
})
When("message is nil", func() {
It("should return false", func() {
Expect(RemoveEdns0Option[*dns.EDNS0_SUBNET](nil)).Should(BeFalse())
})
})
})
Describe("SetEdns0Option", func() {
When("Option is not present", func() {
var eso *dns.EDNS0_SUBNET
BeforeEach(func() {
Expect(baseMsg).ShouldNot(HaveEdnsOption(dns.EDNS0SUBNET))
Expect(SetEdns0Option(baseMsg, new(dns.EDNS0_EDE))).Should(BeTrue())
eso = new(dns.EDNS0_SUBNET)
eso.Code = dns.EDNS0SUBNET
})
It("should add the option", func() {
Expect(SetEdns0Option(baseMsg, eso)).Should(BeTrue())
Expect(baseMsg).Should(HaveEdnsOption(dns.EDNS0SUBNET))
})
})
When("Option is present", func() {
var (
eso *dns.EDNS0_SUBNET
eso2 *dns.EDNS0_SUBNET
)
BeforeEach(func() {
eso = new(dns.EDNS0_SUBNET)
eso.Code = dns.EDNS0SUBNET
eso.Address = net.ParseIP("1.1.1.1")
eso.Family = 1
eso.SourceNetmask = 32
eso2 = new(dns.EDNS0_SUBNET)
eso2.Code = dns.EDNS0SUBNET
eso2.Address = net.ParseIP("2.2.2.2")
eso2.Family = 1
eso2.SourceNetmask = 32
Expect(SetEdns0Option(baseMsg, eso)).Should(BeTrue())
})
It("should replace it", func() {
Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(Equal(eso))
Expect(SetEdns0Option(baseMsg, eso2)).Should(BeTrue())
Expect(GetEdns0Option[*dns.EDNS0_SUBNET](baseMsg)).Should(Equal(eso2))
})
})
When("message is nil", func() {
It("should return false", func() {
Expect(SetEdns0Option(nil, new(dns.EDNS0_SUBNET))).Should(BeFalse())
})
})
When("option is nil", func() {
It("should do nothing if option is nil", func() {
Expect(SetEdns0Option(baseMsg, nil)).Should(BeFalse())
})
})
})
})