feat: allow hosts file resolver to use a HTTP(S) link or inline block (#884)

Unify the hosts file parsing between the hosts resolver and lists so
the resolver supports more data sources than local files.

Lists' group cache is now re-used if refresh fails.

Also improve lookups in hosts:
Instead of iterating through all hosts+aliases for each A/AAAA query,
we can do a single lookup.
For PTR we search through only the hosts with an IP version that matches
the question. And compare IPs instead of building the reverse DNS name
for each IP in the hosts database.
This commit is contained in:
ThinkChaos 2023-03-06 19:32:41 -05:00 committed by GitHub
parent 9f58c4bf69
commit a2ab7c3ef1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 2309 additions and 268 deletions

View File

@ -51,7 +51,7 @@ var _ = Describe("External lists and query blocking", func() {
HaveTTL(BeNumerically("==", 123)),
))
Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("error during file processing")))
Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("cannot open source: ")))
})
})
Context("startStrategy = failOnError", func() {

1
go.mod
View File

@ -76,6 +76,7 @@ require (
github.com/Masterminds/semver v1.5.0 // indirect
github.com/Masterminds/sprig v2.22.0+incompatible // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect

2
go.sum
View File

@ -62,6 +62,8 @@ github.com/alicebob/miniredis/v2 v2.30.0/go.mod h1:84TWKZlxYkfgMucPBf5SOQBYJceZe
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef h1:2JGTg6JapxP9/R33ZaagQtAM4EkkSYnIAlOG5EI8gkM=
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef/go.mod h1:JS7hed4L1fj0hXcyEejnW57/7LCetXggd+vwrRnYeII=
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ=
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/avast/retry-go/v4 v4.3.3 h1:G56Bp6mU0b5HE1SkaoVjscZjlQb0oy4mezwY/cGH19w=
github.com/avast/retry-go/v4 v4.3.3/go.mod h1:rg6XFaiuFYII0Xu3RDbZQkxCofFwruZKW8oEF1jpWiU=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=

View File

@ -2,7 +2,7 @@ package lists
//go:generate go run github.com/abice/go-enum -f=$GOFILE --marshal --names
import (
"bufio"
"context"
"errors"
"fmt"
"io"
@ -12,20 +12,20 @@ import (
"sync"
"time"
"github.com/0xERR0R/blocky/cache/stringcache"
"github.com/hako/durafmt"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
"github.com/hako/durafmt"
"github.com/hashicorp/go-multierror"
"github.com/0xERR0R/blocky/cache/stringcache"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/log"
)
const (
defaultProcessingConcurrency = 4
chanCap = 1000
maxErrorsPerFile = 5
)
// ListCacheType represents the type of cached list ENUM(
@ -147,6 +147,8 @@ func logger() *logrus.Entry {
}
// downloads and reads files with domain names and creates cache for them
//
//nolint:funlen // will refactor in a later commit
func (b *ListCache) createCacheForGroup(links []string) (stringcache.StringCache, error) {
var err error
@ -162,11 +164,14 @@ func (b *ListCache) createCacheForGroup(links []string) (stringcache.StringCache
processingLinkJobs := len(links)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// loop over links (http/local) or inline definitions
// start a new goroutine for each link, but limit to max. number (see processingConcurrency)
for _, link := range links {
go func(link string) {
// thy to write in this channel -> this will block if max amount of goroutines are being executed
for idx, link := range links {
go func(idx int, link string) {
// try to write in this channel -> this will block if max amount of goroutines are being executed
guard <- struct{}{}
defer func() {
@ -174,8 +179,14 @@ func (b *ListCache) createCacheForGroup(links []string) (stringcache.StringCache
<-guard
workerDoneChan <- true
}()
b.processFile(link, fileLinesChan, errChan)
}(link)
name := linkName(idx, link)
err := b.parseFile(ctx, name, link, fileLinesChan)
if err != nil {
errChan <- err
}
}(idx, link)
}
Loop:
@ -200,7 +211,12 @@ Loop:
}
}
return factory.Create(), err
cache := factory.Create()
if cache.ElementCount() == 0 && err != nil {
cache = nil // don't replace existing cache
}
return cache, err
}
// Match matches passed domain name against cached list entries
@ -222,7 +238,7 @@ func (b *ListCache) Refresh() {
_ = b.refresh(false)
}
func (b *ListCache) refresh(init bool) error {
func (b *ListCache) refresh(isInit bool) error {
var err error
for group, links := range b.groupToLinks {
@ -231,35 +247,54 @@ func (b *ListCache) refresh(init bool) error {
err = multierror.Append(err, multierror.Prefix(e, fmt.Sprintf("can't create cache group '%s':", group)))
}
if cacheForGroup != nil {
b.lock.Lock()
b.groupCaches[group] = cacheForGroup
b.lock.Unlock()
} else {
if cacheForGroup == nil {
count := b.groupElementCount(group, isInit)
logger := logger().WithFields(logrus.Fields{
"group": group,
"group": group,
"total_count": count,
})
if init {
if count == 0 {
logger.Warn("Populating of group cache failed, cache will be empty until refresh succeeds")
} else {
logger.Warn("Populating of group cache failed, using existing cache, if any")
}
continue
}
if cacheForGroup != nil {
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, cacheForGroup.ElementCount())
b.lock.Lock()
b.groupCaches[group] = cacheForGroup
b.lock.Unlock()
logger().WithFields(logrus.Fields{
"group": group,
"total_count": cacheForGroup.ElementCount(),
}).Info("group import finished")
}
evt.Bus().Publish(evt.BlockingCacheGroupChanged, b.listType, group, cacheForGroup.ElementCount())
logger().WithFields(logrus.Fields{
"group": group,
"total_count": cacheForGroup.ElementCount(),
}).Info("group import finished")
}
return err
}
func (b *ListCache) groupElementCount(group string, isInit bool) int {
if isInit {
return 0
}
b.lock.RLock()
oldCache, ok := b.groupCaches[group]
b.lock.RUnlock()
if !ok {
return 0
}
return oldCache.ElementCount()
}
func readFile(file string) (io.ReadCloser, error) {
logger().WithField("file", file).Info("starting processing of file")
file = strings.TrimPrefix(file, "file://")
@ -268,47 +303,75 @@ func readFile(file string) (io.ReadCloser, error) {
}
// downloads file (or reads local file) and writes each line in the file to the result channel
func (b *ListCache) processFile(link string, resultCh chan<- string, errCh chan<- error) {
var r io.ReadCloser
func (b *ListCache) parseFile(ctx context.Context, name, link string, resultCh chan<- string) error {
count := 0
var err error
r, err = b.getLinkReader(link)
logger := func() *logrus.Entry {
return logger().WithFields(logrus.Fields{
"source": name,
"count": count,
})
}
r, err := b.newLinkReader(link)
if err != nil {
logger().Warn("error during file processing: ", err)
errCh <- err
logger().Error("cannot open source: ", err)
return
return err
}
defer r.Close()
var count int
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// skip comments
if line := processLine(line); line != "" {
resultCh <- line
p := parsers.AllowErrors(parsers.Hosts(r), maxErrorsPerFile)
p.OnErr(func(err error) {
logger().Warnf("parse error: %s, trying to continue", err)
})
err = parsers.ForEach[*parsers.HostsIterator](ctx, p, func(hosts *parsers.HostsIterator) error {
return hosts.ForEach(func(host string) error {
count++
// For IPs, we want to ensure the string is the Go representation so that when
// we compare responses, a same IP matches, even if it was written differently
// in the list.
if ip := net.ParseIP(host); ip != nil {
host = ip.String()
}
resultCh <- host
return nil
})
})
if err != nil {
// Don't log cancelation: it was caused by another goroutine failing
if !errors.Is(err, context.Canceled) {
logger().Error("parse error: ", err)
}
// Only propagate the error if no entries were parsed
// If the file was partially parsed, we'll settle for that
if count == 0 {
return err
}
return nil
}
if err := scanner.Err(); err != nil {
// don't propagate error here. If some lines are not parsable (e.g. too long), it is ok
logger().Warn("can't parse file: ", err)
} else {
logger().WithFields(logrus.Fields{
"source": link,
"count": count,
}).Info("file imported")
}
logger().Info("import succeeded")
return nil
}
func (b *ListCache) getLinkReader(link string) (r io.ReadCloser, err error) {
func linkName(linkIdx int, link string) string {
if strings.ContainsAny(link, "\n") {
return fmt.Sprintf("inline block (item #%d in group)", linkIdx)
}
return link
}
func (b *ListCache) newLinkReader(link string) (r io.ReadCloser, err error) {
switch {
// link contains a line break -> this is inline list definition in YAML (with literal style Block Scalar)
case strings.ContainsAny(link, "\n"):
@ -323,28 +386,3 @@ func (b *ListCache) getLinkReader(link string) (r io.ReadCloser, err error) {
return
}
// return only first column (see hosts format)
func processLine(line string) string {
if strings.HasPrefix(line, "#") {
return ""
}
// remove end of line comment
if idx := strings.IndexRune(line, '#'); idx != -1 {
line = line[:idx]
}
if parts := strings.Fields(line); len(parts) > 0 {
host := parts[len(parts)-1]
ip := net.ParseIP(host)
if ip != nil {
return ip.String()
}
return strings.TrimSpace(strings.ToLower(host))
}
return ""
}

View File

@ -5,9 +5,9 @@ import (
)
func BenchmarkRefresh(b *testing.B) {
file1 := createTestListFile(b.TempDir(), 100000)
file2 := createTestListFile(b.TempDir(), 150000)
file3 := createTestListFile(b.TempDir(), 130000)
file1, _ := createTestListFile(b.TempDir(), 100000)
file2, _ := createTestListFile(b.TempDir(), 150000)
file3, _ := createTestListFile(b.TempDir(), 130000)
lists := map[string][]string{
"gr1": {file1, file2, file3},
}

View File

@ -13,6 +13,8 @@ import (
"time"
. "github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/util"
. "github.com/0xERR0R/blocky/helpertest"
. "github.com/onsi/ginkgo/v2"
@ -75,25 +77,74 @@ var _ = Describe("ListCache", func() {
Expect(group).Should(BeEmpty())
})
})
When("List becomes empty on refresh", func() {
It("should delete existing elements from group cache", func() {
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com"
res <- "# nothing"
})
lists := map[string][]string{
"gr1": {mockDownloader.ListSource()},
}
sut, err := NewListCache(
ListCacheTypeBlacklist, lists,
4*time.Hour,
mockDownloader,
defaultProcessingConcurrency,
false,
)
Expect(err).Should(Succeed())
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
err = sut.refresh(false)
Expect(err).Should(Succeed())
found, group = sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeFalse())
Expect(group).Should(BeEmpty())
})
})
When("List has invalid lines", func() {
It("should still other domains", func() {
lists := map[string][]string{
"gr1": {
inlineList(
"inlinedomain1.com",
"invaliddomain!",
"inlinedomain2.com",
),
},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
found, group := sut.Match("inlinedomain1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
found, group = sut.Match("inlinedomain2.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
})
})
When("a temporary/transient err occurs on download", func() {
It("should not delete existing elements from group cache", func() {
// should produce a transient error on second and third attempt
data := make(chan func() (io.ReadCloser, error), 3)
mockDownloader := &MockDownloader{data: data}
//nolint:unparam
data <- func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader("blocked1.com")), nil
}
//nolint:unparam
data <- func() (io.ReadCloser, error) {
return nil, &TransientError{inner: errors.New("boom")}
}
//nolint:unparam
data <- func() (io.ReadCloser, error) {
return nil, &TransientError{inner: errors.New("boom")}
}
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com"
err <- &TransientError{inner: errors.New("boom")}
err <- &TransientError{inner: errors.New("boom")}
})
lists := map[string][]string{
"gr1": {"http://dummy"},
"gr1": {mockDownloader.ListSource()},
}
sut, err := NewListCache(
@ -113,7 +164,7 @@ var _ = Describe("ListCache", func() {
}, "1s").Should(Succeed())
})
Expect(sut.refresh(true)).Should(HaveOccurred())
Expect(sut.refresh(false)).Should(HaveOccurred())
By("List couldn't be loaded due to timeout", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
@ -131,19 +182,15 @@ var _ = Describe("ListCache", func() {
})
})
When("non transient err occurs on download", func() {
It("should delete existing elements from group cache", func() {
// should produce a 404 err on second attempt
data := make(chan func() (io.ReadCloser, error), 2)
mockDownloader := &MockDownloader{data: data}
//nolint:unparam
data <- func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader("blocked1.com")), nil
}
data <- func() (io.ReadCloser, error) {
return nil, errors.New("boom")
}
It("should keep existing elements from group cache", func() {
// should produce a non transient error on second attempt
mockDownloader := newMockDownloader(func(res chan<- string, err chan<- error) {
res <- "blocked1.com"
err <- errors.New("boom")
})
lists := map[string][]string{
"gr1": {"http://dummy"},
"gr1": {mockDownloader.ListSource()},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, mockDownloader,
@ -151,21 +198,17 @@ var _ = Describe("ListCache", func() {
Expect(err).Should(Succeed())
By("Lists loaded without err", func() {
Eventually(func(g Gomega) {
found, group := sut.Match("blocked1.com", []string{"gr1"})
g.Expect(found).Should(BeTrue())
g.Expect(group).Should(Equal("gr1"))
}, "1s").Should(Succeed())
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
})
Expect(sut.refresh(false)).Should(HaveOccurred())
By("List couldn't be loaded due to 404 err", func() {
Eventually(func() bool {
found, _ := sut.Match("blocked1.com", []string{"gr1"})
return found
}, "1s").Should(BeFalse())
By("Lists from first load is kept", func() {
found, group := sut.Match("blocked1.com", []string{"gr1"})
Expect(found).Should(BeTrue())
Expect(group).Should(Equal("gr1"))
})
})
})
@ -262,9 +305,9 @@ var _ = Describe("ListCache", func() {
})
When("group with bigger files", func() {
It("should match", func() {
file1 := createTestListFile(GinkgoT().TempDir(), 10000)
file2 := createTestListFile(GinkgoT().TempDir(), 15000)
file3 := createTestListFile(GinkgoT().TempDir(), 13000)
file1, lines1 := createTestListFile(GinkgoT().TempDir(), 10000)
file2, lines2 := createTestListFile(GinkgoT().TempDir(), 15000)
file3, lines3 := createTestListFile(GinkgoT().TempDir(), 13000)
lists := map[string][]string{
"gr1": {file1, file2, file3},
}
@ -273,13 +316,17 @@ var _ = Describe("ListCache", func() {
defaultProcessingConcurrency, false)
Expect(err).Should(Succeed())
Expect(sut.groupCaches["gr1"].ElementCount()).Should(Equal(38000))
Expect(sut.groupCaches["gr1"].ElementCount()).Should(Equal(lines1 + lines2 + lines3))
})
})
When("inline list content is defined", func() {
It("should match", func() {
lists := map[string][]string{
"gr1": {"inlinedomain1.com\n#some comment\ninlinedomain2.com"},
"gr1": {inlineList(
"inlinedomain1.com",
"#some comment",
"inlinedomain2.com",
)},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
@ -298,9 +345,13 @@ var _ = Describe("ListCache", func() {
})
When("Text file can't be parsed", func() {
It("should still match already imported strings", func() {
// 2nd line is too long and will cause an error
lists := map[string][]string{
"gr1": {"inlinedomain1.com\n" + strings.Repeat("longString", 100000)},
"gr1": {
inlineList(
"inlinedomain1.com",
"lineTooLong"+strings.Repeat("x", bufio.MaxScanTokenSize), // too long
),
},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
@ -312,10 +363,26 @@ var _ = Describe("ListCache", func() {
Expect(group).Should(Equal("gr1"))
})
})
When("Text file has too many errors", func() {
It("should fail parsing", func() {
lists := map[string][]string{
"gr1": {
inlineList(
strings.Repeat("invaliddomain!\n", maxErrorsPerFile+1), // too many errors
),
},
}
_, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
defaultProcessingConcurrency, false)
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(parsers.ErrTooManyErrors))
})
})
When("file has end of line comment", func() {
It("should still parse the domain", func() {
lists := map[string][]string{
"gr1": {"inlinedomain1.com#a comment\n"},
"gr1": {inlineList("inlinedomain1.com#a comment")},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
@ -330,7 +397,7 @@ var _ = Describe("ListCache", func() {
When("inline regex content is defined", func() {
It("should match", func() {
lists := map[string][]string{
"gr1": {"/^apple\\.(de|com)$/\n"},
"gr1": {inlineList("/^apple\\.(de|com)$/")},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, 0, NewDownloader(),
@ -352,7 +419,7 @@ var _ = Describe("ListCache", func() {
It("should print list configuration", func() {
lists := map[string][]string{
"gr1": {server1.URL, server2.URL},
"gr2": {"inline\ndefinition\n"},
"gr2": {inlineList("inline", "definition")},
}
sut, err := NewListCache(ListCacheTypeBlacklist, lists, time.Hour, NewDownloader(),
@ -396,16 +463,27 @@ var _ = Describe("ListCache", func() {
})
type MockDownloader struct {
data chan func() (io.ReadCloser, error)
util.MockCallSequence[string]
}
func newMockDownloader(driver func(res chan<- string, err chan<- error)) *MockDownloader {
return &MockDownloader{util.NewMockCallSequence(driver)}
}
func (m *MockDownloader) DownloadFile(_ string) (io.ReadCloser, error) {
fn := <-m.data
str, err := m.Call()
if err != nil {
return nil, err
}
return fn()
return io.NopCloser(strings.NewReader(str)), nil
}
func createTestListFile(dir string, totalLines int) string {
func (m *MockDownloader) ListSource() string {
return "http://mock"
}
func createTestListFile(dir string, totalLines int) (string, int) {
file, err := os.CreateTemp(dir, "blocky")
if err != nil {
log.Fatal(err)
@ -417,16 +495,33 @@ func createTestListFile(dir string, totalLines int) string {
}
w.Flush()
return file.Name()
return file.Name(), totalLines
}
const charpool = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
const (
initCharpool = "abcdefghijklmnopqrstuvwxyz"
contCharpool = initCharpool + "0123456789-"
)
func RandStringBytes(n int) string {
b := make([]byte, n)
pool := initCharpool
for i := range b {
b[i] = charpool[rand.Intn(len(charpool))]
b[i] = pool[rand.Intn(len(pool))]
pool = contCharpool
}
return string(b)
}
func inlineList(lines ...string) string {
res := strings.Join(lines, "\n")
// ensure at least one line ending so it's parsed as an inline block
res += "\n"
return res
}

63
lists/parsers/adapt.go Normal file
View File

@ -0,0 +1,63 @@
package parsers
import "context"
// Adapt returns a parser that wraps `inner` converting each parsed value.
func Adapt[From, To any](inner SeriesParser[From], adapt func(From) To) SeriesParser[To] {
return TryAdapt(inner, func(from From) (To, error) {
return adapt(from), nil
})
}
// TryAdapt returns a parser that wraps `inner` and tries to convert each parsed value.
func TryAdapt[From, To any](inner SeriesParser[From], adapt func(From) (To, error)) SeriesParser[To] {
return newAdapter(inner, adapt)
}
// TryAdaptMethod returns a parser that wraps `inner` and tries to convert each parsed value
// using the given method with pointer receiver of `To`.
func TryAdaptMethod[ToPtr *To, From any, To any](
inner SeriesParser[From], method func(ToPtr, From) error,
) SeriesParser[*To] {
return TryAdapt(inner, func(from From) (*To, error) {
res := new(To)
err := method(res, from)
if err != nil {
return nil, err
}
return res, nil
})
}
type adapter[From, To any] struct {
inner SeriesParser[From]
adapt func(From) (To, error)
}
func newAdapter[From, To any](inner SeriesParser[From], adapt func(From) (To, error)) SeriesParser[To] {
return &adapter[From, To]{inner, adapt}
}
func (a *adapter[From, To]) Position() string {
return a.inner.Position()
}
func (a *adapter[From, To]) Next(ctx context.Context) (To, error) {
from, err := a.inner.Next(ctx)
if err != nil {
var zero To
return zero, err
}
res, err := a.adapt(from)
if err != nil {
var zero To
return zero, err
}
return res, nil
}

View File

@ -0,0 +1,92 @@
package parsers
import (
"context"
"errors"
)
// NoErrorLimit can be used to continue parsing until EOF.
const NoErrorLimit = -1
var ErrTooManyErrors = errors.New("too many parse errors")
type FilteredSeriesParser[T any] interface {
SeriesParser[T]
// OnErr registers a callback invoked for each error encountered.
OnErr(func(error))
}
// AllowErrors returns a parser that wraps `inner` and tries to continue parsing.
//
// After `n` errors, it returns any error `inner` does.
func FilterErrors[T any](inner SeriesParser[T], filter func(error) error) FilteredSeriesParser[T] {
return newErrorFilter(inner, filter)
}
// AllowErrors returns a parser that wraps `inner` and tries to continue parsing.
//
// After `n` errors, it returns any error `inner` does.
func AllowErrors[T any](inner SeriesParser[T], n int) FilteredSeriesParser[T] {
if n == NoErrorLimit {
return FilterErrors(inner, func(error) error { return nil })
}
count := 0
return FilterErrors(inner, func(err error) error {
count++
if count > n {
return ErrTooManyErrors
}
return nil
})
}
type errorFilter[T any] struct {
inner SeriesParser[T]
filter func(error) error
}
func newErrorFilter[T any](inner SeriesParser[T], filter func(error) error) FilteredSeriesParser[T] {
return &errorFilter[T]{inner, filter}
}
func (f *errorFilter[T]) OnErr(callback func(error)) {
filter := f.filter
f.filter = func(err error) error {
callback(ErrWithPosition(f.inner, err))
return filter(err)
}
}
func (f *errorFilter[T]) Position() string {
return f.inner.Position()
}
func (f *errorFilter[T]) Next(ctx context.Context) (T, error) {
var zero T
for {
res, err := f.inner.Next(ctx)
if err != nil {
if IsNonResumableErr(err) {
// bypass the filter, and just propagate the error
return zero, err
}
err = f.filter(err)
if err != nil {
return zero, err
}
continue
}
return res, nil
}
}

View File

@ -0,0 +1,118 @@
package parsers
import (
"context"
"errors"
"io"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("errorFilter", func() {
Describe("AllowErrors", func() {
var (
parser SeriesParser[struct{}]
)
BeforeEach(func() {
parser = newMockParser(func(res chan<- struct{}, err chan<- error) {
res <- struct{}{}
err <- errors.New("fail")
res <- struct{}{}
err <- errors.New("fail")
res <- struct{}{}
err <- errors.New("fail")
err <- NewNonResumableError(io.EOF)
})
})
When("0 errors are allowed", func() {
It("should fail on first error", func() {
parser = AllowErrors(parser, 0)
_, err := parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(parser.Position()).Should(Equal("call 1"))
_, err = parser.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(ErrTooManyErrors))
Expect(parser.Position()).Should(Equal("call 2"))
})
})
When("1 error is allowed", func() {
It("should fail on second error", func() {
parser = AllowErrors(parser, 1)
_, err := parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(parser.Position()).Should(Equal("call 1"))
_, err = parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(parser.Position()).Should(Equal("call 3"))
_, err = parser.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(ErrTooManyErrors))
Expect(parser.Position()).Should(Equal("call 4"))
})
})
When("using NoErrorLimit", func() {
It("should ignore all resumable errors", func() {
parser = AllowErrors(parser, NoErrorLimit)
_, err := parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(parser.Position()).Should(Equal("call 1"))
_, err = parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(parser.Position()).Should(Equal("call 3"))
_, err = parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(parser.Position()).Should(Equal("call 5"))
_, err = parser.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(parser.Position()).Should(Equal("call 7"))
})
})
})
Describe("OnErr", func() {
It("should be called for each error", func() {
inner := newMockParser(func(res chan<- string, err chan<- error) {
err <- errors.New("fail")
res <- "ok"
err <- errors.New("fail")
err <- NewNonResumableError(io.EOF)
})
parser := AllowErrors(inner, NoErrorLimit)
errors := 0
parser.OnErr(func(err error) {
errors++
})
res, err := parser.Next(context.Background())
Expect(err).Should(Succeed())
Expect(res).Should(Equal("ok"))
Expect(parser.Position()).Should(Equal("call 2"))
_, err = parser.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(errors).Should(Equal(2))
})
})
})

218
lists/parsers/hosts.go Normal file
View File

@ -0,0 +1,218 @@
package parsers
import (
"bufio"
"bytes"
"encoding"
"fmt"
"io"
"net"
"regexp"
"strings"
"github.com/asaskevich/govalidator"
"github.com/hashicorp/go-multierror"
)
const maxDomainNameLength = 255 // https://www.rfc-editor.org/rfc/rfc1034#section-3.1
var domainNameRegex = regexp.MustCompile(govalidator.DNSName)
// Hosts parses `r` as a series of `HostsIterator`.
// It supports both the hosts file and host list formats.
//
// Each item being an iterator was chosen to abstract the difference between the
// two formats where each host list entry is a single host, but a hosts file
// entry can be multiple due to aliases.
// It also avoids allocating intermediate lists.
func Hosts(r io.Reader) SeriesParser[*HostsIterator] {
return LinesAs[*HostsIterator](r)
}
type HostsIterator struct {
hostsIterator
}
type hostsIterator interface {
encoding.TextUnmarshaler
forEachHost(callback func(string) error) error
}
func (h *HostsIterator) ForEach(callback func(string) error) error {
return h.hostsIterator.forEachHost(callback)
}
func (h *HostsIterator) UnmarshalText(data []byte) error {
var mErr *multierror.Error
entries := []hostsIterator{
new(HostListEntry),
new(HostsFileEntry),
}
for _, entry := range entries {
err := entry.UnmarshalText(data)
if err != nil {
mErr = multierror.Append(mErr, err)
continue
}
h.hostsIterator = entry
return nil
}
return multierror.Flatten(mErr)
}
// HostList parses `r` as a series of `HostListEntry`.
//
// This is for the host list format commonly used by ad blockers.
func HostList(r io.Reader) SeriesParser[*HostListEntry] {
return LinesAs[*HostListEntry](r)
}
// HostListEntry is a single host.
type HostListEntry string
func (e HostListEntry) String() string {
return string(e)
}
// We assume this is used with `Lines`:
// - data will never be empty
// - comments are stripped
func (e *HostListEntry) UnmarshalText(data []byte) error {
scanner := bufio.NewScanner(bytes.NewReader(data))
scanner.Split(bufio.ScanWords)
_ = scanner.Scan() // data is not empty
host := scanner.Text()
if err := validateHostsListEntry(host); err != nil {
return err
}
if scanner.Scan() {
return fmt.Errorf("unexpected second column: %s", scanner.Text())
}
*e = HostListEntry(host)
return nil
}
func (e HostListEntry) forEachHost(callback func(string) error) error {
return callback(e.String())
}
// HostsFile parses `r` as a series of `HostsFileEntry`.
//
// This is for the hosts file format used by OSes, usually `/etc/hosts`.
func HostsFile(r io.Reader) SeriesParser[*HostsFileEntry] {
return LinesAs[*HostsFileEntry](r)
}
// HostsFileEntry is an entry from an OS hosts file.
type HostsFileEntry struct {
IP net.IP
Interface string
Name string
Aliases []string
}
// We assume this is used with `Lines`:
// - data will never be empty
// - comments are stripped
func (e *HostsFileEntry) UnmarshalText(data []byte) error {
scanner := bufio.NewScanner(bytes.NewReader(data))
scanner.Split(bufio.ScanWords)
_ = scanner.Scan() // data is not empty
ipStr := scanner.Text()
var netInterface string
// Remove interface part
if idx := strings.IndexRune(ipStr, '%'); idx != -1 {
// if `netInterface` is empty it's technically an invalid entry, but we'll ignore that here
netInterface = ipStr[idx+1:]
ipStr = ipStr[:idx]
}
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("invalid ip: %s", scanner.Text())
}
hosts := make([]string, 0, 1) // 1: there must be at least one for the line to be valid
for scanner.Scan() {
host := scanner.Text()
if err := validateDomainName(host); err != nil {
return err
}
hosts = append(hosts, host)
}
if len(hosts) == 0 {
return fmt.Errorf("expected at least one host following IP")
}
*e = HostsFileEntry{
IP: ip,
Interface: netInterface,
Name: hosts[0],
Aliases: hosts[1:],
}
return nil
}
func (e HostsFileEntry) forEachHost(callback func(string) error) error {
err := callback(e.Name)
if err != nil {
return err
}
for _, alias := range e.Aliases {
err := callback(alias)
if err != nil {
return err
}
}
return nil
}
func validateDomainName(host string) error {
if len(host) > maxDomainNameLength {
return fmt.Errorf("domain name is too long: %s", host)
}
if domainNameRegex.MatchString(host) {
return nil
}
return fmt.Errorf("invalid domain name: %s", host)
}
func validateHostsListEntry(host string) error {
if net.ParseIP(host) != nil {
return nil
}
if strings.HasPrefix(host, "/") && strings.HasSuffix(host, "/") {
_, err := regexp.Compile(host)
return err
}
return validateDomainName(host)
}

382
lists/parsers/hosts_test.go Normal file
View File

@ -0,0 +1,382 @@
package parsers
import (
"context"
"errors"
"io"
"net"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Hosts", func() {
var (
sutReader io.Reader
sut SeriesParser[*HostsIterator]
)
BeforeEach(func() {
sutReader = nil
})
JustBeforeEach(func() {
sut = Hosts(sutReader)
})
When("parsing valid lines", func() {
BeforeEach(func() {
sutReader = linesReader(
"localhost",
"# comment",
" ",
"127.0.0.1 domain.tld # comment",
"::1 localhost alias",
`/domain\.(tld|local)/`,
)
})
It("succeeds", func() {
it, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(it.ForEach)).Should(Equal([]string{"localhost"}))
Expect(sut.Position()).Should(Equal("line 1"))
it, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(it.ForEach)).Should(Equal([]string{"domain.tld"}))
Expect(sut.Position()).Should(Equal("line 4"))
it, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(it.ForEach)).Should(Equal([]string{"localhost", "alias"}))
Expect(sut.Position()).Should(Equal("line 5"))
it, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(it.ForEach)).Should(Equal([]string{`/domain\.(tld|local)/`}))
Expect(sut.Position()).Should(Equal("line 6"))
_, err = sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 7"))
})
})
When("parsing invalid lines", func() {
It("fails", func() {
lines := []string{
"invalidIP localhost",
"!notadomain!",
`/invalid regex ??/`,
}
for _, line := range lines {
sut := Hosts(strings.NewReader(line))
_, err := sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(IsNonResumableErr(err)).ShouldNot(BeTrue())
Expect(sut.Position()).Should(Equal("line 1"))
}
})
})
Describe("HostsIterator.ForEachHost", func() {
var (
entry *HostsIterator
)
BeforeEach(func() {
sutReader = linesReader(
"domain.tld",
"127.0.0.1 domain.tld alias1 alias2",
)
})
JustBeforeEach(func() {
var err error
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(entry.forEachHost)).Should(Equal([]string{"domain.tld"}))
Expect(sut.Position()).Should(Equal("line 1"))
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(entry.forEachHost)).Should(Equal([]string{"domain.tld", "alias1", "alias2"}))
Expect(sut.Position()).Should(Equal("line 2"))
})
It("calls back with the hosts", func() {})
When("callback returns error", func() {
It("fails", func() {
expectedErr := errors.New("fail")
err := entry.forEachHost(func(host string) error {
return expectedErr
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(expectedErr))
})
})
})
})
var _ = Describe("HostsFile", func() {
var (
sutReader io.Reader
sut SeriesParser[*HostsFileEntry]
)
BeforeEach(func() {
sutReader = nil
})
JustBeforeEach(func() {
sut = HostsFile(sutReader)
})
When("parsing valid lines", func() {
BeforeEach(func() {
sutReader = linesReader(
"127.0.0.1 localhost",
"# comment",
" ",
"::1 localhost # comment",
"0.0.0.0%lo0 ipWithInterface",
)
})
It("succeeds", func() {
entry, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(entry.IP).Should(Equal(net.ParseIP("127.0.0.1")))
Expect(entry.Name).Should(Equal("localhost"))
Expect(entry.Aliases).Should(BeEmpty())
Expect(sut.Position()).Should(Equal("line 1"))
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(entry.IP).Should(Equal(net.IPv6loopback))
Expect(entry.Name).Should(Equal("localhost"))
Expect(entry.Aliases).Should(BeEmpty())
Expect(sut.Position()).Should(Equal("line 4"))
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(entry.IP).Should(Equal(net.IPv4zero))
Expect(entry.Name).Should(Equal("ipWithInterface"))
Expect(entry.Aliases).Should(BeEmpty())
Expect(sut.Position()).Should(Equal("line 5"))
_, err = sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 6"))
})
When("there are aliases", func() {
BeforeEach(func() {
sutReader = linesReader(
"127.0.0.1 localhost alias1 alias2 # comment",
)
})
It("parses them", func() {
entry, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(entry.IP).Should(Equal(net.ParseIP("127.0.0.1")))
Expect(entry.Name).Should(Equal("localhost"))
Expect(entry.Aliases).Should(Equal([]string{"alias1", "alias2"}))
Expect(sut.Position()).Should(Equal("line 1"))
_, err = sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 2"))
})
})
})
When("parsing invalid lines", func() {
It("fails", func() {
lines := []string{
"127.0.0.1",
"localhost",
"localhost localhost",
"::1 # localhost # comment",
"::1 toolong" + strings.Repeat("a", maxDomainNameLength),
}
for _, line := range lines {
sut := HostsFile(strings.NewReader(line))
_, err := sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(IsNonResumableErr(err)).ShouldNot(BeTrue())
Expect(sut.Position()).Should(Equal("line 1"))
}
})
})
Describe("HostsFileEntry.forEachHost", func() {
var (
entry *HostsFileEntry
)
BeforeEach(func() {
sutReader = linesReader(
"127.0.0.1 domain.tld alias1 alias2",
)
})
JustBeforeEach(func() {
var err error
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(entry.forEachHost)).Should(Equal([]string{"domain.tld", "alias1", "alias2"}))
Expect(sut.Position()).Should(Equal("line 1"))
})
It("calls back with the host", func() {})
When("callback returns an error immediately", func() {
It("fails", func() {
expectedErr := errors.New("fail")
err := entry.forEachHost(func(host string) error {
return expectedErr
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(expectedErr))
})
})
When("callback returns an error on further calls", func() {
It("fails", func() {
expectedErr := errors.New("fail")
firstCall := true
err := entry.forEachHost(func(host string) error {
if firstCall {
firstCall = false
return nil
}
return expectedErr
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(expectedErr))
})
})
})
})
var _ = Describe("HostList", func() {
var (
sutReader io.Reader
sut SeriesParser[*HostListEntry]
)
BeforeEach(func() {
sutReader = nil
})
JustBeforeEach(func() {
sut = HostList(sutReader)
})
When("parsing valid lines", func() {
BeforeEach(func() {
sutReader = linesReader(
"localhost",
"# comment",
" ",
"domain.tld # comment",
)
})
It("succeeds", func() {
entry, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(entry.String()).Should(Equal("localhost"))
Expect(sut.Position()).Should(Equal("line 1"))
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(entry.String()).Should(Equal("domain.tld"))
Expect(sut.Position()).Should(Equal("line 4"))
_, err = sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 5"))
})
})
When("parsing invalid lines", func() {
It("fails", func() {
lines := []string{
"127.0.0.1 localhost",
"localhost localhost",
`/invalid regex ??/`,
"toolong" + strings.Repeat("a", maxDomainNameLength),
}
for _, line := range lines {
sut := HostList(strings.NewReader(line))
_, err := sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(IsNonResumableErr(err)).ShouldNot(BeTrue())
Expect(sut.Position()).Should(Equal("line 1"))
}
})
})
Describe("HostListEntry.forEachHost", func() {
var (
entry *HostListEntry
)
BeforeEach(func() {
sutReader = linesReader(
"domain.tld",
)
})
JustBeforeEach(func() {
var err error
entry, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(iteratorToList(entry.forEachHost)).Should(Equal([]string{"domain.tld"}))
Expect(sut.Position()).Should(Equal("line 1"))
})
It("calls back with the host", func() {})
When("callback returns error", func() {
It("fails", func() {
expectedErr := errors.New("fail")
err := entry.forEachHost(func(host string) error {
return expectedErr
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(expectedErr))
})
})
})
})

93
lists/parsers/lines.go Normal file
View File

@ -0,0 +1,93 @@
package parsers
import (
"bufio"
"context"
"encoding"
"fmt"
"io"
"strings"
"unicode"
)
// Lines splits `r` into a series of lines.
//
// Empty lines are skipped, and comments are stripped.
func Lines(r io.Reader) SeriesParser[string] {
return newLines(r)
}
// LinesAs returns a parser that parses each line of `r` as a `T`.
func LinesAs[TPtr TextUnmarshaler[T], T any](r io.Reader) SeriesParser[*T] {
return UnmarshalEach[TPtr](Lines(r))
}
// UnmarshalEach returns a parser that unmarshals each string of `inner` as a `T`.
func UnmarshalEach[TPtr TextUnmarshaler[T], T any](inner SeriesParser[string]) SeriesParser[*T] {
stringToBytes := func(s string) []byte {
return []byte(s)
}
return TryAdaptMethod(Adapt(inner, stringToBytes), TPtr.UnmarshalText)
}
type TextUnmarshaler[T any] interface {
encoding.TextUnmarshaler
*T
}
type lines struct {
scanner *bufio.Scanner
lineNo uint
}
func newLines(r io.Reader) SeriesParser[string] {
scanner := bufio.NewScanner(r)
scanner.Split(bufio.ScanLines)
return &lines{scanner: scanner}
}
func (l *lines) Position() string {
return fmt.Sprintf("line %d", l.lineNo)
}
func (l *lines) Next(ctx context.Context) (string, error) {
for {
l.lineNo++
if err := ctx.Err(); err != nil {
return "", NewNonResumableError(err)
}
if !l.scanner.Scan() {
break
}
text := strings.TrimSpace(l.scanner.Text())
if len(text) == 0 {
continue // empty line
}
if idx := strings.IndexRune(text, '#'); idx != -1 {
if idx == 0 {
continue // commented line
}
// end of line comment
text = text[:idx]
text = strings.TrimRightFunc(text, unicode.IsSpace)
}
return text, nil
}
err := l.scanner.Err()
if err != nil {
// bufio.Scanner does not support continuing after an error
return "", NewNonResumableError(err)
}
return "", NewNonResumableError(io.EOF)
}

192
lists/parsers/lines_test.go Normal file
View File

@ -0,0 +1,192 @@
package parsers
import (
"bufio"
"context"
"io"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Lines", func() {
var (
data string
sutReader io.Reader
sut SeriesParser[string]
)
BeforeEach(func() {
sutReader = nil
})
JustBeforeEach(func() {
if sutReader == nil {
sutReader = strings.NewReader(data)
}
sut = Lines(sutReader)
})
When("it has normal lines", func() {
BeforeEach(func() {
sutReader = linesReader(
"first",
"second",
"third",
)
})
It("returns them all", func() {
str, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("first"))
Expect(sut.Position()).Should(Equal("line 1"))
str, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("second"))
Expect(sut.Position()).Should(Equal("line 2"))
str, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("third"))
Expect(sut.Position()).Should(Equal("line 3"))
_, err = sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 4"))
})
})
When("it has empty lines", func() {
BeforeEach(func() {
sutReader = linesReader(
"",
" ",
"\t",
"\r",
)
})
It("skips them", func() {
_, err := sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(io.EOF))
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 5"))
})
})
When("it has commented lines", func() {
BeforeEach(func() {
sutReader = linesReader(
"first",
"# comment 1",
"# comment 2",
"second",
)
})
It("returns them all", func() {
str, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("first"))
Expect(sut.Position()).Should(Equal("line 1"))
str, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("second"))
Expect(sut.Position()).Should(Equal("line 4"))
})
})
When("it has end of line comments", func() {
BeforeEach(func() {
sutReader = linesReader(
"first# comment",
"second # other",
)
})
It("returns them all", func() {
str, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("first"))
Expect(sut.Position()).Should(Equal("line 1"))
str, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("second"))
Expect(sut.Position()).Should(Equal("line 2"))
})
})
When("there's a scan error", func() {
BeforeEach(func() {
sutReader = linesReader(
"too long " + strings.Repeat(".", bufio.MaxScanTokenSize),
)
})
It("fails", func() {
_, err := sut.Next(context.Background())
Expect(err).ShouldNot(Succeed())
Expect(sut.Position()).Should(Equal("line 1"))
})
})
When("context is cancelled", func() {
BeforeEach(func() {
sutReader = linesReader(
"first",
"second",
)
})
It("stops parsing", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
str, err := sut.Next(ctx)
Expect(err).Should(Succeed())
Expect(str).Should(Equal("first"))
Expect(sut.Position()).Should(Equal("line 1"))
cancel()
_, err = sut.Next(ctx)
Expect(err).ShouldNot(Succeed())
Expect(IsNonResumableErr(err)).Should(BeTrue())
Expect(sut.Position()).Should(Equal("line 2"))
})
})
When("last line has no newline", func() {
BeforeEach(func() {
data = "first\nlast"
})
It("still returns it", func() {
str, err := sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("first"))
Expect(sut.Position()).Should(Equal("line 1"))
str, err = sut.Next(context.Background())
Expect(err).Should(Succeed())
Expect(str).Should(Equal("last"))
Expect(sut.Position()).Should(Equal("line 2"))
})
})
})
func linesReader(lines ...string) io.Reader {
data := strings.Join(lines, "\n") + "\n"
return strings.NewReader(data)
}

93
lists/parsers/parser.go Normal file
View File

@ -0,0 +1,93 @@
package parsers
import (
"context"
"errors"
"fmt"
"io"
)
// SeriesParser parses a series of `T`.
type SeriesParser[T any] interface {
// Next advances the cursor in the underlying data source,
// and returns a `T`, or an error.
//
// Fatal parse errors, where no more calls to `Next` should
// be made are of type `NonResumableError`.
// Other errors apply to the item being parsed, and have no
// impact on the rest of the series.
Next(context.Context) (T, error)
// Position returns a string that gives an user readable indication
// as to where in the parser's underlying data source the cursor is.
//
// The string should be understandable easily by the user.
Position() string
}
// ForEach is a helper for consuming a parser.
//
// It stops iteration at the first error encountered.
// If that error is `io.EOF`, `nil` is returned instead.
// Any other error is wrapped with the parser's position using `ErrWithPosition`.
//
// To continue iteration on resumable errors, use with `FilterErrors`.
func ForEach[T any](ctx context.Context, parser SeriesParser[T], callback func(T) error) (rerr error) {
defer func() {
rerr = ErrWithPosition(parser, rerr)
}()
for {
if err := ctx.Err(); err != nil {
return err
}
res, err := parser.Next(ctx)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
err = callback(res)
if err != nil {
return err
}
}
}
// ErrWithPosition adds the `parser`'s position to the given `err`.
func ErrWithPosition[T any](parser SeriesParser[T], err error) error {
if err == nil {
return nil
}
return fmt.Errorf("%s: %w", parser.Position(), err)
}
// IsNonResumableErr is a helper to check if an error returned by a parser is resumable.
func IsNonResumableErr(err error) bool {
var nonResumableError *NonResumableError
return errors.As(err, &nonResumableError)
}
// NonResumableError represents an error from which a parser cannot recover.
type NonResumableError struct {
inner error
}
// NewNonResumableError creates and returns a new `NonResumableError`.
func NewNonResumableError(inner error) error {
return &NonResumableError{inner}
}
func (e *NonResumableError) Error() string {
return fmt.Sprintf("non resumable parse error: %s", e.inner.Error())
}
func (e *NonResumableError) Unwrap() error {
return e.inner
}

View File

@ -0,0 +1,192 @@
package parsers
import (
"context"
"errors"
"fmt"
"github.com/0xERR0R/blocky/util"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ForEach", func() {
var (
lines SeriesParser[string]
)
BeforeEach(func() {
lines = Lines(linesReader(
"first",
"second",
"third",
))
})
It("should iterate and hide io.EOF", func() {
list := iteratorToList(func(cb func(string) error) error {
return ForEach(context.Background(), lines, cb)
})
Expect(list).Should(Equal([]string{"first", "second", "third"}))
})
It("should return callback errors", func() {
expectedErr := errors.New("fail")
err := ForEach(context.Background(), lines, func(line string) error {
return expectedErr
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(expectedErr))
Expect(err.Error()).Should(HavePrefix("line 1: "))
})
It("should return parser errors", func() {
lines := Hosts(linesReader(
"invalid line",
))
err := ForEach(context.Background(), lines, func(*HostsIterator) error {
Fail("callback should not be called")
return nil
})
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(HavePrefix("line 1: "))
})
It("should stop when context is done", func() {
ctx, cancel := context.WithCancel(context.Background())
err := ForEach(ctx, lines, func(line string) error {
if ctx.Err() != nil {
Fail("callback should not be called")
}
cancel()
return nil
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(context.Canceled))
Expect(err.Error()).Should(HavePrefix("line 1: "))
})
It("should not start if context is already done", func() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := ForEach(ctx, lines, func(line string) error {
Fail("callback should not be called")
return nil
})
Expect(err).ShouldNot(Succeed())
Expect(err).Should(MatchError(context.Canceled))
Expect(err.Error()).Should(HavePrefix("line 0: "))
})
})
var _ = Describe("ErrWithPosition", func() {
When("err is nil", func() {
It("returns nil", func() {
inner := errors.New("inner")
lines := Lines(linesReader(
"first",
"second",
))
_, err := lines.Next(context.Background())
Expect(err).Should(Succeed())
err = ErrWithPosition(lines, inner)
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(Equal("line 1: inner"))
_, err = lines.Next(context.Background())
Expect(err).Should(Succeed())
err = ErrWithPosition(lines, inner)
Expect(err).ShouldNot(Succeed())
Expect(err.Error()).Should(Equal("line 2: inner"))
})
})
When("err is nil", func() {
It("returns nil", func() {
err := ErrWithPosition[any](nil, nil)
Expect(err).Should(Succeed())
})
})
})
var _ = Describe("NonResumableError", func() {
Describe("IsNonResumableErr", func() {
It("should return the inner error", func() {
inner := errors.New("inner")
Expect(IsNonResumableErr(inner)).Should(BeFalse())
err := NewNonResumableError(inner)
Expect(IsNonResumableErr(err)).Should(BeTrue())
})
})
Describe("Error", func() {
It("should return error message", func() {
inner := errors.New("inner")
err := NewNonResumableError(inner)
Expect(err.Error()).Should(Equal("non resumable parse error: inner"))
})
})
Describe("Unwrap", func() {
It("should return the inner error", func() {
inner := errors.New("inner")
err := NewNonResumableError(inner)
Expect(errors.Unwrap(err)).Should(Equal(inner))
Expect(errors.Is(err, inner)).Should(BeTrue())
})
})
})
func iteratorToList[T any](forEach func(func(T) error) error) []T {
var res []T
err := forEach(func(t T) error {
res = append(res, t)
return nil
})
Expect(err).Should(Succeed())
return res
}
type mockParser[T any] struct{ util.MockCallSequence[T] }
func newMockParser[T any](driver func(chan<- T, chan<- error)) SeriesParser[T] {
return &mockParser[T]{util.NewMockCallSequence(driver)}
}
func (m *mockParser[T]) Next(ctx context.Context) (_ T, rerr error) {
defer func() {
if rerr != nil && IsNonResumableErr(rerr) {
m.Close()
}
}()
if err := ctx.Err(); err != nil {
var zero T
return zero, NewNonResumableError(err)
}
return m.Call()
}
func (m *mockParser[T]) Position() string {
return fmt.Sprintf("call %d", m.CallCount())
}

View File

@ -0,0 +1,16 @@
package parsers
import (
"testing"
"github.com/0xERR0R/blocky/log"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestLists(t *testing.T) {
log.Silence()
RegisterFailHandler(Fail)
RunSpecs(t, "Parsers Suite")
}

View File

@ -1,13 +1,14 @@
package resolver
import (
"context"
"fmt"
"net"
"os"
"strings"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/lists/parsers"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
@ -15,49 +16,65 @@ import (
"github.com/sirupsen/logrus"
)
//nolint:gochecknoglobals
var (
_, loopback4, _ = net.ParseCIDR("127.0.0.0/8")
loopback6 = net.ParseIP("::1")
)
const (
hostsFileResolverLogger = "hosts_file_resolver"
// reduce initial capacity so we don't waste memory if there are less entries than before
memReleaseFactor = 2
)
type HostsFileResolver struct {
NextResolver
HostsFilePath string
hosts []host
hosts splitHostsFileData
ttl uint32
refreshPeriod time.Duration
filterLoopback bool
}
type HostsFileEntry = parsers.HostsFileEntry
func (r *HostsFileResolver) handleReverseDNS(request *model.Request) *model.Response {
question := request.Req.Question[0]
if question.Qtype == dns.TypePTR {
response := new(dns.Msg)
response.SetReply(request.Req)
if question.Qtype != dns.TypePTR {
return nil
}
for _, host := range r.hosts {
raddr, _ := dns.ReverseAddr(host.IP.String())
questionIP, err := util.ParseIPFromArpaAddr(question.Name)
if err != nil {
// ignore the parse error, and pass the request down the chain
return nil
}
if raddr == question.Name {
ptr := new(dns.PTR)
ptr.Ptr = dns.Fqdn(host.Hostname)
ptr.Hdr = util.CreateHeader(question, r.ttl)
response.Answer = append(response.Answer, ptr)
if r.filterLoopback && questionIP.IsLoopback() {
// skip the search: we won't find anything
return nil
}
for _, alias := range host.Aliases {
ptrAlias := new(dns.PTR)
ptrAlias.Ptr = dns.Fqdn(alias)
ptrAlias.Hdr = util.CreateHeader(question, r.ttl)
response.Answer = append(response.Answer, ptrAlias)
}
// search only in the hosts with an IP version that matches the question
hostsData := r.hosts.v4
if questionIP.To4() == nil {
hostsData = r.hosts.v6
}
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}
for host, hostData := range hostsData.hosts {
if hostData.IP.Equal(questionIP) {
response := new(dns.Msg)
response.SetReply(request.Req)
ptr := new(dns.PTR)
ptr.Ptr = dns.Fqdn(host)
ptr.Hdr = util.CreateHeader(question, r.ttl)
response.Answer = append(response.Answer, ptr)
for _, alias := range hostData.Aliases {
ptrAlias := new(dns.PTR)
ptrAlias.Ptr = dns.Fqdn(alias)
ptrAlias.Hdr = util.CreateHeader(question, r.ttl)
response.Answer = append(response.Answer, ptrAlias)
}
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}
}
}
@ -76,25 +93,17 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
return reverseResp, nil
}
if len(r.hosts) != 0 {
response := new(dns.Msg)
response.SetReply(request.Req)
question := request.Req.Question[0]
domain := util.ExtractDomain(question)
question := request.Req.Question[0]
domain := util.ExtractDomain(question)
response := r.resolve(request.Req, question, domain)
if response != nil {
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer),
"domain": domain,
}).Debugf("returning hosts file entry")
for _, host := range r.hosts {
response.Answer = append(response.Answer, r.processHostEntry(host, domain, question)...)
}
if len(response.Answer) > 0 {
logger.WithFields(logrus.Fields{
"answer": util.AnswerToString(response.Answer),
"domain": domain,
}).Debugf("returning hosts file entry")
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
}
return &model.Response{Res: response, RType: model.ResponseTypeHOSTSFILE, Reason: "HOSTS FILE"}, nil
}
logger.WithField("resolver", Name(r.next)).Trace("go to next resolver")
@ -102,28 +111,23 @@ func (r *HostsFileResolver) Resolve(request *model.Request) (*model.Response, er
return r.next.Resolve(request)
}
func (r *HostsFileResolver) processHostEntry(host host, domain string, question dns.Question) (result []dns.RR) {
if host.Hostname == domain {
if isSupportedType(host.IP, question) {
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
result = append(result, rr)
}
func (r *HostsFileResolver) resolve(req *dns.Msg, question dns.Question, domain string) *dns.Msg {
ip := r.hosts.getIP(dns.Type(question.Qtype), domain)
if ip == nil {
return nil
}
for _, alias := range host.Aliases {
if alias == domain {
if isSupportedType(host.IP, question) {
rr, _ := util.CreateAnswerFromQuestion(question, host.IP, r.ttl)
result = append(result, rr)
}
}
}
rr, _ := util.CreateAnswerFromQuestion(question, ip, r.ttl)
return
response := new(dns.Msg)
response.SetReply(req)
response.Answer = []dns.RR{rr}
return response
}
func (r *HostsFileResolver) Configuration() (result []string) {
if r.HostsFilePath == "" || len(r.hosts) == 0 {
if r.HostsFilePath == "" || r.hosts.isEmpty() {
return configDisabled
}
@ -143,10 +147,11 @@ func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
filterLoopback: cfg.FilterLoopback,
}
if err := r.parseHostsFile(); err != nil {
if err := r.parseHostsFile(context.Background()); err != nil {
logger := log.PrefixedLog(hostsFileResolverLogger)
logger.Warnf("cannot parse hosts file: %s, hosts file resolving is disabled", r.HostsFilePath)
r.HostsFilePath = ""
logger.Warnf("hosts file resolving is disabled: %s", err)
r.HostsFilePath = "" // don't try parsing the file again
} else {
go r.periodicUpdate()
}
@ -154,72 +159,44 @@ func NewHostsFileResolver(cfg config.HostsFileConfig) *HostsFileResolver {
return &r
}
type host struct {
IP net.IP
Hostname string
Aliases []string
}
//nolint:funlen
func (r *HostsFileResolver) parseHostsFile() error {
const minColumnCount = 2
func (r *HostsFileResolver) parseHostsFile(ctx context.Context) error {
const maxErrorsPerFile = 5
if r.HostsFilePath == "" {
return nil
}
buf, err := os.ReadFile(r.HostsFilePath)
f, err := os.Open(r.HostsFilePath)
if err != nil {
return err
}
defer f.Close()
newHosts := make([]host, 0)
newHosts := newSplitHostsDataWithSameCapacity(r.hosts)
for _, line := range strings.Split(string(buf), "\n") {
trimmed := strings.TrimSpace(line)
p := parsers.AllowErrors(parsers.HostsFile(f), maxErrorsPerFile)
p.OnErr(func(err error) {
log.PrefixedLog(hostsFileResolverLogger).Warnf("error parsing %s: %s, trying to continue", r.HostsFilePath, err)
})
if len(trimmed) == 0 || trimmed[0] == '#' {
// Skip empty and commented lines
continue
err = parsers.ForEach[*HostsFileEntry](ctx, p, func(entry *HostsFileEntry) error {
if len(entry.Interface) != 0 {
// Ignore entries with a specific interface: we don't restrict what clients/interfaces we serve entries to,
// so this avoids returning entries that can't be accessed by the client.
return nil
}
// Find comment symbol at the end of the line
var fields []string
end := strings.IndexRune(trimmed, '#')
if end == -1 {
fields = strings.Fields(trimmed)
} else {
fields = strings.Fields(trimmed[:end])
// Ignore loopback, if so configured
if r.filterLoopback && (entry.IP.IsLoopback() || entry.Name == "localhost") {
return nil
}
if len(fields) < minColumnCount {
// Skip invalid entry
continue
}
newHosts.add(entry)
if net.ParseIP(fields[0]) == nil {
// Skip invalid IP address
continue
}
var h host
h.IP = net.ParseIP(fields[0])
h.Hostname = fields[1]
// Check if loopback
if r.filterLoopback && (loopback4.Contains(h.IP) || loopback6.Equal(h.IP)) {
continue
}
if len(fields) > minColumnCount {
for i := 2; i < len(fields); i++ {
h.Aliases = append(h.Aliases, fields[i])
}
}
newHosts = append(newHosts, h)
return nil
})
if err != nil {
return fmt.Errorf("error parsing %s: %w", r.HostsFilePath, err) // err is parsers.ErrTooManyErrors
}
r.hosts = newHosts
@ -238,7 +215,88 @@ func (r *HostsFileResolver) periodicUpdate() {
logger := log.PrefixedLog(hostsFileResolverLogger)
logger.WithField("file", r.HostsFilePath).Debug("refreshing hosts file")
util.LogOnError("can't refresh hosts file: ", r.parseHostsFile())
util.LogOnError("can't refresh hosts file: ", r.parseHostsFile(context.Background()))
}
}
}
// stores hosts file data for IP versions separately
//
// Makes finding an IP for a question faster.
// Especially reverse lookups where we have to iterate through
// all the known hosts.
type splitHostsFileData struct {
v4 hostsFileData
v6 hostsFileData
}
func newSplitHostsDataWithSameCapacity(other splitHostsFileData) splitHostsFileData {
return splitHostsFileData{
v4: newHostsDataWithSameCapacity(other.v4),
v6: newHostsDataWithSameCapacity(other.v6),
}
}
func (d splitHostsFileData) isEmpty() bool {
return d.v4.isEmpty() && d.v6.isEmpty()
}
func (d splitHostsFileData) getIP(qType dns.Type, domain string) net.IP {
switch uint16(qType) {
case dns.TypeA:
return d.v4.getIP(domain)
case dns.TypeAAAA:
return d.v6.getIP(domain)
}
return nil
}
func (d splitHostsFileData) add(entry *parsers.HostsFileEntry) {
if entry.IP.To4() != nil {
d.v4.add(entry)
} else {
d.v6.add(entry)
}
}
type hostsFileData struct {
hosts map[string]hostData
aliases map[string]net.IP
}
type hostData struct {
IP net.IP
Aliases []string
}
func newHostsDataWithSameCapacity(other hostsFileData) hostsFileData {
return hostsFileData{
hosts: make(map[string]hostData, len(other.hosts)/memReleaseFactor),
aliases: make(map[string]net.IP, len(other.aliases)/memReleaseFactor),
}
}
func (d hostsFileData) isEmpty() bool {
return len(d.hosts) == 0 && len(d.aliases) == 0
}
func (d hostsFileData) getIP(hostname string) net.IP {
if hostData, ok := d.hosts[hostname]; ok {
return hostData.IP
}
if ip, ok := d.aliases[hostname]; ok {
return ip
}
return nil
}
func (d hostsFileData) add(entry *parsers.HostsFileEntry) {
d.hosts[entry.Name] = hostData{entry.IP, entry.Aliases}
for _, alias := range entry.Aliases {
d.aliases[alias] = entry.IP
}
}

View File

@ -1,6 +1,7 @@
package resolver
import (
"context"
"fmt"
"math/rand"
"time"
@ -16,10 +17,11 @@ import (
var _ = Describe("HostsFileResolver", func() {
var (
sut *HostsFileResolver
m *mockResolver
tmpDir *TmpFolder
tmpFile *TmpFile
sut *HostsFileResolver
sutConfig config.HostsFileConfig
m *mockResolver
tmpDir *TmpFolder
tmpFile *TmpFile
)
TTL := uint32(time.Now().Second())
@ -32,13 +34,16 @@ var _ = Describe("HostsFileResolver", func() {
tmpFile = writeHostFile(tmpDir)
Expect(tmpFile.Error).Should(Succeed())
cfg := config.HostsFileConfig{
sutConfig = config.HostsFileConfig{
Filepath: tmpFile.Path,
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
RefreshPeriod: config.Duration(30 * time.Minute),
FilterLoopback: true,
}
sut = NewHostsFileResolver(cfg)
})
JustBeforeEach(func() {
sut = NewHostsFileResolver(sutConfig)
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
@ -47,17 +52,18 @@ var _ = Describe("HostsFileResolver", func() {
Describe("Using hosts file", func() {
When("Hosts file cannot be located", func() {
BeforeEach(func() {
sut = NewHostsFileResolver(config.HostsFileConfig{
sutConfig = config.HostsFileConfig{
Filepath: fmt.Sprintf("/tmp/blocky/file-%d", rand.Uint64()),
HostsTTL: config.Duration(time.Duration(TTL) * time.Second),
})
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
sut.Next(m)
}
})
It("should not parse any hosts", func() {
Expect(sut.HostsFilePath).Should(BeEmpty())
Expect(sut.hosts).Should(HaveLen(0))
Expect(sut.hosts.v4.hosts).Should(BeEmpty())
Expect(sut.hosts.v6.hosts).Should(BeEmpty())
Expect(sut.hosts.v4.aliases).Should(BeEmpty())
Expect(sut.hosts.v6.aliases).Should(BeEmpty())
Expect(sut.hosts.isEmpty()).Should(BeTrue())
})
It("should go to next resolver on query", func() {
Expect(sut.Resolve(newRequest("example.com.", A))).
@ -78,7 +84,7 @@ var _ = Describe("HostsFileResolver", func() {
sut.Next(m)
})
It("should not return an error", func() {
err := sut.parseHostsFile()
err := sut.parseHostsFile(context.Background())
Expect(err).Should(Succeed())
})
It("should go to next resolver on query", func() {
@ -95,7 +101,50 @@ var _ = Describe("HostsFileResolver", func() {
When("Hosts file can be located", func() {
It("should parse it successfully", func() {
Expect(sut).ShouldNot(BeNil())
Expect(sut.hosts).Should(HaveLen(4))
Expect(sut.hosts.v4.hosts).Should(HaveLen(5))
Expect(sut.hosts.v6.hosts).Should(HaveLen(2))
Expect(sut.hosts.v4.aliases).Should(HaveLen(4))
Expect(sut.hosts.v6.aliases).Should(HaveLen(2))
})
When("filterLoopback is false", func() {
BeforeEach(func() {
sutConfig.FilterLoopback = false
})
It("should parse it successfully", func() {
Expect(sut).ShouldNot(BeNil())
Expect(sut.hosts.v4.hosts).Should(HaveLen(7))
Expect(sut.hosts.v6.hosts).Should(HaveLen(3))
Expect(sut.hosts.v4.aliases).Should(HaveLen(5))
Expect(sut.hosts.v6.aliases).Should(HaveLen(2))
})
})
})
When("Hosts file has too many errors", func() {
BeforeEach(func() {
tmpFile = tmpDir.CreateStringFile("hosts-too-many-errors.txt",
"invalidip localhost",
"127.0.0.1 localhost", // ok
"127.0.0.1 # no host",
"127.0.0.1 invalidhost!",
"a.b.c.d localhost",
"127.0.0.x localhost",
"256.0.0.1 localhost",
)
Expect(tmpFile.Error).Should(Succeed())
sutConfig.Filepath = tmpFile.Path
})
It("should not be used", func() {
Expect(sut).ShouldNot(BeNil())
Expect(sut.HostsFilePath).Should(BeEmpty())
Expect(sut.hosts.v4.hosts).Should(HaveLen(0))
Expect(sut.hosts.v6.hosts).Should(HaveLen(0))
Expect(sut.hosts.v4.aliases).Should(HaveLen(0))
Expect(sut.hosts.v6.aliases).Should(HaveLen(0))
})
})
@ -153,6 +202,24 @@ var _ = Describe("HostsFileResolver", func() {
})
})
When("the domain is not known", func() {
It("calls the next resolver", func() {
resp, err := sut.Resolve(newRequest("not-in-hostsfile.tld.", A))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
})
})
When("the question type is not handled", func() {
It("calls the next resolver", func() {
resp, err := sut.Resolve(newRequest("localhost.", MX))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
})
})
When("Reverse DNS request is received", func() {
It("should resolve the defined domain name", func() {
By("ipv4 with one hostname", func() {
@ -193,6 +260,50 @@ var _ = Describe("HostsFileResolver", func() {
))
})
})
It("should ignore invalid PTR", func() {
resp, err := sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
})
When("filterLoopback is true", func() {
It("calls the next resolver", func() {
resp, err := sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
})
})
When("the IP is not known", func() {
It("calls the next resolver", func() {
resp, err := sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR))
Expect(err).Should(Succeed())
Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE))
m.AssertExpectations(GinkgoT())
})
})
When("filterLoopback is false", func() {
BeforeEach(func() {
sutConfig.FilterLoopback = false
})
It("resolve the defined domain name", func() {
Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))).
Should(
SatisfyAll(
HaveResponseType(ResponseTypeHOSTSFILE),
HaveReturnCode(dns.RcodeSuccess),
WithTransform(ToAnswer, ContainElements(
BeDNSRecord("1.1.0.127.in-addr.arpa.", PTR, "localhost2."),
BeDNSRecord("1.1.0.127.in-addr.arpa.", PTR, "localhost2.local.lan."),
)),
))
})
})
})
})
@ -206,7 +317,7 @@ var _ = Describe("HostsFileResolver", func() {
When("hosts file is not provided", func() {
BeforeEach(func() {
sut = NewHostsFileResolver(config.HostsFileConfig{})
sutConfig = config.HostsFileConfig{}
})
It("should return 'disabled'", func() {
c := sut.Configuration()
@ -238,9 +349,17 @@ func writeHostFile(tmpDir *TmpFolder) *TmpFile {
"",
"faaf:faaf:faaf:faaf::1 ipv6host ipv6host.local.lan",
"192.168.2.1 ipv4host ipv4host.local.lan",
"faaf:faaf:faaf:faaf::2 dualhost dualhost.local.lan",
"192.168.2.2 dualhost dualhost.local.lan",
"10.0.0.1 router0 router1 router2",
"10.0.0.2 router3 # Another comment",
"10.0.0.3 # Invalid entry",
"10.0.0.3 router4#comment without a space",
"10.0.0.4 # Invalid entry",
"300.300.300.300 invalid4 # Invalid IPv4",
"abcd:efgh:ijkl::1 invalid6 # Invalud IPv6")
"abcd:efgh:ijkl::1 invalid6 # Invalid IPv6",
"1.2.3.4 localhost", // localhost name but not localhost IP
// from https://raw.githubusercontent.com/StevenBlack/hosts/master/hosts
"fe80::1%lo0 localhost", // interface name
)
}

93
util/arpa.go Normal file
View File

@ -0,0 +1,93 @@
package util
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
)
const (
IPv4PtrSuffix = ".in-addr.arpa."
IPv6PtrSuffix = ".ip6.arpa."
byteBits = 8
)
var (
ErrInvalidArpaAddrLen = errors.New("arpa hostname is not of expected length")
)
func ParseIPFromArpaAddr(arpa string) (net.IP, error) {
if strings.HasSuffix(arpa, IPv4PtrSuffix) {
return parseIPv4FromArpaAddr(arpa)
}
if strings.HasSuffix(arpa, IPv6PtrSuffix) {
return parseIPv6FromArpaAddr(arpa)
}
return nil, fmt.Errorf("invalid arpa hostname: %s", arpa)
}
func parseIPv4FromArpaAddr(arpa string) (net.IP, error) {
const base10 = 10
revAddr := strings.TrimSuffix(arpa, IPv4PtrSuffix)
parts := strings.Split(revAddr, ".")
if len(parts) != net.IPv4len {
return nil, ErrInvalidArpaAddrLen
}
buf := make([]byte, 0, net.IPv4len)
// Parse and add each byte, in reverse, to the buffer
for i := len(parts) - 1; i >= 0; i-- {
part, err := strconv.ParseUint(parts[i], base10, byteBits)
if err != nil {
return nil, err
}
buf = append(buf, byte(part))
}
return net.IPv4(buf[0], buf[1], buf[2], buf[3]), nil
}
func parseIPv6FromArpaAddr(arpa string) (net.IP, error) {
const (
base16 = 16
ipv6Bytes = 2 * net.IPv6len
nibbleBits = byteBits / 2
)
revAddr := strings.TrimSuffix(arpa, IPv6PtrSuffix)
parts := strings.Split(revAddr, ".")
if len(parts) != ipv6Bytes {
return nil, ErrInvalidArpaAddrLen
}
buf := make([]byte, 0, net.IPv6len)
// Parse and add each byte, in reverse, to the buffer
for i := len(parts) - 1; i >= 0; i -= 2 {
msNibble, err := strconv.ParseUint(parts[i], base16, byteBits)
if err != nil {
return nil, err
}
lsNibble, err := strconv.ParseUint(parts[i-1], base16, byteBits)
if err != nil {
return nil, err
}
part := msNibble<<nibbleBits | lsNibble
buf = append(buf, byte(part))
}
return net.IP(buf), nil
}

98
util/arpa_test.go Normal file
View File

@ -0,0 +1,98 @@
package util
import (
"net"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ParseIPFromArpaAddr", func() {
Describe("IPv4", func() {
It("parses an IP correctly", func() {
ip, err := ParseIPFromArpaAddr("4.3.2.1.in-addr.arpa.")
Expect(err).Should(Succeed())
Expect(ip).Should(Equal(net.ParseIP("1.2.3.4")))
})
It("requires the arpa domain", func() {
_, err := ParseIPFromArpaAddr("4.3.2.1.in-addr.arpa.fail.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("4.3.2.1.in-addr.fail.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("4.3.2.1.fail.in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all ip parts to be decimal numbers", func() {
_, err := ParseIPFromArpaAddr("a.3.2.1.in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all parts to be present", func() {
_, err := ParseIPFromArpaAddr("3.2.1.in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all parts to be non empty", func() {
_, err := ParseIPFromArpaAddr(".3.2.1.in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("4..2.1.in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("4.3..1.in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("4.3.2..in-addr.arpa.")
Expect(err).ShouldNot(Succeed())
})
})
Describe("IPv6", func() {
It("parses an IP correctly", func() {
ip, err := ParseIPFromArpaAddr("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.")
Expect(err).Should(Succeed())
Expect(ip).Should(Equal(net.ParseIP("2002:7f00:1::1")))
})
It("requires the arpa domain", func() {
_, err := ParseIPFromArpaAddr("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.fail.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.fail.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.fail.ip6.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all LSB parts to be hex numbers", func() {
_, err := ParseIPFromArpaAddr("g.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all MSB parts to be hex numbers", func() {
_, err := ParseIPFromArpaAddr("1.g.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all parts to be present", func() {
_, err := ParseIPFromArpaAddr("1.0.0.0.0.0.0.0.0.0.0.0.0.g.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.")
Expect(err).ShouldNot(Succeed())
})
It("requires all parts to non empty", func() {
_, err := ParseIPFromArpaAddr(".0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..0.0.0.0.1.0.0.0.0.0.f.7.2.0.0.2.ip6.arpa.")
Expect(err).ShouldNot(Succeed())
_, err = ParseIPFromArpaAddr("0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..0.0.0.0.1.0.0.0.0.0.f.7.2.0.0..ip6.arpa.")
Expect(err).ShouldNot(Succeed())
})
})
})

View File

@ -0,0 +1,78 @@
package util
import (
"context"
"fmt"
"sync"
"time"
)
const mockCallTimeout = 2 * time.Second
type MockCallSequence[T any] struct {
driver func(chan<- T, chan<- error)
res chan T
err chan error
callCount uint
initOnce sync.Once
closeOnce sync.Once
}
func NewMockCallSequence[T any](driver func(chan<- T, chan<- error)) MockCallSequence[T] {
return MockCallSequence[T]{
driver: driver,
}
}
func (m *MockCallSequence[T]) Call() (T, error) {
m.callCount++
m.initOnce.Do(func() {
m.res = make(chan T)
m.err = make(chan error)
// This goroutine never stops
go func() {
defer m.Close()
m.driver(m.res, m.err)
}()
})
ctx, cancel := context.WithTimeout(context.Background(), mockCallTimeout)
defer cancel()
select {
case t, ok := <-m.res:
if !ok {
break
}
return t, nil
case err, ok := <-m.err:
if !ok {
break
}
var zero T
return zero, err
case <-ctx.Done():
panic(fmt.Sprintf("mock call sequence driver timed-out on call %d", m.CallCount()))
}
panic("mock call sequence called after driver returned (or sequence Close was called explicitly)")
}
func (m *MockCallSequence[T]) CallCount() uint {
return m.callCount
}
func (m *MockCallSequence[T]) Close() {
m.closeOnce.Do(func() {
close(m.res)
close(m.err)
})
}