Compare commits

...

19 Commits

Author SHA1 Message Date
ThinkChaos 3cd36e3ad4
Merge 74b8931998 into 4ebe1ef21a 2024-04-26 15:47:39 -04:00
dependabot[bot] 4ebe1ef21a
build(deps): bump github.com/miekg/dns from 1.1.58 to 1.1.59 (#1452)
Bumps [github.com/miekg/dns](https://github.com/miekg/dns) from 1.1.58 to 1.1.59.
- [Changelog](https://github.com/miekg/dns/blob/master/Makefile.release)
- [Commits](https://github.com/miekg/dns/compare/v1.1.58...v1.1.59)

---
updated-dependencies:
- dependency-name: github.com/miekg/dns
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-26 12:49:13 +02:00
dependabot[bot] 7f20d17d2e
build(deps): bump github.com/onsi/gomega from 1.32.0 to 1.33.0 (#1455)
Bumps [github.com/onsi/gomega](https://github.com/onsi/gomega) from 1.32.0 to 1.33.0.
- [Release notes](https://github.com/onsi/gomega/releases)
- [Changelog](https://github.com/onsi/gomega/blob/master/CHANGELOG.md)
- [Commits](https://github.com/onsi/gomega/compare/v1.32.0...v1.33.0)

---
updated-dependencies:
- dependency-name: github.com/onsi/gomega
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-26 12:48:40 +02:00
dependabot[bot] cbbe8d46f0
build(deps): bump github.com/avast/retry-go/v4 from 4.5.1 to 4.6.0 (#1456)
Bumps [github.com/avast/retry-go/v4](https://github.com/avast/retry-go) from 4.5.1 to 4.6.0.
- [Release notes](https://github.com/avast/retry-go/releases)
- [Commits](https://github.com/avast/retry-go/compare/4.5.1...4.6.0)

---
updated-dependencies:
- dependency-name: github.com/avast/retry-go/v4
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-26 12:48:16 +02:00
dependabot[bot] 62b1354fba
build(deps): bump github.com/docker/docker (#1459)
Bumps [github.com/docker/docker](https://github.com/docker/docker) from 26.0.1+incompatible to 26.1.0+incompatible.
- [Release notes](https://github.com/docker/docker/releases)
- [Commits](https://github.com/docker/docker/compare/v26.0.1...v26.1.0)

---
updated-dependencies:
- dependency-name: github.com/docker/docker
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-26 12:47:46 +02:00
Thomas Anderson e99c98b4c2
feat: log the rule which is the cause of blocking (#1460)
Co-authored-by: ThinkChaos <ThinkChaos@users.noreply.github.com>
2024-04-24 12:58:29 -04:00
ThinkChaos 74b8931998
refactor(service): make `Info` fields private 2024-04-03 17:49:59 -04:00
ThinkChaos 48cad3b786
refactor: add `service.SimpleHTTP` to reduce required boilerplate 2024-04-03 17:48:51 -04:00
ThinkChaos 1515e17f0f
refactor: switch HTTP API to Service pattern 2024-04-03 13:40:04 -04:00
ThinkChaos e1c717be70
refactor: switch metrics to Service pattern 2024-04-03 13:40:04 -04:00
ThinkChaos d7a2952b1d
refactor: switch DoH to Service pattern 2024-04-03 13:40:04 -04:00
ThinkChaos c11f9a1c98
refactor: add `service` package to prepare for split HTTP handling
Package service exposes types to abstract services from the networking.

The idea is that we build a set of services and a set of network
endpoints (Listener). The services are then assigned to endpoints based
on the address(es) they were configured for.

Actual service to endpoint binding is not handled by the abstractions in
this package as it is protocol specific.
The general pattern is to make a "server" that wraps a service, and can
then be started on an endpoint using a `Serve` method,
similar to `http.Server`.

To support exposing multiple compatible services on a single endpoint
(example: DoH + metrics on a single port),
services can implement `Merger`.
2024-04-03 13:40:04 -04:00
ThinkChaos 4b37b404bf
refactor: add `:` prefix to ports during config unmarshaling 2024-04-02 22:06:10 -04:00
ThinkChaos 36d443728d
refactor(server): move middleware setup to `httpServer` 2024-04-02 22:06:09 -04:00
ThinkChaos 35b1c16878
refactor(server): deduplicate HTTP server setup with new `httpServer` 2024-04-02 20:20:53 -04:00
ThinkChaos 17b2a94a64
refactor(server): setup TLS listeners manually to remove `ServeTLS` use 2024-04-02 20:20:12 -04:00
ThinkChaos c6e3de4ae0
refactor(server): deduplicate `tls.Config` setup 2024-04-02 20:20:12 -04:00
ThinkChaos c389a4a0f4
refactor(server): simplify HTTP router setup 2024-04-02 20:19:27 -04:00
ThinkChaos 7d510d009b
fix(server): typo causing HTTPS router to be used for HTTP server 2024-04-02 20:19:27 -04:00
26 changed files with 1041 additions and 323 deletions

View File

@ -53,7 +53,7 @@ type CacheControl interface {
FlushCaches(ctx context.Context)
}
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
func registerOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware}
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api")

View File

@ -105,7 +105,7 @@ var _ = Describe("API implementation tests", func() {
Describe("RegisterOpenAPIEndpoints", func() {
It("adds routes", func() {
rtr := chi.NewRouter()
RegisterOpenAPIEndpoints(rtr, sut)
registerOpenAPIEndpoints(rtr, sut)
Expect(rtr.Routes()).ShouldNot(BeEmpty())
})

27
api/service.go Normal file
View File

@ -0,0 +1,27 @@
package api
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
)
// Service implements service.HTTPService.
type Service struct {
service.SimpleHTTP
}
func NewService(cfg config.APIService, server StrictServerInterface) *Service {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
)
s := &Service{
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
}
registerOpenAPIEndpoints(s.Router(), server)
return s
}

View File

@ -50,7 +50,12 @@ func (cache stringMap) contains(searchString string) bool {
})
if idx < searchBucketLen {
return cache[searchLen][idx*searchLen:idx*searchLen+searchLen] == strings.ToLower(normalized)
blockRule := cache[searchLen][idx*searchLen : idx*searchLen+searchLen]
if blockRule == normalized {
log.PrefixedLog("string_map").Debugf("block rule '%s' matched with '%s'", blockRule, searchString)
return true
}
}
return false
@ -132,7 +137,7 @@ func (cache regexCache) elementCount() int {
func (cache regexCache) contains(searchString string) bool {
for _, regex := range cache {
if regex.MatchString(searchString) {
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
log.PrefixedLog("regex_cache").Debugf("regex '%s' matched with '%s'", regex, searchString)
return true
}

View File

@ -10,6 +10,7 @@ import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/server"
"github.com/0xERR0R/blocky/util"
@ -47,6 +48,10 @@ func startServer(_ *cobra.Command, _ []string) error {
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
if cfg.Prometheus.Enable {
metrics.StartCollection()
}
srv, err := server.NewServer(ctx, cfg)
if err != nil {
return fmt.Errorf("can't start server: %w", err)

View File

@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
*l = strings.Split(addresses, ",")
// Prefix all ports with :
for i, addr := range *l {
if !strings.ContainsRune(addr, ':') {
(*l)[i] = ":" + addr
}
}
return nil
}
@ -226,6 +233,7 @@ type Config struct {
Redis Redis `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports Ports `yaml:"ports"`
Services Services `yaml:"-"` // not user exposed yet
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
@ -255,6 +263,19 @@ type Config struct {
} `yaml:",inline"`
}
// Services holds network service related configuration.
//
// The actual config layout is not decided yet.
// See https://github.com/0xERR0R/blocky/issues/1206
//
// The `yaml` struct tags are just for manual testing,
// and require replacing `yaml:"-"` in Config to work.
type Services struct {
API APIService `yaml:"control-api"`
DoH DoHService `yaml:"dns-over-https"`
Metrics MetricsService `yaml:"metrics"`
}
type Ports struct {
DNS ListenConfig `yaml:"dns" default:"53"`
HTTP ListenConfig `yaml:"http"`
@ -594,6 +615,23 @@ func (cfg *Config) validate(logger *logrus.Entry) {
cfg.Upstreams.validate(logger)
}
// CopyPortsToServices sets Services values to match Ports.
//
// This should be replaced with a migration once everything from Ports is supported in Services.
// Done this way for now to avoid creating temporary generic services and updating all Ports related code at once.
func (cfg *Config) CopyPortsToServices() {
httpAddrs := httpAddrs{
HTTPAddrs: HTTPAddrs{HTTP: cfg.Ports.HTTP},
HTTPSAddrs: HTTPSAddrs{HTTPS: cfg.Ports.HTTPS},
}
cfg.Services = Services{
API: APIService{Addrs: httpAddrs},
DoH: DoHService{Addrs: httpAddrs},
Metrics: MetricsService{Addrs: httpAddrs},
}
}
// ConvertPort converts string representation into a valid port (0 - 65535)
func ConvertPort(in string) (uint16, error) {
const (

View File

@ -462,7 +462,7 @@ bootstrapDns:
err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56"))
Expect(*l).Should(ContainElements(":55", ":56"))
})
})
})
@ -958,7 +958,7 @@ bootstrapDns:
})
func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))

25
config/doh_service.go Normal file
View File

@ -0,0 +1,25 @@
package config
type (
APIService httpService
DoHService httpService
MetricsService httpService
)
// httpService can be used by any service that uses HTTP(S).
type httpService struct {
Addrs httpAddrs `yaml:"addrs"`
}
type httpAddrs struct {
HTTPAddrs `yaml:",inline"`
HTTPSAddrs `yaml:",inline"`
}
type HTTPAddrs struct {
HTTP ListenConfig `yaml:"http"`
}
type HTTPSAddrs struct {
HTTPS ListenConfig `yaml:"https"`
}

8
go.mod
View File

@ -6,7 +6,7 @@ require (
github.com/abice/go-enum v0.6.0
github.com/alicebob/miniredis/v2 v2.32.1
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef
github.com/avast/retry-go/v4 v4.5.1
github.com/avast/retry-go/v4 v4.6.0
github.com/creasty/defaults v1.7.0
github.com/go-chi/chi/v5 v5.0.12
github.com/go-chi/cors v1.2.1
@ -17,10 +17,10 @@ require (
github.com/hashicorp/golang-lru v1.0.2
github.com/mattn/go-colorable v0.1.13
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
github.com/miekg/dns v1.1.58
github.com/miekg/dns v1.1.59
github.com/mroth/weightedrand/v2 v2.1.0
github.com/onsi/ginkgo/v2 v2.17.1
github.com/onsi/gomega v1.32.0
github.com/onsi/gomega v1.33.0
github.com/prometheus/client_golang v1.19.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.0
@ -38,7 +38,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef
github.com/deepmap/oapi-codegen v1.16.2
github.com/docker/docker v26.0.1+incompatible
github.com/docker/docker v26.1.0+incompatible
github.com/docker/go-connections v0.5.0
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
github.com/oapi-codegen/runtime v1.1.1

16
go.sum
View File

@ -29,8 +29,8 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7D
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef h1:2JGTg6JapxP9/R33ZaagQtAM4EkkSYnIAlOG5EI8gkM=
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef/go.mod h1:JS7hed4L1fj0hXcyEejnW57/7LCetXggd+vwrRnYeII=
github.com/avast/retry-go/v4 v4.5.1 h1:AxIx0HGi4VZ3I02jr78j5lZ3M6x1E0Ivxa6b0pUUh7o=
github.com/avast/retry-go/v4 v4.5.1/go.mod h1:/sipNsvNB3RRuT5iNcb6h73nw3IBmXJ/H3XrCQYSOpc=
github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA=
github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
@ -64,8 +64,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0=
github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v26.0.1+incompatible h1:t39Hm6lpXuXtgkF0dm1t9a5HkbUfdGy6XbWexmGr+hA=
github.com/docker/docker v26.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker v26.1.0+incompatible h1:W1G9MPNbskA6VZWL7b3ZljTh0pXI68FpINx0GKaOdaM=
github.com/docker/docker v26.1.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@ -193,8 +193,8 @@ github.com/mattn/goveralls v0.0.12 h1:PEEeF0k1SsTjOBQ8FOmrOAoCu4ytuMaWCnWe94zxbC
github.com/mattn/goveralls v0.0.12/go.mod h1:44ImGEUfmqH8bBtaMrYKsM65LXfNLWmwaxFGjZwgMSQ=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
@ -225,8 +225,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8=
github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs=
github.com/onsi/gomega v1.32.0 h1:JRYU78fJ1LPxlckP6Txi/EYqJvjtMrDC04/MM5XRHPk=
github.com/onsi/gomega v1.32.0/go.mod h1:a4x4gW6Pz2yK1MAmvluYme5lvYTn61afQ2ETw/8n4Lg=
github.com/onsi/gomega v1.33.0 h1:snPCflnZrpMsy94p4lXVEkHo12lmPnc3vY5XBbreexE=
github.com/onsi/gomega v1.33.0/go.mod h1:+925n5YtiFsLzzafLUHzVMBpvvRAzrydIBiSIxjX3wY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
@ -29,20 +30,27 @@ const (
DS = dns.Type(dns.TypeDS)
)
// GetIntPort returns an port for the current testing
// GetIntPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as int
// the base port and returning it as int.
func GetIntPort(port int) int {
return port + ginkgo.GinkgoParallelProcess()
}
// GetStringPort returns an port for the current testing
// GetStringPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string
// the base port and returning it as string.
func GetStringPort(port int) string {
return fmt.Sprintf("%d", GetIntPort(port))
}
// GetHostPort returns a host:port string for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string.
func GetHostPort(host string, port int) string {
return net.JoinHostPort(host, GetStringPort(port))
}
// TempFile creates temp file with passed data
func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix")

View File

@ -1,12 +1,8 @@
package metrics
import (
"github.com/0xERR0R/blocky/config"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
//nolint:gochecknoglobals
@ -17,12 +13,9 @@ func RegisterMetric(c prometheus.Collector) {
_ = reg.Register(c)
}
// Start starts prometheus endpoint
func Start(router *chi.Mux, cfg config.Metrics) {
if cfg.Enable {
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
_ = reg.Register(collectors.NewGoCollector())
router.Handle(cfg.Path, promhttp.InstrumentMetricHandler(reg,
promhttp.HandlerFor(reg, promhttp.HandlerOpts{})))
}
func StartCollection() {
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
_ = reg.Register(collectors.NewGoCollector())
registerEventListeners()
}

View File

@ -11,8 +11,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
// RegisterEventListeners registers all metric handlers by the event bus
func RegisterEventListeners() {
// registerEventListeners registers all metric handlers on the event bus
func registerEventListeners() {
registerBlockingEventListeners()
registerCachingEventListeners()
registerApplicationEventListeners()

36
metrics/service.go Normal file
View File

@ -0,0 +1,36 @@
package metrics
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Service implements service.HTTPService.
type Service struct {
service.SimpleHTTP
}
func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
)
if !metricsCfg.Enable || len(endpoints) == 0 {
// Avoid setting up collectors and listeners
return new(Service)
}
s := &Service{
SimpleHTTP: service.NewSimpleHTTP("Metrics", endpoints),
}
s.Router().Handle(
metricsCfg.Path,
promhttp.InstrumentMetricHandler(reg, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})),
)
return s
}

126
server/doh.go Normal file
View File

@ -0,0 +1,126 @@
package server
import (
"encoding/base64"
"io"
"net/http"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/miekg/dns"
)
type dohService struct {
service.SimpleHTTP
handler dnsHandler
}
func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
)
s := &dohService{
SimpleHTTP: service.NewSimpleHTTP("DoH", endpoints),
handler: handler,
}
s.Router().Route("/dns-query", func(mux chi.Router) {
// Handlers for / also handle /dns-query without trailing slash
mux.Get("/", s.handleGET)
mux.Get("/{clientID}", s.handleGET)
mux.Post("/", s.handlePOST)
mux.Post("/{clientID}", s.handlePOST)
})
return s
}
func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)
return
}
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
if err != nil {
http.Error(rw, "wrong message format", http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
return
}
rawMsg, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
s.handler(ctx, dnsReq, httpMsgWriter{rw})
}
type httpMsgWriter struct {
rw http.ResponseWriter
}
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}
r.rw.Header().Set("content-type", dnsContentType)
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)
_, err = r.rw.Write(b)
return err
}

119
server/http.go Normal file
View File

@ -0,0 +1,119 @@
package server
import (
"context"
"net"
"net/http"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
)
// httpMiscService implements service.HTTPService.
//
// This supports the existing single HTTP/HTTPS endpoints
// that expose everything. The goal is to split it up
// and remove it.
type httpMiscService struct {
service.SimpleHTTP
}
func newHTTPMiscService(cfg *config.Config) *httpMiscService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
)
s := &httpMiscService{
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
}
configureHTTPRouter(s.Router(), cfg)
return s
}
// httpServer implements subServer for HTTP.
type httpServer struct {
service.HTTPService
inner http.Server
}
func newHTTPServer(svc service.HTTPService) *httpServer {
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)
return &httpServer{
HTTPService: svc,
inner: http.Server{
Handler: withCommonMiddleware(svc.Router()),
ReadHeaderTimeout: readHeaderTimeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
},
}
}
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
go func() {
<-ctx.Done()
s.inner.Close()
}()
return s.inner.Serve(l)
}
func withCommonMiddleware(inner http.Handler) *chi.Mux {
// Middleware must be defined before routes, so
// create a new router and mount the inner handler
mux := chi.NewMux()
mux.Use(
secureHeadersMiddleware,
newCORSMiddleware(),
)
mux.Mount("/", inner)
return mux
}
type httpMiddleware = func(http.Handler) http.Handler
func secureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
}
next.ServeHTTP(w, r)
})
}
func newCORSMiddleware() httpMiddleware {
const corsMaxAge = 5 * time.Minute
options := cors.Options{
AllowCredentials: true,
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
AllowedMethods: []string{"GET", "POST"},
AllowedOrigins: []string{"*"},
ExposedHeaders: []string{"Link"},
MaxAge: int(corsMaxAge.Seconds()),
}
return cors.New(options).Handler
}

View File

@ -18,15 +18,20 @@ import (
"net/http"
"runtime"
"runtime/debug"
"slices"
"strings"
"time"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/service"
"golang.org/x/exp/maps"
"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
@ -44,14 +49,18 @@ const (
// Server controls the endpoints for DNS and HTTP
type Server struct {
dnsServers []*dns.Server
httpListeners []net.Listener
httpsListeners []net.Listener
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
httpsMux *chi.Mux
cert tls.Certificate
dnsServers []*dns.Server
queryResolver resolver.ChainedResolver
cfg *config.Config
services map[service.Listener]service.Service
}
type subServer interface {
fmt.Stringer
service.Service
Serve(context.Context, net.Listener) error
}
func logger() *logrus.Entry {
@ -71,14 +80,6 @@ func tlsCipherSuites() []uint16 {
return tlsCipherSuites
}
func getServerAddress(addr string) string {
if !strings.Contains(addr, ":") {
addr = fmt.Sprintf(":%s", addr)
}
return addr
}
type NewServerFunc func(address string) (*dns.Server, error)
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
@ -99,39 +100,47 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
return
}
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
var cert tls.Certificate
cert, err := retrieveCertificate(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
}
// #nosec G402 // See TLSVersion.validate
res := &tls.Config{
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{cert},
}
return res, nil
}
// NewServer creates new server instance with passed config
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
cfg.CopyPortsToServices()
var tlsCfg *tls.Config
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
cert, err = retrieveCertificate(cfg)
tlsCfg, err = newTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
return nil, err
}
}
dnsServers, err := createServers(cfg, cert)
dnsServers, err := createServers(cfg, tlsCfg)
if err != nil {
return nil, fmt.Errorf("server creation failed: %w", err)
}
httpRouter := createHTTPRouter(cfg)
httpsRouter := createHTTPSRouter(cfg)
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
listeners, err := createListeners(ctx, cfg, tlsCfg)
if err != nil {
return nil, err
}
if len(httpListeners) != 0 || len(httpsListeners) != 0 {
metrics.Start(httpRouter, cfg.Prometheus)
metrics.Start(httpsRouter, cfg.Prometheus)
}
metrics.RegisterEventListeners()
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
if err != nil {
return nil, err
@ -151,27 +160,21 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
}
server = &Server{
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
httpMux: httpRouter,
httpsMux: httpsRouter,
cert: cert,
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
}
server.printConfiguration()
server.registerDNSHandlers(ctx)
err = server.registerAPIEndpoints(httpRouter)
services, err := server.createServices()
if err != nil {
return nil, err
}
err = server.registerAPIEndpoints(httpsRouter)
server.services, err = service.GroupByListener(services, listeners)
if err != nil {
return nil, err
}
@ -179,14 +182,35 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return server, err
}
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) {
func (s *Server) createServices() ([]service.Service, error) {
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return nil, err
}
res := []service.Service{
newHTTPMiscService(s.cfg),
newDoHService(s.cfg.Services.DoH, s.handleReq),
api.NewService(s.cfg.Services.API, openAPIImpl),
metrics.NewService(s.cfg.Services.Metrics, s.cfg.Prometheus),
}
// Remove services the user has not enabled
res = slices.DeleteFunc(res, func(svc service.Service) bool {
return len(svc.ExposeOn()) == 0
})
return res, nil
}
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
var dnsServers []*dns.Server
var err *multierror.Error
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
for _, address := range addresses {
server, err := newServer(getServerAddress(address))
server, err := newServer(address)
if err != nil {
return err
}
@ -201,52 +225,69 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
addServers(createUDPServer, cfg.Ports.DNS),
addServers(createTCPServer, cfg.Ports.DNS),
addServers(func(address string) (*dns.Server, error) {
return createTLSServer(cfg, address, cert)
return createTLSServer(address, tlsCfg)
}, cfg.Ports.TLS))
return dnsServers, err.ErrorOrNil()
}
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
if err != nil {
return nil, nil, err
func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) {
res := make(map[string]service.Listener)
listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) {
return service.ListenTLS(ctx, endpoint, tlsCfg)
}
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
err := errors.Join(
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
newListeners(ctx, service.HTTPProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res),
newListeners(ctx, service.HTTPProtocol, cfg.Services.Metrics.Addrs.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Services.Metrics.Addrs.HTTPS, listenTLS, res),
)
if err != nil {
return nil, nil, err
return nil, err
}
return httpListeners, httpsListeners, nil
return maps.Values(res), nil
}
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses))
type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error)
for _, address := range addresses {
listener, err := net.Listen("tcp", getServerAddress(address))
if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
func newListeners[T service.Listener](
ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener,
) error {
for _, addr := range addrs {
key := fmt.Sprintf("%s:%s", proto, addr)
if _, ok := out[key]; ok {
// Avoid "address already in use"
// We instead try to merge services, see services.GroupByListener
continue
}
listeners = append(listeners, listener)
endpoint := service.Endpoint{
Protocol: proto,
AddrConf: addr,
}
l, err := listen(ctx, endpoint)
if err != nil {
return err // already has all info
}
out[key] = l
}
return listeners, nil
return nil
}
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) {
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{
Addr: address,
Net: "tcp-tls",
//nolint:gosec
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
},
Handler: dns.NewServeMux(),
Addr: address,
Net: "tcp-tls",
TLSConfig: tlsCfg,
Handler: dns.NewServeMux(),
NotifyStartedFunc: func() {
logger().Infof("TLS server is up and running on address %s", address)
},
@ -470,11 +511,15 @@ func toMB(b uint64) uint64 {
return b / bytesInKB / bytesInKB
}
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)
func newSubServer(svc service.Service) (subServer, error) {
switch svc := svc.(type) {
case service.HTTPService:
return newHTTPServer(svc), nil
default:
return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc)
}
}
// Start starts the server
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
@ -490,48 +535,22 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
}()
}
for i, listener := range s.httpListeners {
listener := listener
address := s.cfg.Ports.HTTP[i]
for listener, svc := range s.services {
listener, svc := listener, svc
srv, err := newSubServer(svc)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err)
return
}
go func() {
logger().Infof("http server is up and running on addr/port %s", address)
logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes())
srv := &http.Server{
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
Handler: s.httpsMux,
}
if err := srv.Serve(listener); err != nil {
errCh <- fmt.Errorf("start http listener failed: %w", err)
}
}()
}
for i, listener := range s.httpsListeners {
listener := listener
address := s.cfg.Ports.HTTPS[i]
go func() {
logger().Infof("https server is up and running on addr/port %s", address)
server := http.Server{
Handler: s.httpsMux,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
//nolint:gosec
TLSConfig: &tls.Config{
MinVersion: uint16(s.cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{s.cert},
},
}
if err := server.ServeTLS(listener, "", ""); err != nil {
errCh <- fmt.Errorf("start https listener failed: %w", err)
err := srv.Serve(ctx, listener)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err)
}
}()
}
@ -630,6 +649,8 @@ type msgWriter interface {
WriteMsg(msg *dns.Msg) error
}
type dnsHandler func(context.Context, *model.Request, msgWriter)
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
response, err := s.resolve(ctx, request)
if err != nil {

View File

@ -2,13 +2,10 @@ package server
import (
"context"
"encoding/base64"
"fmt"
"html/template"
"io"
"net"
"net/http"
"time"
"github.com/0xERR0R/blocky/resolver"
@ -22,7 +19,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/miekg/dns"
)
@ -32,19 +28,8 @@ const (
dnsContentType = "application/dns-message"
htmlContentType = "text/html; charset=UTF-8"
yamlContentType = "text/yaml"
corsMaxAge = 5 * time.Minute
)
func secureHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
next.ServeHTTP(w, r)
})
}
func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) {
bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver)
if err != nil {
@ -64,108 +49,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}
func (s *Server) registerAPIEndpoints(router *chi.Mux) error {
const pathDohQuery = "/dns-query"
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return err
}
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
router.Get(pathDohQuery, s.dohGetRequestHandler)
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
router.Post(pathDohQuery, s.dohPostRequestHandler)
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
return nil
}
func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)
return
}
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
if err != nil {
http.Error(rw, "wrong message format", http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
return
}
rawMsg, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
}
type httpMsgWriter struct {
rw http.ResponseWriter
}
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}
r.rw.Header().Set("content-type", dnsContentType)
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)
_, err = r.rw.Write(b)
return err
}
func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
@ -177,27 +60,7 @@ func (s *Server) Query(
return s.resolve(ctx, req)
}
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
router := chi.NewRouter()
configureSecureHeaderHandler(router)
registerHandlers(cfg, router)
return router
}
func createHTTPRouter(cfg *config.Config) *chi.Mux {
router := chi.NewRouter()
registerHandlers(cfg, router)
return router
}
func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureCorsHandler(router)
func configureHTTPRouter(router chi.Router, cfg *config.Config) {
configureDebugHandler(router)
configureDocsHandler(router)
@ -207,7 +70,7 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureRootHandler(cfg, router)
}
func configureDocsHandler(router *chi.Mux) {
func configureDocsHandler(router chi.Router) {
router.Get("/docs/openapi.yaml", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set(contentTypeHeader, yamlContentType)
_, err := writer.Write([]byte(docs.OpenAPI))
@ -215,7 +78,7 @@ func configureDocsHandler(router *chi.Mux) {
})
}
func configureStaticAssetsHandler(router *chi.Mux) {
func configureStaticAssetsHandler(router chi.Router) {
assets, err := web.Assets()
util.FatalOnError("unable to load static asset files", err)
@ -223,7 +86,7 @@ func configureStaticAssetsHandler(router *chi.Mux) {
router.Handle("/static/*", http.StripPrefix("/static/", fs))
}
func configureRootHandler(cfg *config.Config, router *chi.Mux) {
func configureRootHandler(cfg *config.Config, router chi.Router) {
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set(contentTypeHeader, htmlContentType)
@ -282,22 +145,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit
}
}
func configureSecureHeaderHandler(router *chi.Mux) {
router.Use(secureHeader)
}
func configureDebugHandler(router *chi.Mux) {
func configureDebugHandler(router chi.Router) {
router.Mount("/debug", middleware.Profiler())
}
func configureCorsHandler(router *chi.Mux) {
crs := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: int(corsMaxAge.Seconds()),
})
router.Use(crs.Handler)
}

View File

@ -44,7 +44,7 @@ var (
)
var _ = BeforeSuite(func() {
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
queryURL = baseURL + "dns-query"
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
ctx, cancelFn := context.WithCancel(context.Background())
@ -147,10 +147,10 @@ var _ = BeforeSuite(func() {
},
Ports: config.Ports{
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
},
CertFile: certPem.Path,
KeyFile: keyPem.Path,
@ -634,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})
@ -678,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})
@ -741,17 +741,18 @@ var _ = Describe("Running DNS server", func() {
cfg.KeyFile = ""
cfg.CertFile = ""
cfg.Ports = config.Ports{
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
HTTPS: []string{":0"},
}
sut, err := NewServer(ctx, &cfg)
sut, err := newTLSConfig(&cfg)
Expect(err).Should(Succeed())
Expect(sut.cert.Certificate).ShouldNot(BeNil())
Expect(sut.Certificates).ShouldNot(BeEmpty())
})
})
})
func requestServer(request *dns.Msg) *dns.Msg {
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
if err != nil {
Log().Fatal("could not connect to server: ", err)
}

51
service/endpoint.go Normal file
View File

@ -0,0 +1,51 @@
package service
import (
"fmt"
"slices"
"strings"
"github.com/0xERR0R/blocky/util"
"golang.org/x/exp/maps"
)
// Endpoint is a network endpoint on which to expose a service.
type Endpoint struct {
// Protocol is the protocol to be exposed on this endpoint.
Protocol string
// AddrConf is the network address as configured by the user.
AddrConf string
}
func EndpointsFromAddrs(proto string, addrs []string) []Endpoint {
return util.ForEach(addrs, func(addr string) Endpoint {
return Endpoint{
Protocol: proto,
AddrConf: addr,
}
})
}
func (e Endpoint) String() string {
addr := e.AddrConf
if strings.HasPrefix(addr, ":") {
addr = "*" + addr
}
return fmt.Sprintf("%s://%s", e.Protocol, addr)
}
type endpointSet map[Endpoint]struct{}
func (s endpointSet) ToSlice() []Endpoint {
return maps.Keys(s)
}
func (s endpointSet) IntersectSlice(others []Endpoint) {
for endpoint := range s {
if !slices.Contains(others, endpoint) {
delete(s, endpoint)
}
}
}

123
service/http.go Normal file
View File

@ -0,0 +1,123 @@
package service
import (
"errors"
"net/http"
"strings"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
)
const (
HTTPProtocol = "http"
HTTPSProtocol = "https"
)
// HTTPService is a Service using a HTTP router.
type HTTPService interface {
Service
Merger
// Router returns the service's router.
Router() chi.Router
}
// HTTPInfo can be embedded in structs to help implement HTTPService.
type HTTPInfo struct {
Info
mux *chi.Mux
}
func NewHTTPInfo(name string, endpoints []Endpoint) HTTPInfo {
return HTTPInfo{
Info: NewInfo(name, endpoints),
mux: chi.NewMux(),
}
}
func (i *HTTPInfo) Router() chi.Router { return i.mux }
var _ HTTPService = (*SimpleHTTP)(nil)
// SimpleHTTP implements HTTPService usinig the default HTTP merger.
type SimpleHTTP struct{ HTTPInfo }
func NewSimpleHTTP(name string, endpoints []Endpoint) SimpleHTTP {
return SimpleHTTP{HTTPInfo: NewHTTPInfo(name, endpoints)}
}
func (s *SimpleHTTP) Merge(other Service) (Merger, error) {
return MergeHTTP(s, other)
}
// MergeHTTP merges two compatible HTTPServices.
//
// The second parameter is of type `Service` to make it easy to call
// from a `Merger.Merge` implementation.
func MergeHTTP(a HTTPService, b Service) (Merger, error) {
return newHTTPMerger(a).Merge(b)
}
var _ HTTPService = (*httpMerger)(nil)
// httpMerger can merge HTTPServices by combining their routes.
type httpMerger struct {
inner []HTTPService
router chi.Router
endpoints endpointSet
}
func newHTTPMerger(first HTTPService) *httpMerger {
return &httpMerger{
inner: []HTTPService{first},
router: first.Router(),
}
}
func (m *httpMerger) String() string { return svcString(m) }
func (m *httpMerger) ServiceName() string {
names := util.ForEach(m.inner, func(svc HTTPService) string {
return svc.ServiceName()
})
return strings.Join(names, " & ")
}
func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() }
func (m *httpMerger) Router() chi.Router { return m.router }
func (m *httpMerger) Merge(other Service) (Merger, error) {
httpSvc, ok := other.(HTTPService)
if !ok {
return nil, errors.New("not an HTTPService")
}
type middleware = func(http.Handler) http.Handler
// Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
m.router.With(middlewares...).Method(method, route, handler)
// Expose /example/ as /example too
// Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830
// This means we expose the route without the slash even if it wasn't oringinally registered as such!
// The main point of this is for DoH (/dns-query).
if strings.HasSuffix(route, "/") {
route := strings.TrimSuffix(route, "/")
m.router.With(middlewares...).Method(method, route, handler)
}
return nil
})
m.inner = append(m.inner, httpSvc)
// Don't expose any service more than it expects
m.endpoints.IntersectSlice(other.ExposeOn())
return m, nil
}

74
service/listener.go Normal file
View File

@ -0,0 +1,74 @@
package service
import (
"context"
"crypto/tls"
"fmt"
"net"
)
// Listener is a net.Listener that provides information about
// what protocol and address it is configured for.
type Listener interface {
fmt.Stringer
net.Listener
// Exposes returns the endpoint for this listener.
//
// It can be used to find service(s) with matching configuration.
Exposes() Endpoint
}
// ListenerInfo can be embedded in structs to help implement Listener.
type ListenerInfo struct {
Endpoint
}
func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint }
// NetListener implements Listener using an existing net.Listener.
type NetListener struct {
net.Listener
ListenerInfo
}
func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener {
return &NetListener{
Listener: inner,
ListenerInfo: ListenerInfo{endpoint},
}
}
// TCPListener is a Listener for a TCP socket.
type TCPListener struct{ NetListener }
// ListenTCP creates a new TCPListener.
func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf)
if err != nil {
return nil, err // err already has all the info we could add
}
inner := NewNetListener(endpoint, l)
return &TCPListener{*inner}, nil
}
// TLSListener is a Listener using TLS over TCP.
type TLSListener struct{ NetListener }
// ListenTLS creates a new TLSListener.
func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) {
tcp, err := ListenTCP(ctx, endpoint)
if err != nil {
return nil, err
}
inner := tcp.NetListener
inner.Listener = tls.NewListener(inner.Listener, cfg)
return &TLSListener{inner}, nil
}

57
service/merge.go Normal file
View File

@ -0,0 +1,57 @@
package service
import "errors"
// Merger is a Service that can be merged with another compatible one.
type Merger interface {
Service
// Merge returns the result of merging the receiver with the other Service.
//
// Neither the receiver, nor the other Service should be used directly after
// calling this method.
Merge(other Service) (Merger, error)
}
// MergeAll merges the given services, if they are compatible.
//
// This allows using multiple compatible services with a single listener.
//
// All passed-in services must not be re-used.
func MergeAll(services []Service) (Service, error) {
switch len(services) {
case 0:
return nil, errors.New("no services given")
case 1:
return services[0], nil
}
merger, err := firstMerger(services)
if err != nil {
return nil, err
}
for _, svc := range services {
if svc == merger {
continue
}
merger, err = merger.Merge(svc)
if err != nil {
return nil, err
}
}
return merger, nil
}
func firstMerger(services []Service) (Merger, error) {
for _, t := range services {
if svc, ok := t.(Merger); ok {
return svc, nil
}
}
return nil, errors.New("no merger found")
}

113
service/service.go Normal file
View File

@ -0,0 +1,113 @@
// Package service exposes types to abstract services from the networking.
//
// The idea is that we build a set of services and a set of network endpoints (Listener).
// The services are then assigned to endpoints based on the address(es) they were configured for.
//
// Actual service to endpoint binding is not handled by the abstractions in this package as it is
// protocol specific.
// The general pattern is to make a "server" that wraps a service, and can then be started on an
// endpoint using a `Serve` method, similar to `http.Server`.
//
// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port),
// services can implement `Merger`.
package service
import (
"fmt"
"slices"
"strings"
"github.com/0xERR0R/blocky/util"
)
// Service is a network exposed service.
//
// It contains only the logic and user configured addresses it should be exposed on.
// Is is meant to be associated to one or more sockets via those addresses.
// Actual association with a socket is protocol specific.
type Service interface {
fmt.Stringer
// ServiceName returns the user friendly name of the service.
ServiceName() string
// ExposeOn returns the set of endpoints the service should be exposed on.
//
// They can be used to find listener(s) with matching configuration.
ExposeOn() []Endpoint
}
func svcString(s Service) string {
endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() })
return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", "))
}
// Info can be embedded in structs to help implement Service.
type Info struct {
name string
endpoints []Endpoint
}
func NewInfo(name string, endpoints []Endpoint) Info {
return Info{
name: name,
endpoints: endpoints,
}
}
func (i *Info) ServiceName() string { return i.name }
func (i *Info) ExposeOn() []Endpoint { return i.endpoints }
func (i *Info) String() string { return svcString(i) }
// GroupByListener returns a map of listener and services grouped by configured address.
//
// Each input listener is a key in the map. The corresponding value is a service
// merged from all services with a matching address.
func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) {
res := make(map[Listener]Service, len(listeners))
unused := slices.Clone(services)
for _, listener := range listeners {
services := findAllCompatible(services, listener.Exposes())
if len(services) == 0 {
return nil, fmt.Errorf("found no compatible services for listener %s", listener)
}
svc, err := MergeAll(services)
if err != nil {
return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err)
}
res[listener] = svc
for _, svc := range services {
if i := slices.Index(unused, svc); i != -1 {
unused = slices.Delete(unused, i, i+1)
}
}
}
if len(unused) != 0 {
return nil, fmt.Errorf("found no compatible listener for services: %v", unused)
}
return res, nil
}
// findAllCompatible returns the subset of services that use the given Listener.
func findAllCompatible(services []Service, endpoint Endpoint) []Service {
res := make([]Service, 0, len(services))
for _, svc := range services {
if isExposedOn(svc, endpoint) {
res = append(res, svc)
}
}
return res
}
func isExposedOn(svc Service, endpoint Endpoint) bool {
return slices.Index(svc.ExposeOn(), endpoint) != -1
}

View File

@ -1,5 +1,10 @@
package trie
import (
"github.com/0xERR0R/blocky/log"
"strings"
)
// Trie stores a set of strings and can quickly check
// if it contains an element, or one of its parents.
//
@ -108,8 +113,12 @@ func (n *parent) insert(key string, split SplitFunc) {
}
func (n *parent) hasParentOf(key string, split SplitFunc) bool {
searchString := key
rule := ""
for {
label, rest := split(key)
rule = strings.Join([]string{label, rule}, ".")
child, ok := n.children[label]
if !ok {
@ -132,7 +141,14 @@ func (n *parent) hasParentOf(key string, split SplitFunc) bool {
case terminal:
// Continue down the trie
return child.hasParentOf(rest, split)
matched := child.hasParentOf(rest, split)
if matched {
rule = strings.Join([]string{child.String(), rule}, ".")
rule = strings.Trim(rule, ".")
log.PrefixedLog("trie").Debugf("wildcard block rule '%s' matched with '%s'", rule, searchString)
}
return matched
}
}
}

33
util/slices.go Normal file
View File

@ -0,0 +1,33 @@
package util
// ForEach implements the functional map operation, under a different
// name to avoid confusion with Go's map type.
func ForEach[T, U any](slice []T, convert func(T) U) []U {
res := make([]U, 0, len(slice))
for _, t := range slice {
u := convert(t)
res = append(res, u)
}
return res
}
// ConcatSlices returns a new slice with contents of all the inputs concatenated.
func ConcatSlices[T any](slices ...[]T) []T {
// Allocation is usually the bottleneck, so do it all at once
totalLen := 0
for _, slice := range slices {
totalLen += len(slice)
}
res := make([]T, 0, totalLen)
for _, slice := range slices {
res = append(res, slice...)
}
return res
}