Refactoring: FileDownloader (#1281)

* enabled containedctx & contextcheck

* enabled noctx

* less background context

* context metrics test

* use ginkgo context instead of background

* fix redis e2e tests

* made downloader context aware
This commit is contained in:
Kwitsch 2023-11-29 18:18:29 +01:00 committed by GitHub
parent 976d6198f1
commit 3378316982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 232 additions and 212 deletions

View File

@ -55,9 +55,10 @@ linters:
- whitespace
- wsl
- ginkgolinter
disable:
- noctx
- containedctx
- contextcheck
disable:
- scopelint
- structcheck
- deadcode

View File

@ -18,15 +18,15 @@ var _ = Describe("Basic functional tests", func() {
var err error
Describe("Container start", func() {
BeforeEach(func() {
moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`)
BeforeEach(func(ctx context.Context) {
moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`)
Expect(err).Should(Succeed())
DeferCleanup(moka.Terminate)
})
When("wrong port configuration is provided", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -38,22 +38,22 @@ var _ = Describe("Basic functional tests", func() {
Expect(err).Should(HaveOccurred())
// check container exit status
state, err := blocky.State(context.Background())
state, err := blocky.State(ctx)
Expect(err).Should(Succeed())
Expect(state.ExitCode).Should(Equal(1))
DeferCleanup(blocky.Terminate)
})
It("should fail to start", func() {
It("should fail to start", func(ctx context.Context) {
Eventually(blocky.IsRunning, "5s", "2ms").Should(BeFalse())
Expect(getContainerLogs(blocky)).
Expect(getContainerLogs(ctx, blocky)).
Should(ContainElement(ContainSubstring("address already in use")))
})
})
When("Minimal configuration is provided", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -63,19 +63,19 @@ var _ = Describe("Basic functional tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("Should start and answer DNS queries", func() {
It("Should start and answer DNS queries", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("google.de.", A)
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
})
It("should return 'healthy' container status (healthcheck)", func() {
It("should return 'healthy' container status (healthcheck)", func(ctx context.Context) {
Eventually(func(g Gomega) string {
state, err := blocky.State(context.Background())
state, err := blocky.State(ctx)
g.Expect(err).NotTo(HaveOccurred())
return state.Health.Status
@ -84,8 +84,8 @@ var _ = Describe("Basic functional tests", func() {
})
Context("http port configuration", func() {
When("'httpPort' is not defined", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -96,8 +96,8 @@ var _ = Describe("Basic functional tests", func() {
DeferCleanup(blocky.Terminate)
})
It("should not open http port", func() {
host, port, err := getContainerHostPort(blocky, "4000/tcp")
It("should not open http port", func(ctx context.Context) {
host, port, err := getContainerHostPort(ctx, blocky, "4000/tcp")
Expect(err).Should(Succeed())
_, err = http.Get(fmt.Sprintf("http://%s", net.JoinHostPort(host, port)))
@ -105,8 +105,8 @@ var _ = Describe("Basic functional tests", func() {
})
})
When("'httpPort' is defined", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -118,8 +118,8 @@ var _ = Describe("Basic functional tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("should serve http content", func() {
host, port, err := getContainerHostPort(blocky, "4000/tcp")
It("should serve http content", func(ctx context.Context) {
host, port, err := getContainerHostPort(ctx, blocky, "4000/tcp")
Expect(err).Should(Succeed())
url := fmt.Sprintf("http://%s", net.JoinHostPort(host, port))
@ -142,15 +142,15 @@ var _ = Describe("Basic functional tests", func() {
})
Describe("Logging", func() {
BeforeEach(func() {
moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`)
BeforeEach(func(ctx context.Context) {
moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`)
Expect(err).Should(Succeed())
DeferCleanup(moka.Terminate)
})
When("log privacy is enabled", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -162,27 +162,27 @@ var _ = Describe("Basic functional tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("should not log answers and questions", func() {
It("should not log answers and questions", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("google.com.", A)
// do 2 requests
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("<=", 123)),
))
Expect(getContainerLogs(blocky)).Should(Not(ContainElement(ContainSubstring("google.com"))))
Expect(getContainerLogs(blocky)).Should(Not(ContainElement(ContainSubstring("1.2.3.4"))))
Expect(getContainerLogs(ctx, blocky)).Should(Not(ContainElement(ContainSubstring("google.com"))))
Expect(getContainerLogs(ctx, blocky)).Should(Not(ContainElement(ContainSubstring("1.2.3.4"))))
})
})
})

View File

@ -13,8 +13,8 @@ import (
var _ = Describe("External lists and query blocking", func() {
var blocky, httpServer, moka testcontainers.Container
var err error
BeforeEach(func() {
moka, err = createDNSMokkaContainer("moka", `A google/NOERROR("A 1.2.3.4 123")`)
BeforeEach(func(ctx context.Context) {
moka, err = createDNSMokkaContainer(ctx, "moka", `A google/NOERROR("A 1.2.3.4 123")`)
Expect(err).Should(Succeed())
DeferCleanup(moka.Terminate)
@ -22,8 +22,8 @@ var _ = Describe("External lists and query blocking", func() {
Describe("List download on startup", func() {
When("external blacklist ist not available", func() {
Context("loading.strategy = blocking", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -45,22 +45,22 @@ var _ = Describe("External lists and query blocking", func() {
DeferCleanup(blocky.Terminate)
})
It("should start with warning in log work without errors", func() {
It("should start with warning in log work without errors", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("google.com.", A)
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.com.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
Expect(getContainerLogs(blocky)).Should(ContainElement(ContainSubstring("cannot open source: ")))
Expect(getContainerLogs(ctx, blocky)).Should(ContainElement(ContainSubstring("cannot open source: ")))
})
})
Context("loading.strategy = failOnError", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -81,17 +81,17 @@ var _ = Describe("External lists and query blocking", func() {
Expect(err).Should(HaveOccurred())
// check container exit status
state, err := blocky.State(context.Background())
state, err := blocky.State(ctx)
Expect(err).Should(Succeed())
Expect(state.ExitCode).Should(Equal(1))
DeferCleanup(blocky.Terminate)
})
It("should fail to start", func() {
It("should fail to start", func(ctx context.Context) {
Eventually(blocky.IsRunning, "5s", "2ms").Should(BeFalse())
Expect(getContainerLogs(blocky)).
Expect(getContainerLogs(ctx, blocky)).
Should(ContainElement(ContainSubstring("Error: can't start server: 1 error occurred")))
})
})
@ -99,13 +99,13 @@ var _ = Describe("External lists and query blocking", func() {
})
Describe("Query blocking against external blacklists", func() {
When("external blacklists are defined and available", func() {
BeforeEach(func() {
httpServer, err = createHTTPServerContainer("httpserver", tmpDir, "list.txt", "blockeddomain.com")
BeforeEach(func(ctx context.Context) {
httpServer, err = createHTTPServerContainer(ctx, "httpserver", tmpDir, "list.txt", "blockeddomain.com")
Expect(err).Should(Succeed())
DeferCleanup(httpServer.Terminate)
blocky, err = createBlockyContainer(tmpDir,
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -124,17 +124,17 @@ var _ = Describe("External lists and query blocking", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("should download external list on startup and block queries", func() {
It("should download external list on startup and block queries", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("blockeddomain.com.", A)
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("blockeddomain.com.", A, "0.0.0.0"),
HaveTTL(BeNumerically("==", 6*60*60)),
))
Expect(getContainerLogs(blocky)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty())
})
})
})

View File

@ -39,9 +39,7 @@ const (
blockyImage = "blocky-e2e"
)
func createDNSMokkaContainer(alias string, rules ...string) (testcontainers.Container, error) {
ctx := context.Background()
func createDNSMokkaContainer(ctx context.Context, alias string, rules ...string) (testcontainers.Container, error) {
mokaRules := make(map[string]string)
for i, rule := range rules {
@ -63,7 +61,7 @@ func createDNSMokkaContainer(alias string, rules ...string) (testcontainers.Cont
})
}
func createHTTPServerContainer(alias string, tmpDir *helpertest.TmpFolder,
func createHTTPServerContainer(ctx context.Context, alias string, tmpDir *helpertest.TmpFolder,
filename string, lines ...string,
) (testcontainers.Container, error) {
f1 := tmpDir.CreateStringFile(filename,
@ -75,7 +73,6 @@ func createHTTPServerContainer(alias string, tmpDir *helpertest.TmpFolder,
const modeOwner = 700
ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: staticServerImage,
Networks: []string{NetworkName},
@ -105,9 +102,7 @@ func WithNetwork(network string) testcontainers.CustomizeRequestOption {
}
}
func createRedisContainer() (*redis.RedisContainer, error) {
ctx := context.Background()
func createRedisContainer(ctx context.Context) (*redis.RedisContainer, error) {
return redis.RunContainer(ctx,
testcontainers.WithImage(redisImage),
redis.WithLogLevel(redis.LogLevelVerbose),
@ -115,9 +110,7 @@ func createRedisContainer() (*redis.RedisContainer, error) {
)
}
func createPostgresContainer() (*postgres.PostgresContainer, error) {
ctx := context.Background()
func createPostgresContainer(ctx context.Context) (*postgres.PostgresContainer, error) {
const waitLogOccurrence = 2
return postgres.RunContainer(ctx,
@ -134,9 +127,7 @@ func createPostgresContainer() (*postgres.PostgresContainer, error) {
)
}
func createMariaDBContainer() (*mariadb.MariaDBContainer, error) {
ctx := context.Background()
func createMariaDBContainer(ctx context.Context) (*mariadb.MariaDBContainer, error) {
return mariadb.RunContainer(ctx,
testcontainers.WithImage(mariaDBImage),
mariadb.WithDatabase("user"),
@ -151,7 +142,9 @@ const (
startupTimeout = 30 * time.Second
)
func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testcontainers.Container, error) {
func createBlockyContainer(ctx context.Context, tmpDir *helpertest.TmpFolder,
lines ...string,
) (testcontainers.Container, error) {
f1 := tmpDir.CreateStringFile("config1.yaml",
lines...,
)
@ -164,7 +157,6 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc
return nil, fmt.Errorf("can't create config struct %w", err)
}
ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: blockyImage,
Networks: []string{NetworkName},
@ -192,7 +184,7 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc
})
if err != nil {
// attach container log if error occurs
if r, err := container.Logs(context.Background()); err == nil {
if r, err := container.Logs(ctx); err == nil {
if b, err := io.ReadAll(r); err == nil {
ginkgo.AddReportEntry("blocky container log", string(b))
}
@ -203,7 +195,7 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc
// check if DNS/HTTP interface is working.
// Sometimes the internal health check returns OK, but the container port is not mapped yet
err = checkBlockyReadiness(cfg, container)
err = checkBlockyReadiness(ctx, cfg, container)
if err != nil {
return container, fmt.Errorf("container not ready: %w", err)
@ -212,14 +204,14 @@ func createBlockyContainer(tmpDir *helpertest.TmpFolder, lines ...string) (testc
return container, nil
}
func checkBlockyReadiness(cfg *config.Config, container testcontainers.Container) error {
func checkBlockyReadiness(ctx context.Context, cfg *config.Config, container testcontainers.Container) error {
var err error
const retryAttempts = 3
err = retry.Do(
func() error {
_, err = doDNSRequest(container, util.NewMsgWithQuestion("healthcheck.blocky.", dns.Type(dns.TypeA)))
_, err = doDNSRequest(ctx, container, util.NewMsgWithQuestion("healthcheck.blocky.", dns.Type(dns.TypeA)))
return err
},
@ -239,7 +231,7 @@ func checkBlockyReadiness(cfg *config.Config, container testcontainers.Container
port := parts[len(parts)-1]
err = retry.Do(
func() error {
return doHTTPRequest(container, port)
return doHTTPRequest(ctx, container, port)
},
retry.OnRetry(func(n uint, err error) {
log.Infof("Performing retry HTTP request #%d: %s\n", n, err)
@ -256,13 +248,19 @@ func checkBlockyReadiness(cfg *config.Config, container testcontainers.Container
return nil
}
func doHTTPRequest(container testcontainers.Container, containerPort string) error {
host, port, err := getContainerHostPort(container, nat.Port(fmt.Sprintf("%s/tcp", containerPort)))
func doHTTPRequest(ctx context.Context, container testcontainers.Container, containerPort string) error {
host, port, err := getContainerHostPort(ctx, container, nat.Port(fmt.Sprintf("%s/tcp", containerPort)))
if err != nil {
return err
}
resp, err := http.Get(fmt.Sprintf("http://%s", net.JoinHostPort(host, port)))
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("http://%s", net.JoinHostPort(host, port)), nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
@ -276,7 +274,7 @@ func doHTTPRequest(container testcontainers.Container, containerPort string) err
return err
}
func doDNSRequest(container testcontainers.Container, message *dns.Msg) (*dns.Msg, error) {
func doDNSRequest(ctx context.Context, container testcontainers.Container, message *dns.Msg) (*dns.Msg, error) {
const timeout = 5 * time.Second
c := &dns.Client{
@ -284,7 +282,7 @@ func doDNSRequest(container testcontainers.Container, message *dns.Msg) (*dns.Ms
Timeout: timeout,
}
host, port, err := getContainerHostPort(container, "53/tcp")
host, port, err := getContainerHostPort(ctx, container, "53/tcp")
if err != nil {
return nil, err
}
@ -294,13 +292,13 @@ func doDNSRequest(container testcontainers.Container, message *dns.Msg) (*dns.Ms
return msg, err
}
func getContainerHostPort(c testcontainers.Container, p nat.Port) (host, port string, err error) {
res, err := c.MappedPort(context.Background(), p)
func getContainerHostPort(ctx context.Context, c testcontainers.Container, p nat.Port) (host, port string, err error) {
res, err := c.MappedPort(ctx, p)
if err != nil {
return "", "", err
}
host, err = c.Host(context.Background())
host, err = c.Host(ctx)
if err != nil {
return "", "", err
@ -309,8 +307,8 @@ func getContainerHostPort(c testcontainers.Container, p nat.Port) (host, port st
return host, res.Port(), err
}
func getContainerLogs(c testcontainers.Container) (lines []string, err error) {
if r, err := c.Logs(context.Background()); err == nil {
func getContainerLogs(ctx context.Context, c testcontainers.Container) (lines []string, err error) {
if r, err := c.Logs(ctx); err == nil {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()

View File

@ -28,10 +28,10 @@ var (
tmpDir *helpertest.TmpFolder
)
var _ = BeforeSuite(func() {
var _ = BeforeSuite(func(ctx context.Context) {
var err error
network, err = testcontainers.GenericNetwork(context.Background(), testcontainers.GenericNetworkRequest{
network, err = testcontainers.GenericNetwork(ctx, testcontainers.GenericNetworkRequest{
NetworkRequest: testcontainers.NetworkRequest{
Name: NetworkName,
CheckDuplicate: false,
@ -41,10 +41,10 @@ var _ = BeforeSuite(func() {
Expect(err).Should(Succeed())
DeferCleanup(func() {
DeferCleanup(func(ctx context.Context) {
err := retry.Do(
func() error {
return network.Remove(context.Background())
return network.Remove(ctx)
},
retry.Attempts(3),
retry.DelayType(retry.BackOffDelay),

View File

@ -2,6 +2,7 @@ package e2e
import (
"bufio"
"context"
"fmt"
"net"
"net/http"
@ -20,23 +21,24 @@ var _ = Describe("Metrics functional tests", func() {
var metricsURL string
Describe("Metrics", func() {
BeforeEach(func() {
moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`)
BeforeEach(func(ctx context.Context) {
moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`)
Expect(err).Should(Succeed())
DeferCleanup(moka.Terminate)
httpServer1, err = createHTTPServerContainer("httpserver1", tmpDir, "list1.txt", "domain1.com")
httpServer1, err = createHTTPServerContainer(ctx, "httpserver1", tmpDir, "list1.txt", "domain1.com")
Expect(err).Should(Succeed())
DeferCleanup(httpServer1.Terminate)
httpServer2, err = createHTTPServerContainer("httpserver2", tmpDir, "list2.txt", "domain1.com", "domain2", "domain3")
httpServer2, err = createHTTPServerContainer(ctx, "httpserver2", tmpDir, "list2.txt",
"domain1.com", "domain2", "domain3")
Expect(err).Should(Succeed())
DeferCleanup(httpServer2.Terminate)
blocky, err = createBlockyContainer(tmpDir,
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -56,26 +58,26 @@ var _ = Describe("Metrics functional tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
host, port, err := getContainerHostPort(blocky, "4000/tcp")
host, port, err := getContainerHostPort(ctx, blocky, "4000/tcp")
Expect(err).Should(Succeed())
metricsURL = fmt.Sprintf("http://%s/metrics", net.JoinHostPort(host, port))
})
When("Blocky is started", func() {
It("Should provide 'blocky_build_info' prometheus metrics", func() {
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
It("Should provide 'blocky_build_info' prometheus metrics", func(ctx context.Context) {
Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL).
Should(ContainElement(ContainSubstring("blocky_build_info")))
})
It("Should provide 'blocky_blocking_enabled' prometheus metrics", func() {
Eventually(fetchBlockyMetrics, "30s", "2ms").WithArguments(metricsURL).
It("Should provide 'blocky_blocking_enabled' prometheus metrics", func(ctx context.Context) {
Eventually(fetchBlockyMetrics, "30s", "2ms").WithArguments(ctx, metricsURL).
Should(ContainElement("blocky_blocking_enabled 1"))
})
})
When("Some query results are cached", func() {
BeforeEach(func() {
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
BeforeEach(func(ctx context.Context) {
Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_cache_entry_count 0"),
@ -84,18 +86,18 @@ var _ = Describe("Metrics functional tests", func() {
))
})
It("Should increment cache counts", func() {
It("Should increment cache counts", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("google.de.", A)
By("first query, should increment the cache miss count and the total count", func() {
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("==", 123)),
))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_cache_entry_count 1"),
@ -105,14 +107,14 @@ var _ = Describe("Metrics functional tests", func() {
})
By("Same query again, should increment the cache hit count", func() {
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
HaveTTL(BeNumerically("<=", 123)),
))
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_cache_entry_count 1"),
@ -124,8 +126,8 @@ var _ = Describe("Metrics functional tests", func() {
})
When("Lists are loaded", func() {
It("Should expose list cache sizes per group as metrics", func() {
Eventually(fetchBlockyMetrics).WithArguments(metricsURL).
It("Should expose list cache sizes per group as metrics", func(ctx context.Context) {
Eventually(fetchBlockyMetrics).WithArguments(ctx, metricsURL).
Should(
SatisfyAll(
ContainElement("blocky_blacklist_cache{group=\"group1\"} 1"),
@ -136,10 +138,15 @@ var _ = Describe("Metrics functional tests", func() {
})
})
func fetchBlockyMetrics(url string) ([]string, error) {
func fetchBlockyMetrics(ctx context.Context, url string) ([]string, error) {
var metrics []string
r, err := http.Get(url)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
r, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}

View File

@ -22,20 +22,20 @@ var _ = Describe("Query logs functional tests", func() {
var db *gorm.DB
var err error
BeforeEach(func() {
moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`, `A unknown/NXDOMAIN()`)
BeforeEach(func(ctx context.Context) {
moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`, `A unknown/NXDOMAIN()`)
Expect(err).Should(Succeed())
DeferCleanup(moka.Terminate)
})
Describe("Query logging into the mariaDB database", func() {
BeforeEach(func() {
mariaDB, err = createMariaDBContainer()
BeforeEach(func(ctx context.Context) {
mariaDB, err = createMariaDBContainer(ctx)
Expect(err).Should(Succeed())
DeferCleanup(mariaDB.Terminate)
blocky, err = createBlockyContainer(tmpDir,
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -53,7 +53,7 @@ var _ = Describe("Query logs functional tests", func() {
Expect(err).Should(Succeed())
connectionString, err := mariaDB.ConnectionString(context.Background(),
connectionString, err := mariaDB.ConnectionString(ctx,
"tls=false", "charset=utf8mb4", "parseTime=True", "loc=Local")
Expect(err).Should(Succeed())
@ -67,10 +67,12 @@ var _ = Describe("Query logs functional tests", func() {
Eventually(countEntries).WithArguments(db).Should(BeNumerically("==", 0))
})
When("Some queries were performed", func() {
It("Should store query log in the mariaDB database", func() {
It("Should store query log in the mariaDB database", func(ctx context.Context) {
By("Performing 2 queries", func() {
Expect(doDNSRequest(blocky, util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)))).Should(Not(BeNil()))
Expect(doDNSRequest(blocky, util.NewMsgWithQuestion("unknown.domain.", dns.Type(dns.TypeA)))).Should(Not(BeNil()))
Expect(doDNSRequest(ctx, blocky,
util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA)))).Should(Not(BeNil()))
Expect(doDNSRequest(ctx, blocky,
util.NewMsgWithQuestion("unknown.domain.", dns.Type(dns.TypeA)))).Should(Not(BeNil()))
})
By("check entries count asynchronously, since blocky flushes log entries in bulk", func() {
@ -108,12 +110,12 @@ var _ = Describe("Query logs functional tests", func() {
})
Describe("Query logging into the postgres database", func() {
BeforeEach(func() {
postgresDB, err = createPostgresContainer()
BeforeEach(func(ctx context.Context) {
postgresDB, err = createPostgresContainer(ctx)
Expect(err).Should(Succeed())
DeferCleanup(postgresDB.Terminate)
blocky, err = createBlockyContainer(tmpDir,
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -129,7 +131,7 @@ var _ = Describe("Query logs functional tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
connectionString, err := postgresDB.ConnectionString(context.Background(), "sslmode=disable")
connectionString, err := postgresDB.ConnectionString(ctx, "sslmode=disable")
Expect(err).Should(Succeed())
// database might be slow on first start, retry here if necessary
@ -143,10 +145,10 @@ var _ = Describe("Query logs functional tests", func() {
})
When("Some queries were performed", func() {
msg := util.NewMsgWithQuestion("google.de.", dns.Type(dns.TypeA))
It("Should store query log in the postgres database", func() {
It("Should store query log in the postgres database", func(ctx context.Context) {
By("Performing 2 queries", func() {
Expect(doDNSRequest(blocky, msg)).Should(Not(BeNil()))
Expect(doDNSRequest(blocky, msg)).Should(Not(BeNil()))
Expect(doDNSRequest(ctx, blocky, msg)).Should(Not(BeNil()))
Expect(doDNSRequest(ctx, blocky, msg)).Should(Not(BeNil()))
})
By("check entries count asynchronously, since blocky flushes log entries in bulk", func() {

View File

@ -19,13 +19,13 @@ var _ = Describe("Redis configuration tests", func() {
var redisClient *redis.Client
var err error
BeforeEach(func() {
redisDB, err = createRedisContainer()
BeforeEach(func(ctx context.Context) {
redisDB, err = createRedisContainer(ctx)
Expect(err).Should(Succeed())
DeferCleanup(redisDB.Terminate)
redisConnectionString, err := redisDB.ConnectionString(context.Background())
redisConnectionString, err := redisDB.ConnectionString(ctx)
Expect(err).Should(Succeed())
redisConnectionString = strings.ReplaceAll(redisConnectionString, "redis://", "")
@ -34,20 +34,20 @@ var _ = Describe("Redis configuration tests", func() {
Addr: redisConnectionString,
})
Expect(dbSize(redisClient)).Should(BeNumerically("==", 0))
Expect(dbSize(ctx, redisClient)).Should(BeNumerically("==", 0))
moka, err = createDNSMokkaContainer("moka1", `A google/NOERROR("A 1.2.3.4 123")`)
moka, err = createDNSMokkaContainer(ctx, "moka1", `A google/NOERROR("A 1.2.3.4 123")`)
Expect(err).Should(Succeed())
DeferCleanup(func() {
_ = moka.Terminate(context.Background())
DeferCleanup(func(ctx context.Context) {
_ = moka.Terminate(ctx)
})
})
Describe("Cache sharing between blocky instances", func() {
When("Redis and 2 blocky instances are configured", func() {
BeforeEach(func() {
blocky1, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky1, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -61,7 +61,7 @@ var _ = Describe("Redis configuration tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky1.Terminate)
blocky2, err = createBlockyContainer(tmpDir,
blocky2, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -75,10 +75,10 @@ var _ = Describe("Redis configuration tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky2.Terminate)
})
It("2nd instance of blocky should use cache from redis", func() {
It("2nd instance of blocky should use cache from redis", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("google.de.", A)
By("Query first blocky instance, should store cache in redis", func() {
Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky1, msg).
Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky1, msg).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
@ -87,15 +87,15 @@ var _ = Describe("Redis configuration tests", func() {
})
By("Check redis, must contain one cache entry", func() {
Eventually(dbSize, "5s", "2ms").WithArguments(redisClient).Should(BeNumerically("==", 1))
Eventually(dbSize, "5s", "2ms").WithArguments(ctx, redisClient).Should(BeNumerically("==", 1))
})
By("Shutdown the upstream DNS server", func() {
Expect(moka.Terminate(context.Background())).Should(Succeed())
Expect(moka.Terminate(ctx)).Should(Succeed())
})
By("Query second blocky instance, should use cache from redis", func() {
Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky2, msg).
Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky2, msg).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
@ -104,8 +104,8 @@ var _ = Describe("Redis configuration tests", func() {
})
By("No warnings/errors in log", func() {
Expect(getContainerLogs(blocky1)).Should(BeEmpty())
Expect(getContainerLogs(blocky2)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky1)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky2)).Should(BeEmpty())
})
})
})
@ -113,8 +113,8 @@ var _ = Describe("Redis configuration tests", func() {
Describe("Cache loading on startup", func() {
When("Redis and 1 blocky instance are configured", func() {
BeforeEach(func() {
blocky1, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky1, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -128,10 +128,10 @@ var _ = Describe("Redis configuration tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky1.Terminate)
})
It("should load cache from redis after start", func() {
It("should load cache from redis after start", func(ctx context.Context) {
msg := util.NewMsgWithQuestion("google.de.", A)
By("Query first blocky instance, should store cache in redis\"", func() {
Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky1, msg).
Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky1, msg).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
@ -140,11 +140,11 @@ var _ = Describe("Redis configuration tests", func() {
})
By("Check redis, must contain one cache entry", func() {
Eventually(dbSize).WithArguments(redisClient).Should(BeNumerically("==", 1))
Eventually(dbSize).WithArguments(ctx, redisClient).Should(BeNumerically("==", 1))
})
By("start other instance of blocky now -> it should load the cache from redis", func() {
blocky2, err = createBlockyContainer(tmpDir,
blocky2, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -160,11 +160,11 @@ var _ = Describe("Redis configuration tests", func() {
})
By("Shutdown the upstream DNS server", func() {
Expect(moka.Terminate(context.Background())).Should(Succeed())
Expect(moka.Terminate(ctx)).Should(Succeed())
})
By("Query second blocky instance", func() {
Eventually(doDNSRequest, "5s", "2ms").WithArguments(blocky2, msg).
Eventually(doDNSRequest, "5s", "2ms").WithArguments(ctx, blocky2, msg).
Should(
SatisfyAll(
BeDNSRecord("google.de.", A, "1.2.3.4"),
@ -173,14 +173,14 @@ var _ = Describe("Redis configuration tests", func() {
})
By("No warnings/errors in log", func() {
Expect(getContainerLogs(blocky1)).Should(BeEmpty())
Expect(getContainerLogs(blocky2)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky1)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky2)).Should(BeEmpty())
})
})
})
})
})
func dbSize(redisClient *redis.Client) (int64, error) {
return redisClient.DBSize(context.Background()).Result()
func dbSize(ctx context.Context, redisClient *redis.Client) (int64, error) {
return redisClient.DBSize(ctx).Result()
}

View File

@ -1,6 +1,8 @@
package e2e
import (
"context"
. "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
@ -15,8 +17,8 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Describe("'upstreams.startVerify' parameter handling", func() {
When("'upstreams.startVerify' is false and upstream server as IP is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -29,14 +31,14 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("should start even if upstream server is not reachable", func() {
It("should start even if upstream server is not reachable", func(ctx context.Context) {
Expect(blocky.IsRunning()).Should(BeTrue())
Expect(getContainerLogs(blocky)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty())
})
})
When("'upstreams.startVerify' is false and upstream server as host name is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"log:",
" level: warn",
"upstreams:",
@ -49,14 +51,14 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("should start even if upstream server is not reachable", func() {
It("should start even if upstream server is not reachable", func(ctx context.Context) {
Expect(blocky.IsRunning()).Should(BeTrue())
Expect(getContainerLogs(blocky)).Should(BeEmpty())
Expect(getContainerLogs(ctx, blocky)).Should(BeEmpty())
})
})
When("'upstreams.startVerify' is true and upstream as IP address server is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -67,15 +69,15 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(err).Should(HaveOccurred())
DeferCleanup(blocky.Terminate)
})
It("should not start", func() {
It("should not start", func(ctx context.Context) {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Expect(getContainerLogs(ctx, blocky)).
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
})
})
When("'upstreams.startVerify' is true and upstream server as host name is not reachable", func() {
BeforeEach(func() {
blocky, err = createBlockyContainer(tmpDir,
BeforeEach(func(ctx context.Context) {
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -86,24 +88,24 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(err).Should(HaveOccurred())
DeferCleanup(blocky.Terminate)
})
It("should not start", func() {
It("should not start", func(ctx context.Context) {
Expect(blocky.IsRunning()).Should(BeFalse())
Expect(getContainerLogs(blocky)).
Expect(getContainerLogs(ctx, blocky)).
Should(ContainElement(ContainSubstring("no valid upstream for group default")))
})
})
})
Describe("'upstreams.timeout' parameter handling", func() {
var moka testcontainers.Container
BeforeEach(func() {
moka, err = createDNSMokkaContainer("moka1",
BeforeEach(func(ctx context.Context) {
moka, err = createDNSMokkaContainer(ctx, "moka1",
`A example.com/NOERROR("A 1.2.3.4 123")`,
`A delay.com/delay(NOERROR("A 1.1.1.1 100"), "300ms")`)
Expect(err).Should(Succeed())
DeferCleanup(moka.Terminate)
blocky, err = createBlockyContainer(tmpDir,
blocky, err = createBlockyContainer(ctx, tmpDir,
"upstreams:",
" groups:",
" default:",
@ -114,10 +116,10 @@ var _ = Describe("Upstream resolver configuration tests", func() {
Expect(err).Should(Succeed())
DeferCleanup(blocky.Terminate)
})
It("should consider the timeout parameter", func() {
It("should consider the timeout parameter", func(ctx context.Context) {
By("query without timeout", func() {
msg := util.NewMsgWithQuestion("example.com.", A)
Expect(doDNSRequest(blocky, msg)).
Expect(doDNSRequest(ctx, blocky, msg)).
Should(
SatisfyAll(
BeDNSRecord("example.com.", A, "1.2.3.4"),
@ -128,7 +130,7 @@ var _ = Describe("Upstream resolver configuration tests", func() {
By("query with timeout", func() {
msg := util.NewMsgWithQuestion("delay.com.", A)
resp, err := doDNSRequest(blocky, msg)
resp, err := doDNSRequest(ctx, blocky, msg)
Expect(err).Should(Succeed())
Expect(resp.Rcode).Should(Equal(dns.RcodeServerFailure))
})

View File

@ -1,6 +1,7 @@
package lists
import (
"context"
"errors"
"fmt"
"io"
@ -27,7 +28,7 @@ func (e *TransientError) Unwrap() error {
// FileDownloader is able to download some text file
type FileDownloader interface {
DownloadFile(link string) (io.ReadCloser, error)
DownloadFile(ctx context.Context, link string) (io.ReadCloser, error)
}
// httpDownloader downloads files via HTTP protocol
@ -52,12 +53,17 @@ func newDownloader(cfg config.DownloaderConfig, transport http.RoundTripper) *ht
}
}
func (d *httpDownloader) DownloadFile(link string) (io.ReadCloser, error) {
func (d *httpDownloader) DownloadFile(ctx context.Context, link string) (io.ReadCloser, error) {
var body io.ReadCloser
err := retry.Do(
func() error {
resp, httpErr := d.client.Get(link)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil)
if err != nil {
return err
}
resp, httpErr := d.client.Do(req)
if httpErr == nil {
if resp.StatusCode == http.StatusOK {
body = resp.Body

View File

@ -1,6 +1,7 @@
package lists
import (
"context"
"errors"
"io"
"net"
@ -80,8 +81,8 @@ var _ = Describe("Downloader", func() {
sut = newDownloader(sutConfig, nil)
})
It("Should return all lines from the file", func() {
reader, err := sut.DownloadFile(server.URL)
It("Should return all lines from the file", func(ctx context.Context) {
reader, err := sut.DownloadFile(ctx, server.URL)
Expect(err).Should(Succeed())
Expect(reader).Should(Not(BeNil()))
@ -101,8 +102,8 @@ var _ = Describe("Downloader", func() {
sutConfig.Attempts = 3
})
It("Should return error", func() {
reader, err := sut.DownloadFile(server.URL)
It("Should return error", func(ctx context.Context) {
reader, err := sut.DownloadFile(ctx, server.URL)
Expect(err).Should(HaveOccurred())
Expect(reader).Should(BeNil())
@ -115,8 +116,8 @@ var _ = Describe("Downloader", func() {
BeforeEach(func() {
sutConfig.Attempts = 1
})
It("Should return error", func() {
_, err := sut.DownloadFile("somewrongurl")
It("Should return error", func(ctx context.Context) {
_, err := sut.DownloadFile(ctx, "somewrongurl")
Expect(err).Should(HaveOccurred())
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Can't download file: "))
@ -149,8 +150,8 @@ var _ = Describe("Downloader", func() {
}))
DeferCleanup(server.Close)
})
It("Should perform a retry and return file content", func() {
reader, err := sut.DownloadFile(server.URL)
It("Should perform a retry and return file content", func(ctx context.Context) {
reader, err := sut.DownloadFile(ctx, server.URL)
Expect(err).Should(Succeed())
Expect(reader).Should(Not(BeNil()))
DeferCleanup(reader.Close)
@ -180,17 +181,18 @@ var _ = Describe("Downloader", func() {
}))
DeferCleanup(server.Close)
})
It("Should perform a retry until max retry attempt count is reached and return TransientError", func() {
reader, err := sut.DownloadFile(server.URL)
Expect(err).Should(HaveOccurred())
Expect(errors.As(err, new(*TransientError))).Should(BeTrue())
Expect(err.Error()).Should(ContainSubstring("Timeout"))
Expect(reader).Should(BeNil())
It("Should perform a retry until max retry attempt count is reached and return TransientError",
func(ctx context.Context) {
reader, err := sut.DownloadFile(ctx, server.URL)
Expect(err).Should(HaveOccurred())
Expect(errors.As(err, new(*TransientError))).Should(BeTrue())
Expect(err.Error()).Should(ContainSubstring("Timeout"))
Expect(reader).Should(BeNil())
// failed download event was emitted 3 times
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
})
// failed download event was emitted 3 times
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal(server.URL)))
})
})
When("DNS resolution of passed URL fails", func() {
BeforeEach(func() {
@ -200,19 +202,20 @@ var _ = Describe("Downloader", func() {
Cooldown: 200 * config.Duration(time.Millisecond),
}
})
It("Should perform a retry until max retry attempt count is reached and return DNSError", func() {
reader, err := sut.DownloadFile("http://some.domain.which.does.not.exist")
Expect(err).Should(HaveOccurred())
It("Should perform a retry until max retry attempt count is reached and return DNSError",
func(ctx context.Context) {
reader, err := sut.DownloadFile(ctx, "http://some.domain.which.does.not.exist")
Expect(err).Should(HaveOccurred())
var dnsError *net.DNSError
Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err)
Expect(reader).Should(BeNil())
var dnsError *net.DNSError
Expect(errors.As(err, &dnsError)).Should(BeTrue(), "received error %w", err)
Expect(reader).Should(BeNil())
// failed download event was emitted 3 times
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("http://some.domain.which.does.not.exist")))
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
})
// failed download event was emitted 3 times
Expect(failedDownloadCountEvtChannel).Should(HaveLen(3))
Expect(failedDownloadCountEvtChannel).Should(Receive(Equal("http://some.domain.which.does.not.exist")))
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
})
})
})
})

View File

@ -228,7 +228,7 @@ func (b *ListCache) parseFile(ctx context.Context, opener SourceOpener, resultCh
logger().Debug("starting processing of source")
r, err := opener.Open()
r, err := opener.Open(ctx)
if err != nil {
logger().Error("cannot open source: ", err)

View File

@ -455,7 +455,7 @@ func newMockDownloader(driver func(res chan<- string, err chan<- error)) *MockDo
return &MockDownloader{NewMockCallSequence(driver)}
}
func (m *MockDownloader) DownloadFile(_ string) (io.ReadCloser, error) {
func (m *MockDownloader) DownloadFile(_ context.Context, _ string) (io.ReadCloser, error) {
str, err := m.Call()
if err != nil {
return nil, err

View File

@ -1,6 +1,7 @@
package lists
import (
"context"
"fmt"
"io"
"os"
@ -12,7 +13,7 @@ import (
type SourceOpener interface {
fmt.Stringer
Open() (io.ReadCloser, error)
Open(ctx context.Context) (io.ReadCloser, error)
}
func NewSourceOpener(txtLocInfo string, source config.BytesSource, downloader FileDownloader) (SourceOpener, error) {
@ -35,7 +36,7 @@ type textOpener struct {
locInfo string
}
func (o *textOpener) Open() (io.ReadCloser, error) {
func (o *textOpener) Open(_ context.Context) (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(o.source.From)), nil
}
@ -48,8 +49,8 @@ type httpOpener struct {
downloader FileDownloader
}
func (o *httpOpener) Open() (io.ReadCloser, error) {
return o.downloader.DownloadFile(o.source.From)
func (o *httpOpener) Open(ctx context.Context) (io.ReadCloser, error) {
return o.downloader.DownloadFile(ctx, o.source.From)
}
func (o *httpOpener) String() string {
@ -60,7 +61,7 @@ type fileOpener struct {
source config.BytesSource
}
func (o *fileOpener) Open() (io.ReadCloser, error) {
func (o *fileOpener) Open(_ context.Context) (io.ReadCloser, error) {
return os.Open(o.source.From)
}

View File

@ -212,7 +212,7 @@ func (r *HostsFileResolver) loadSources(ctx context.Context) error {
func (r *HostsFileResolver) parseFile(
ctx context.Context, opener lists.SourceOpener, hostsChan chan<- *HostsFileEntry,
) error {
reader, err := opener.Open()
reader, err := opener.Open(ctx)
if err != nil {
return err
}