Convert CustomDNSMapping from a struct to a map[string]CustomDNSEntries

This commit is contained in:
Ben McHone 2024-01-26 20:37:10 -06:00
parent ffa3418afc
commit 4dbb650c5d
2 changed files with 66 additions and 74 deletions

View File

@ -17,13 +17,72 @@ type CustomDNS struct {
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
}
type CustomDNSMapping struct {
Entries map[string][]dns.RR
type (
CustomDNSMapping map[string]CustomDNSEntries
CustomDNSEntries []dns.RR
)
func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
return err
}
parts := strings.Split(input, ",")
result := make(CustomDNSEntries, len(parts))
removePrefixSuffix := func(in, prefix string) string {
in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix))
in = strings.TrimSuffix(in, ")")
return strings.TrimSpace(in)
}
for _, part := range parts {
if strings.HasPrefix(part, "CNAME(") {
domain := removePrefixSuffix(part, "CNAME")
domain = dns.Fqdn(domain)
cname := &dns.CNAME{Target: domain}
result = append(result, cname)
} else {
// Fall back to A/AAAA records to maintain backwards compatibility in config.yml
// We will still remove the A() or AAAA() if it exists
if strings.Contains(part, ".") { // IPV4 address
ipStr := removePrefixSuffix(part, "A")
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", part)
}
a := new(dns.A)
a.A = ip
result = append(result, a)
} else { // IPV6 address
ipStr := removePrefixSuffix(part, "AAAA")
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", part)
}
aaaa := new(dns.AAAA)
aaaa.AAAA = ip
result = append(result, aaaa)
}
}
}
*c = result
return nil
}
// IsEnabled implements `config.Configurable`.
func (c *CustomDNS) IsEnabled() bool {
return len(c.Mapping.Entries) != 0
return len(c.Mapping) != 0
}
// LogConfig implements `config.Configurable`.
@ -33,74 +92,7 @@ func (c *CustomDNS) LogConfig(logger *logrus.Entry) {
logger.Info("mapping:")
for key, val := range c.Mapping.Entries {
for key, val := range c.Mapping {
logger.Infof(" %s = %s", key, val)
}
}
// UnmarshalYAML implements `yaml.Unmarshaler`.
func (c *CustomDNSMapping) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input map[string]string
if err := unmarshal(&input); err != nil {
return err
}
result := make(map[string][]dns.RR, len(input))
removePrefixSuffix := func(in, prefix string) string {
in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix))
in = strings.TrimSuffix(in, ")")
return strings.TrimSpace(in)
}
addMapping := func(domain string, rr dns.RR) {
if _, ok := result[domain]; !ok {
result[domain] = []dns.RR{rr}
} else {
result[domain] = append(result[domain], rr)
}
}
for k, v := range input {
for _, part := range strings.Split(v, ",") {
if strings.HasPrefix(part, "CNAME(") {
domain := removePrefixSuffix(part, "CNAME")
domain = dns.Fqdn(domain)
cname := &dns.CNAME{Target: domain}
addMapping(k, cname)
} else {
// Fall back to A/AAAA records to maintain backwards compatibility in config.yml
// We will still remove the A() or AAAA() if it exists
if strings.Contains(part, ".") { // IPV4 address
ipStr := removePrefixSuffix(part, "A")
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", part)
}
a := new(dns.A)
a.A = ip
addMapping(k, a)
} else { // IPV6 address
ipStr := removePrefixSuffix(part, "AAAA")
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("inpartalid IP address '%s'", part)
}
aaaa := new(dns.AAAA)
aaaa.AAAA = ip
addMapping(k, aaaa)
}
}
}
}
c.Entries = result
return nil
}

View File

@ -26,10 +26,10 @@ type CustomDNSResolver struct {
// NewCustomDNSResolver creates new resolver instance
func NewCustomDNSResolver(cfg config.CustomDNS) *CustomDNSResolver {
m := make(map[string][]dns.RR, len(cfg.Mapping.Entries))
reverse := make(map[string][]string, len(cfg.Mapping.Entries))
m := make(map[string][]dns.RR, len(cfg.Mapping))
reverse := make(map[string][]string, len(cfg.Mapping))
for url, entries := range cfg.Mapping.Entries {
for url, entries := range cfg.Mapping {
m[strings.ToLower(url)] = entries
for _, entry := range entries {