mirror of https://github.com/0xERR0R/blocky.git
239 lines
7.3 KiB
Go
239 lines
7.3 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/sirupsen/logrus/hooks/test"
|
|
|
|
. "github.com/0xERR0R/blocky/log"
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("Common function tests", func() {
|
|
Describe("Print DNS answer", func() {
|
|
When("different types of DNS answers", func() {
|
|
rr := make([]dns.RR, 0)
|
|
rr = append(rr, &dns.A{A: net.ParseIP("127.0.0.1")})
|
|
rr = append(rr, &dns.AAAA{AAAA: net.ParseIP("2001:0db8:85a3:08d3:1319:8a2e:0370:7344")})
|
|
rr = append(rr, &dns.CNAME{Target: "cname"})
|
|
rr = append(rr, &dns.PTR{Ptr: "ptr"})
|
|
rr = append(rr, &dns.NS{Ns: "ns"})
|
|
It("should print the answers", func() {
|
|
answerToString := AnswerToString(rr)
|
|
Expect(answerToString).Should(Equal("A (127.0.0.1), " +
|
|
"AAAA (2001:db8:85a3:8d3:1319:8a2e:370:7344), CNAME (cname), PTR (ptr), \t0\tCLASS0\tNone\tns"))
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("print question", func() {
|
|
When("question is provided", func() {
|
|
question := dns.Question{
|
|
Name: "google.de",
|
|
Qtype: dns.TypeA,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
It("should print the answers", func() {
|
|
questionToString := QuestionToString([]dns.Question{question})
|
|
Expect(questionToString).Should(Equal("A (google.de)"))
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Extract domain from query", func() {
|
|
When("Question is provided", func() {
|
|
question := dns.Question{
|
|
Name: "google.de.",
|
|
Qtype: dns.TypeA,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
It("should extract only domain", func() {
|
|
domain := ExtractDomain(question)
|
|
Expect(domain).Should(Equal("google.de"))
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Create new DNS message", func() {
|
|
When("Question is provided", func() {
|
|
question := "google.com."
|
|
It("should create message", func() {
|
|
msg := NewMsgWithQuestion(question, dns.Type(dns.TypeA))
|
|
Expect(QuestionToString(msg.Question)).Should(Equal("A (google.com.)"))
|
|
})
|
|
})
|
|
When("Answer is provided", func() {
|
|
It("should create message", func() {
|
|
msg, err := NewMsgWithAnswer("google.com", 25, dns.Type(dns.TypeA), "192.168.178.1")
|
|
Expect(err).Should(Succeed())
|
|
Expect(AnswerToString(msg.Answer)).Should(Equal("A (192.168.178.1)"))
|
|
})
|
|
})
|
|
When("Answer is corrupt", func() {
|
|
It("should throw an error", func() {
|
|
_, err := NewMsgWithAnswer(strings.Repeat("a", 300), 25, dns.Type(dns.TypeA), "192.168.178.1")
|
|
Expect(err).Should(HaveOccurred())
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Create answer from question", func() {
|
|
When("type A is provided", func() {
|
|
question := dns.Question{
|
|
Name: "google.de",
|
|
Qtype: dns.TypeA,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
answer, err := CreateAnswerFromQuestion(question, net.ParseIP("192.168.178.1"), 25)
|
|
Expect(err).Should(Succeed())
|
|
It("should return A record", func() {
|
|
Expect(answer.String()).Should(Equal("google.de 25 IN A 192.168.178.1"))
|
|
})
|
|
})
|
|
When("type AAAA is provided", func() {
|
|
question := dns.Question{
|
|
Name: "google.de",
|
|
Qtype: dns.TypeAAAA,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
answer, err := CreateAnswerFromQuestion(question, net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), 25)
|
|
Expect(err).Should(Succeed())
|
|
It("should return AAAA record", func() {
|
|
Expect(answer.String()).Should(Equal("google.de 25 IN AAAA 2001:db8:85a3::8a2e:370:7334"))
|
|
})
|
|
})
|
|
When("type NS is provided", func() {
|
|
question := dns.Question{
|
|
Name: "google.de",
|
|
Qtype: dns.TypeNS,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
answer, err := CreateAnswerFromQuestion(question, net.ParseIP("192.168.178.1"), 25)
|
|
Expect(err).Should(Succeed())
|
|
It("should return generic record as fallback", func() {
|
|
Expect(answer.String()).Should(Equal("google.de. 25 IN NS 192.168.178.1."))
|
|
})
|
|
})
|
|
|
|
When("Invalid record is provided", func() {
|
|
question := dns.Question{
|
|
Name: strings.Repeat("k", 99),
|
|
Qtype: dns.TypeNS,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
_, err := CreateAnswerFromQuestion(question, net.ParseIP("192.168.178.1"), 25)
|
|
It("should fail", func() {
|
|
Expect(err).Should(HaveOccurred())
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Sorted iteration over map", func() {
|
|
When("Key-value map is provided", func() {
|
|
m := make(map[string]int)
|
|
m["x"] = 5
|
|
m["a"] = 1
|
|
m["m"] = 9
|
|
It("should iterate in sorted order", func() {
|
|
result := make([]string, 0)
|
|
IterateValueSorted(m, func(s string, i int) {
|
|
result = append(result, fmt.Sprintf("%s-%d", s, i))
|
|
})
|
|
Expect(strings.Join(result, ";")).Should(Equal("m-9;x-5;a-1"))
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Logging functions", func() {
|
|
When("LogOnError is called with error", func() {
|
|
err := errors.New("test")
|
|
It("should log", func(ctx context.Context) {
|
|
hook := test.NewGlobal()
|
|
Log().AddHook(hook)
|
|
defer hook.Reset()
|
|
LogOnError(ctx, "message ", err)
|
|
Expect(hook.LastEntry().Message).Should(Equal("message test"))
|
|
})
|
|
})
|
|
|
|
When("LogOnErrorWithEntry is called with error", func() {
|
|
err := errors.New("test")
|
|
It("should log", func() {
|
|
hook := test.NewGlobal()
|
|
Log().AddHook(hook)
|
|
defer hook.Reset()
|
|
logger, hook := test.NewNullLogger()
|
|
entry := logrus.NewEntry(logger)
|
|
LogOnErrorWithEntry(entry, "message ", err)
|
|
Expect(hook.LastEntry().Message).Should(Equal("message test"))
|
|
})
|
|
})
|
|
|
|
When("FatalOnError is called with error", func() {
|
|
err := errors.New("test")
|
|
It("should log and exit", func() {
|
|
hook := test.NewGlobal()
|
|
Log().AddHook(hook)
|
|
fatal := false
|
|
Log().ExitFunc = func(int) { fatal = true }
|
|
defer func() {
|
|
Log().ExitFunc = nil
|
|
}()
|
|
FatalOnError("message ", err)
|
|
Expect(hook.LastEntry().Message).Should(Equal("message test"))
|
|
Expect(fatal).Should(BeTrue())
|
|
})
|
|
})
|
|
})
|
|
|
|
Describe("Domain cache key generate/extract", func() {
|
|
It("should works", func() {
|
|
cacheKey := GenerateCacheKey(dns.Type(dns.TypeA), "example.com")
|
|
qType, qName := ExtractCacheKey(cacheKey)
|
|
Expect(qType).Should(Equal(dns.Type(dns.TypeA)))
|
|
Expect(qName).Should(Equal("example.com"))
|
|
})
|
|
})
|
|
|
|
Describe("CIDR contains IP", func() {
|
|
It("should return true if CIDR (10.43.8.64 - 10.43.8.79) contains the IP", func() {
|
|
c := CidrContainsIP("10.43.8.67/28", net.ParseIP("10.43.8.64"))
|
|
Expect(c).Should(BeTrue())
|
|
})
|
|
It("should return false if CIDR (10.43.8.64 - 10.43.8.79) doesn't contain the IP", func() {
|
|
c := CidrContainsIP("10.43.8.67/28", net.ParseIP("10.43.8.63"))
|
|
Expect(c).Should(BeFalse())
|
|
})
|
|
It("should return false if CIDR is wrong", func() {
|
|
c := CidrContainsIP("10.43.8.67", net.ParseIP("10.43.8.63"))
|
|
Expect(c).Should(BeFalse())
|
|
})
|
|
})
|
|
|
|
Describe("Client name matches group name", func() {
|
|
It("should return true if client name matches with wildcard", func() {
|
|
c := ClientNameMatchesGroupName("group*", "group-test")
|
|
Expect(c).Should(BeTrue())
|
|
})
|
|
It("should return false if client name doesn't match with wildcard", func() {
|
|
c := ClientNameMatchesGroupName("group*", "abc")
|
|
Expect(c).Should(BeFalse())
|
|
})
|
|
It("should return true if client name matches with range wildcard", func() {
|
|
c := ClientNameMatchesGroupName("group[1-3]", "group1")
|
|
Expect(c).Should(BeTrue())
|
|
})
|
|
It("should return false if client name doesn't match with range wildcard", func() {
|
|
c := ClientNameMatchesGroupName("group[1-3]", "group4")
|
|
Expect(c).Should(BeFalse())
|
|
})
|
|
})
|
|
})
|