blocky/util/common.go

222 lines
5.6 KiB
Go

package util
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"path/filepath"
"regexp"
"sort"
"strings"
"sync/atomic"
"github.com/0xERR0R/blocky/log"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
//nolint:gochecknoglobals
var (
// To avoid making this package depend on config, we use a global
// that is set at config load.
// Ideally we'd move the obfuscate code somewhere else (maybe into `log`),
// but that would require also moving all its dependencies.
// This is good enough for now.
LogPrivacy atomic.Bool
alphanumeric = regexp.MustCompile("[a-zA-Z0-9]")
)
// Obfuscate replaces all alphanumeric characters with * to obfuscate user sensitive data if LogPrivacy is enabled
func Obfuscate(in string) string {
if LogPrivacy.Load() {
return alphanumeric.ReplaceAllString(in, "*")
}
return in
}
// AnswerToString creates a user-friendly representation of an answer
func AnswerToString(answer []dns.RR) string {
answers := make([]string, len(answer))
for i, record := range answer {
switch v := record.(type) {
case *dns.A:
answers[i] = fmt.Sprintf("A (%s)", v.A)
case *dns.AAAA:
answers[i] = fmt.Sprintf("AAAA (%s)", v.AAAA)
case *dns.CNAME:
answers[i] = fmt.Sprintf("CNAME (%s)", v.Target)
case *dns.PTR:
answers[i] = fmt.Sprintf("PTR (%s)", v.Ptr)
default:
answers[i] = fmt.Sprint(record.String())
}
}
return Obfuscate(strings.Join(answers, ", "))
}
// QuestionToString creates a user-friendly representation of a question
func QuestionToString(questions []dns.Question) string {
result := make([]string, len(questions))
for i, question := range questions {
result[i] = fmt.Sprintf("%s (%s)", dns.TypeToString[question.Qtype], question.Name)
}
return Obfuscate(strings.Join(result, ", "))
}
// CreateAnswerFromQuestion creates new answer from a question
func CreateAnswerFromQuestion(question dns.Question, ip net.IP, remainingTTL uint32) (dns.RR, error) {
h := CreateHeader(question, remainingTTL)
switch question.Qtype {
case dns.TypeA:
a := new(dns.A)
a.A = ip
a.Hdr = h
return a, nil
case dns.TypeAAAA:
a := new(dns.AAAA)
a.AAAA = ip
a.Hdr = h
return a, nil
}
log.Log().Errorf("Using fallback for unsupported query type %s", dns.TypeToString[question.Qtype])
return dns.NewRR(fmt.Sprintf("%s %d %s %s %s",
question.Name, remainingTTL, "IN", dns.TypeToString[question.Qtype], ip))
}
// CreateHeader creates DNS header for passed question
func CreateHeader(question dns.Question, remainingTTL uint32) dns.RR_Header {
return dns.RR_Header{Name: question.Name, Rrtype: question.Qtype, Class: dns.ClassINET, Ttl: remainingTTL}
}
// ExtractDomain returns domain string from the question
func ExtractDomain(question dns.Question) string {
return ExtractDomainOnly(question.Name)
}
// ExtractDomainOnly extracts domain from the DNS query
func ExtractDomainOnly(in string) string {
return strings.TrimSuffix(strings.ToLower(in), ".")
}
// NewMsgWithQuestion creates new DNS message with question
func NewMsgWithQuestion(question string, qType dns.Type) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(question), uint16(qType))
return msg
}
// NewMsgWithAnswer creates new DNS message with answer
func NewMsgWithAnswer(domain string, ttl uint, dnsType dns.Type, address string) (*dns.Msg, error) {
rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType, address))
if err != nil {
return nil, err
}
msg := new(dns.Msg)
msg.Answer = []dns.RR{rr}
return msg, nil
}
type kv struct {
key string
value int
}
// IterateValueSorted iterates over maps value in a sorted order and applies the passed function
func IterateValueSorted(in map[string]int, fn func(string, int)) {
ss := make([]kv, 0)
for k, v := range in {
ss = append(ss, kv{k, v})
}
sort.Slice(ss, func(i, j int) bool {
return ss[i].value > ss[j].value || (ss[i].value == ss[j].value && ss[i].key > ss[j].key)
})
for _, kv := range ss {
fn(kv.key, kv.value)
}
}
// LogOnError logs the message only if error is not nil
func LogOnError(ctx context.Context, message string, err error) {
if err != nil {
log.FromCtx(ctx).Error(message, err)
}
}
// LogOnErrorWithEntry logs the message only if error is not nil
func LogOnErrorWithEntry(logEntry *logrus.Entry, message string, err error) {
if err != nil {
logEntry.Error(message, err)
}
}
// FatalOnError logs the message only if error is not nil and exits the program execution
func FatalOnError(message string, err error) {
if err != nil {
logger := log.Log()
// Make sure the error is printend even if the log has been silenced
if logger.Out == io.Discard {
log.ConfigureLogger(logger, log.DefaultConfig())
}
logger.Fatal(message, err)
}
}
// GenerateCacheKey return cacheKey by query type/domain
func GenerateCacheKey(qType dns.Type, qName string) string {
const qTypeLength = 2
b := make([]byte, qTypeLength+len(qName))
binary.BigEndian.PutUint16(b, uint16(qType))
copy(b[2:], strings.ToLower(qName))
return string(b)
}
// ExtractCacheKey return query type/domain from cacheKey
func ExtractCacheKey(key string) (qType dns.Type, qName string) {
b := []byte(key)
qType = dns.Type(binary.BigEndian.Uint16(b))
qName = string(b[2:])
return
}
// CidrContainsIP checks if CIDR contains a single IP
func CidrContainsIP(cidr string, ip net.IP) bool {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
return ipnet.Contains(ip)
}
// ClientNameMatchesGroupName checks if a group with optional wildcards contains a client name
func ClientNameMatchesGroupName(group, clientName string) bool {
match, _ := filepath.Match(strings.ToLower(group), strings.ToLower(clientName))
return match
}