mirror of https://github.com/0xERR0R/blocky.git
Compare commits
19 Commits
c206364713
...
3cd36e3ad4
Author | SHA1 | Date |
---|---|---|
ThinkChaos | 3cd36e3ad4 | |
dependabot[bot] | 4ebe1ef21a | |
dependabot[bot] | 7f20d17d2e | |
dependabot[bot] | cbbe8d46f0 | |
dependabot[bot] | 62b1354fba | |
Thomas Anderson | e99c98b4c2 | |
ThinkChaos | 74b8931998 | |
ThinkChaos | 48cad3b786 | |
ThinkChaos | 1515e17f0f | |
ThinkChaos | e1c717be70 | |
ThinkChaos | d7a2952b1d | |
ThinkChaos | c11f9a1c98 | |
ThinkChaos | 4b37b404bf | |
ThinkChaos | 36d443728d | |
ThinkChaos | 35b1c16878 | |
ThinkChaos | 17b2a94a64 | |
ThinkChaos | c6e3de4ae0 | |
ThinkChaos | c389a4a0f4 | |
ThinkChaos | 7d510d009b |
|
@ -53,7 +53,7 @@ type CacheControl interface {
|
|||
FlushCaches(ctx context.Context)
|
||||
}
|
||||
|
||||
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
||||
func registerOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
||||
middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware}
|
||||
|
||||
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api")
|
||||
|
|
|
@ -105,7 +105,7 @@ var _ = Describe("API implementation tests", func() {
|
|||
Describe("RegisterOpenAPIEndpoints", func() {
|
||||
It("adds routes", func() {
|
||||
rtr := chi.NewRouter()
|
||||
RegisterOpenAPIEndpoints(rtr, sut)
|
||||
registerOpenAPIEndpoints(rtr, sut)
|
||||
|
||||
Expect(rtr.Routes()).ShouldNot(BeEmpty())
|
||||
})
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/service"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
)
|
||||
|
||||
// Service implements service.HTTPService.
|
||||
type Service struct {
|
||||
service.SimpleHTTP
|
||||
}
|
||||
|
||||
func NewService(cfg config.APIService, server StrictServerInterface) *Service {
|
||||
endpoints := util.ConcatSlices(
|
||||
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||
)
|
||||
|
||||
s := &Service{
|
||||
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
|
||||
}
|
||||
|
||||
registerOpenAPIEndpoints(s.Router(), server)
|
||||
|
||||
return s
|
||||
}
|
|
@ -50,7 +50,12 @@ func (cache stringMap) contains(searchString string) bool {
|
|||
})
|
||||
|
||||
if idx < searchBucketLen {
|
||||
return cache[searchLen][idx*searchLen:idx*searchLen+searchLen] == strings.ToLower(normalized)
|
||||
blockRule := cache[searchLen][idx*searchLen : idx*searchLen+searchLen]
|
||||
if blockRule == normalized {
|
||||
log.PrefixedLog("string_map").Debugf("block rule '%s' matched with '%s'", blockRule, searchString)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
@ -132,7 +137,7 @@ func (cache regexCache) elementCount() int {
|
|||
func (cache regexCache) contains(searchString string) bool {
|
||||
for _, regex := range cache {
|
||||
if regex.MatchString(searchString) {
|
||||
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
||||
log.PrefixedLog("regex_cache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/evt"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/metrics"
|
||||
"github.com/0xERR0R/blocky/server"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
|
||||
|
@ -47,6 +48,10 @@ func startServer(_ *cobra.Command, _ []string) error {
|
|||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
defer cancelFn()
|
||||
|
||||
if cfg.Prometheus.Enable {
|
||||
metrics.StartCollection()
|
||||
}
|
||||
|
||||
srv, err := server.NewServer(ctx, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't start server: %w", err)
|
||||
|
|
|
@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
|
|||
|
||||
*l = strings.Split(addresses, ",")
|
||||
|
||||
// Prefix all ports with :
|
||||
for i, addr := range *l {
|
||||
if !strings.ContainsRune(addr, ':') {
|
||||
(*l)[i] = ":" + addr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -226,6 +233,7 @@ type Config struct {
|
|||
Redis Redis `yaml:"redis"`
|
||||
Log log.Config `yaml:"log"`
|
||||
Ports Ports `yaml:"ports"`
|
||||
Services Services `yaml:"-"` // not user exposed yet
|
||||
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
||||
CertFile string `yaml:"certFile"`
|
||||
KeyFile string `yaml:"keyFile"`
|
||||
|
@ -255,6 +263,19 @@ type Config struct {
|
|||
} `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Services holds network service related configuration.
|
||||
//
|
||||
// The actual config layout is not decided yet.
|
||||
// See https://github.com/0xERR0R/blocky/issues/1206
|
||||
//
|
||||
// The `yaml` struct tags are just for manual testing,
|
||||
// and require replacing `yaml:"-"` in Config to work.
|
||||
type Services struct {
|
||||
API APIService `yaml:"control-api"`
|
||||
DoH DoHService `yaml:"dns-over-https"`
|
||||
Metrics MetricsService `yaml:"metrics"`
|
||||
}
|
||||
|
||||
type Ports struct {
|
||||
DNS ListenConfig `yaml:"dns" default:"53"`
|
||||
HTTP ListenConfig `yaml:"http"`
|
||||
|
@ -594,6 +615,23 @@ func (cfg *Config) validate(logger *logrus.Entry) {
|
|||
cfg.Upstreams.validate(logger)
|
||||
}
|
||||
|
||||
// CopyPortsToServices sets Services values to match Ports.
|
||||
//
|
||||
// This should be replaced with a migration once everything from Ports is supported in Services.
|
||||
// Done this way for now to avoid creating temporary generic services and updating all Ports related code at once.
|
||||
func (cfg *Config) CopyPortsToServices() {
|
||||
httpAddrs := httpAddrs{
|
||||
HTTPAddrs: HTTPAddrs{HTTP: cfg.Ports.HTTP},
|
||||
HTTPSAddrs: HTTPSAddrs{HTTPS: cfg.Ports.HTTPS},
|
||||
}
|
||||
|
||||
cfg.Services = Services{
|
||||
API: APIService{Addrs: httpAddrs},
|
||||
DoH: DoHService{Addrs: httpAddrs},
|
||||
Metrics: MetricsService{Addrs: httpAddrs},
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertPort converts string representation into a valid port (0 - 65535)
|
||||
func ConvertPort(in string) (uint16, error) {
|
||||
const (
|
||||
|
|
|
@ -462,7 +462,7 @@ bootstrapDns:
|
|||
err := l.UnmarshalText([]byte("55,:56"))
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(*l).Should(HaveLen(2))
|
||||
Expect(*l).Should(ContainElements("55", ":56"))
|
||||
Expect(*l).Should(ContainElements(":55", ":56"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
@ -958,7 +958,7 @@ bootstrapDns:
|
|||
})
|
||||
|
||||
func defaultTestFileConfig(config *Config) {
|
||||
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
|
||||
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
|
||||
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
|
||||
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
|
||||
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
package config
|
||||
|
||||
type (
|
||||
APIService httpService
|
||||
DoHService httpService
|
||||
MetricsService httpService
|
||||
)
|
||||
|
||||
// httpService can be used by any service that uses HTTP(S).
|
||||
type httpService struct {
|
||||
Addrs httpAddrs `yaml:"addrs"`
|
||||
}
|
||||
|
||||
type httpAddrs struct {
|
||||
HTTPAddrs `yaml:",inline"`
|
||||
HTTPSAddrs `yaml:",inline"`
|
||||
}
|
||||
|
||||
type HTTPAddrs struct {
|
||||
HTTP ListenConfig `yaml:"http"`
|
||||
}
|
||||
|
||||
type HTTPSAddrs struct {
|
||||
HTTPS ListenConfig `yaml:"https"`
|
||||
}
|
8
go.mod
8
go.mod
|
@ -6,7 +6,7 @@ require (
|
|||
github.com/abice/go-enum v0.6.0
|
||||
github.com/alicebob/miniredis/v2 v2.32.1
|
||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef
|
||||
github.com/avast/retry-go/v4 v4.5.1
|
||||
github.com/avast/retry-go/v4 v4.6.0
|
||||
github.com/creasty/defaults v1.7.0
|
||||
github.com/go-chi/chi/v5 v5.0.12
|
||||
github.com/go-chi/cors v1.2.1
|
||||
|
@ -17,10 +17,10 @@ require (
|
|||
github.com/hashicorp/golang-lru v1.0.2
|
||||
github.com/mattn/go-colorable v0.1.13
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mroth/weightedrand/v2 v2.1.0
|
||||
github.com/onsi/ginkgo/v2 v2.17.1
|
||||
github.com/onsi/gomega v1.32.0
|
||||
github.com/onsi/gomega v1.33.0
|
||||
github.com/prometheus/client_golang v1.19.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.8.0
|
||||
|
@ -38,7 +38,7 @@ require (
|
|||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef
|
||||
github.com/deepmap/oapi-codegen v1.16.2
|
||||
github.com/docker/docker v26.0.1+incompatible
|
||||
github.com/docker/docker v26.1.0+incompatible
|
||||
github.com/docker/go-connections v0.5.0
|
||||
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
|
||||
github.com/oapi-codegen/runtime v1.1.1
|
||||
|
|
16
go.sum
16
go.sum
|
@ -29,8 +29,8 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7D
|
|||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef h1:2JGTg6JapxP9/R33ZaagQtAM4EkkSYnIAlOG5EI8gkM=
|
||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef/go.mod h1:JS7hed4L1fj0hXcyEejnW57/7LCetXggd+vwrRnYeII=
|
||||
github.com/avast/retry-go/v4 v4.5.1 h1:AxIx0HGi4VZ3I02jr78j5lZ3M6x1E0Ivxa6b0pUUh7o=
|
||||
github.com/avast/retry-go/v4 v4.5.1/go.mod h1:/sipNsvNB3RRuT5iNcb6h73nw3IBmXJ/H3XrCQYSOpc=
|
||||
github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA=
|
||||
github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||
|
@ -64,8 +64,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0=
|
||||
github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v26.0.1+incompatible h1:t39Hm6lpXuXtgkF0dm1t9a5HkbUfdGy6XbWexmGr+hA=
|
||||
github.com/docker/docker v26.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/docker v26.1.0+incompatible h1:W1G9MPNbskA6VZWL7b3ZljTh0pXI68FpINx0GKaOdaM=
|
||||
github.com/docker/docker v26.1.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
|
@ -193,8 +193,8 @@ github.com/mattn/goveralls v0.0.12 h1:PEEeF0k1SsTjOBQ8FOmrOAoCu4ytuMaWCnWe94zxbC
|
|||
github.com/mattn/goveralls v0.0.12/go.mod h1:44ImGEUfmqH8bBtaMrYKsM65LXfNLWmwaxFGjZwgMSQ=
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
|
||||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
|
||||
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
|
||||
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
||||
|
@ -225,8 +225,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
|||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8=
|
||||
github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs=
|
||||
github.com/onsi/gomega v1.32.0 h1:JRYU78fJ1LPxlckP6Txi/EYqJvjtMrDC04/MM5XRHPk=
|
||||
github.com/onsi/gomega v1.32.0/go.mod h1:a4x4gW6Pz2yK1MAmvluYme5lvYTn61afQ2ETw/8n4Lg=
|
||||
github.com/onsi/gomega v1.33.0 h1:snPCflnZrpMsy94p4lXVEkHo12lmPnc3vY5XBbreexE=
|
||||
github.com/onsi/gomega v1.33.0/go.mod h1:+925n5YtiFsLzzafLUHzVMBpvvRAzrydIBiSIxjX3wY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
|
@ -29,20 +30,27 @@ const (
|
|||
DS = dns.Type(dns.TypeDS)
|
||||
)
|
||||
|
||||
// GetIntPort returns an port for the current testing
|
||||
// GetIntPort returns a port for the current testing
|
||||
// process by adding the current ginkgo parallel process to
|
||||
// the base port and returning it as int
|
||||
// the base port and returning it as int.
|
||||
func GetIntPort(port int) int {
|
||||
return port + ginkgo.GinkgoParallelProcess()
|
||||
}
|
||||
|
||||
// GetStringPort returns an port for the current testing
|
||||
// GetStringPort returns a port for the current testing
|
||||
// process by adding the current ginkgo parallel process to
|
||||
// the base port and returning it as string
|
||||
// the base port and returning it as string.
|
||||
func GetStringPort(port int) string {
|
||||
return fmt.Sprintf("%d", GetIntPort(port))
|
||||
}
|
||||
|
||||
// GetHostPort returns a host:port string for the current testing
|
||||
// process by adding the current ginkgo parallel process to
|
||||
// the base port and returning it as string.
|
||||
func GetHostPort(host string, port int) string {
|
||||
return net.JoinHostPort(host, GetStringPort(port))
|
||||
}
|
||||
|
||||
// TempFile creates temp file with passed data
|
||||
func TempFile(data string) *os.File {
|
||||
f, err := os.CreateTemp("", "prefix")
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
|
@ -17,12 +13,9 @@ func RegisterMetric(c prometheus.Collector) {
|
|||
_ = reg.Register(c)
|
||||
}
|
||||
|
||||
// Start starts prometheus endpoint
|
||||
func Start(router *chi.Mux, cfg config.Metrics) {
|
||||
if cfg.Enable {
|
||||
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||
_ = reg.Register(collectors.NewGoCollector())
|
||||
router.Handle(cfg.Path, promhttp.InstrumentMetricHandler(reg,
|
||||
promhttp.HandlerFor(reg, promhttp.HandlerOpts{})))
|
||||
}
|
||||
func StartCollection() {
|
||||
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||
_ = reg.Register(collectors.NewGoCollector())
|
||||
|
||||
registerEventListeners()
|
||||
}
|
||||
|
|
|
@ -11,8 +11,8 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// RegisterEventListeners registers all metric handlers by the event bus
|
||||
func RegisterEventListeners() {
|
||||
// registerEventListeners registers all metric handlers on the event bus
|
||||
func registerEventListeners() {
|
||||
registerBlockingEventListeners()
|
||||
registerCachingEventListeners()
|
||||
registerApplicationEventListeners()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/service"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// Service implements service.HTTPService.
|
||||
type Service struct {
|
||||
service.SimpleHTTP
|
||||
}
|
||||
|
||||
func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service {
|
||||
endpoints := util.ConcatSlices(
|
||||
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||
)
|
||||
|
||||
if !metricsCfg.Enable || len(endpoints) == 0 {
|
||||
// Avoid setting up collectors and listeners
|
||||
return new(Service)
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
SimpleHTTP: service.NewSimpleHTTP("Metrics", endpoints),
|
||||
}
|
||||
|
||||
s.Router().Handle(
|
||||
metricsCfg.Path,
|
||||
promhttp.InstrumentMetricHandler(reg, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})),
|
||||
)
|
||||
|
||||
return s
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/service"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type dohService struct {
|
||||
service.SimpleHTTP
|
||||
|
||||
handler dnsHandler
|
||||
}
|
||||
|
||||
func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService {
|
||||
endpoints := util.ConcatSlices(
|
||||
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||
)
|
||||
|
||||
s := &dohService{
|
||||
SimpleHTTP: service.NewSimpleHTTP("DoH", endpoints),
|
||||
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
s.Router().Route("/dns-query", func(mux chi.Router) {
|
||||
// Handlers for / also handle /dns-query without trailing slash
|
||||
|
||||
mux.Get("/", s.handleGET)
|
||||
mux.Get("/{clientID}", s.handleGET)
|
||||
|
||||
mux.Post("/", s.handlePOST)
|
||||
mux.Post("/{clientID}", s.handlePOST)
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) {
|
||||
dnsParam, ok := req.URL.Query()["dns"]
|
||||
if !ok || len(dnsParam[0]) < 1 {
|
||||
http.Error(rw, "dns param is missing", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
|
||||
if err != nil {
|
||||
http.Error(rw, "wrong message format", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) {
|
||||
contentType := req.Header.Get("Content-type")
|
||||
if contentType != dnsContentType {
|
||||
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(rawMsg); err != nil {
|
||||
logger().Error("can't deserialize message: ", err)
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
|
||||
|
||||
s.handler(ctx, dnsReq, httpMsgWriter{rw})
|
||||
}
|
||||
|
||||
type httpMsgWriter struct {
|
||||
rw http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
|
||||
b, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.rw.Header().Set("content-type", dnsContentType)
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
|
||||
r.rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err = r.rw.Write(b)
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/service"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/cors"
|
||||
)
|
||||
|
||||
// httpMiscService implements service.HTTPService.
|
||||
//
|
||||
// This supports the existing single HTTP/HTTPS endpoints
|
||||
// that expose everything. The goal is to split it up
|
||||
// and remove it.
|
||||
type httpMiscService struct {
|
||||
service.SimpleHTTP
|
||||
}
|
||||
|
||||
func newHTTPMiscService(cfg *config.Config) *httpMiscService {
|
||||
endpoints := util.ConcatSlices(
|
||||
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
|
||||
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
|
||||
)
|
||||
|
||||
s := &httpMiscService{
|
||||
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
|
||||
}
|
||||
|
||||
configureHTTPRouter(s.Router(), cfg)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// httpServer implements subServer for HTTP.
|
||||
type httpServer struct {
|
||||
service.HTTPService
|
||||
|
||||
inner http.Server
|
||||
}
|
||||
|
||||
func newHTTPServer(svc service.HTTPService) *httpServer {
|
||||
const (
|
||||
readHeaderTimeout = 20 * time.Second
|
||||
readTimeout = 20 * time.Second
|
||||
writeTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
return &httpServer{
|
||||
HTTPService: svc,
|
||||
|
||||
inner: http.Server{
|
||||
Handler: withCommonMiddleware(svc.Router()),
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
s.inner.Close()
|
||||
}()
|
||||
|
||||
return s.inner.Serve(l)
|
||||
}
|
||||
|
||||
func withCommonMiddleware(inner http.Handler) *chi.Mux {
|
||||
// Middleware must be defined before routes, so
|
||||
// create a new router and mount the inner handler
|
||||
mux := chi.NewMux()
|
||||
|
||||
mux.Use(
|
||||
secureHeadersMiddleware,
|
||||
newCORSMiddleware(),
|
||||
)
|
||||
|
||||
mux.Mount("/", inner)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
type httpMiddleware = func(http.Handler) http.Handler
|
||||
|
||||
func secureHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
w.Header().Set("strict-transport-security", "max-age=63072000")
|
||||
w.Header().Set("x-frame-options", "DENY")
|
||||
w.Header().Set("x-content-type-options", "nosniff")
|
||||
w.Header().Set("x-xss-protection", "1; mode=block")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func newCORSMiddleware() httpMiddleware {
|
||||
const corsMaxAge = 5 * time.Minute
|
||||
|
||||
options := cors.Options{
|
||||
AllowCredentials: true,
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedOrigins: []string{"*"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
MaxAge: int(corsMaxAge.Seconds()),
|
||||
}
|
||||
|
||||
return cors.New(options).Handler
|
||||
}
|
255
server/server.go
255
server/server.go
|
@ -18,15 +18,20 @@ import (
|
|||
"net/http"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/api"
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"github.com/0xERR0R/blocky/metrics"
|
||||
"github.com/0xERR0R/blocky/model"
|
||||
"github.com/0xERR0R/blocky/redis"
|
||||
"github.com/0xERR0R/blocky/resolver"
|
||||
"github.com/0xERR0R/blocky/service"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
@ -44,14 +49,18 @@ const (
|
|||
|
||||
// Server controls the endpoints for DNS and HTTP
|
||||
type Server struct {
|
||||
dnsServers []*dns.Server
|
||||
httpListeners []net.Listener
|
||||
httpsListeners []net.Listener
|
||||
queryResolver resolver.ChainedResolver
|
||||
cfg *config.Config
|
||||
httpMux *chi.Mux
|
||||
httpsMux *chi.Mux
|
||||
cert tls.Certificate
|
||||
dnsServers []*dns.Server
|
||||
queryResolver resolver.ChainedResolver
|
||||
cfg *config.Config
|
||||
|
||||
services map[service.Listener]service.Service
|
||||
}
|
||||
|
||||
type subServer interface {
|
||||
fmt.Stringer
|
||||
service.Service
|
||||
|
||||
Serve(context.Context, net.Listener) error
|
||||
}
|
||||
|
||||
func logger() *logrus.Entry {
|
||||
|
@ -71,14 +80,6 @@ func tlsCipherSuites() []uint16 {
|
|||
return tlsCipherSuites
|
||||
}
|
||||
|
||||
func getServerAddress(addr string) string {
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr = fmt.Sprintf(":%s", addr)
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
type NewServerFunc func(address string) (*dns.Server, error)
|
||||
|
||||
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
||||
|
@ -99,39 +100,47 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// NewServer creates new server instance with passed config
|
||||
//
|
||||
//nolint:funlen
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
|
||||
var cert tls.Certificate
|
||||
|
||||
cert, err := retrieveCertificate(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't retrieve cert: %w", err)
|
||||
}
|
||||
|
||||
// #nosec G402 // See TLSVersion.validate
|
||||
res := &tls.Config{
|
||||
MinVersion: uint16(cfg.MinTLSServeVer),
|
||||
CipherSuites: tlsCipherSuites(),
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// NewServer creates new server instance with passed config
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
cfg.CopyPortsToServices()
|
||||
|
||||
var tlsCfg *tls.Config
|
||||
|
||||
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
|
||||
cert, err = retrieveCertificate(cfg)
|
||||
tlsCfg, err = newTLSConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't retrieve cert: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dnsServers, err := createServers(cfg, cert)
|
||||
dnsServers, err := createServers(cfg, tlsCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server creation failed: %w", err)
|
||||
}
|
||||
|
||||
httpRouter := createHTTPRouter(cfg)
|
||||
httpsRouter := createHTTPSRouter(cfg)
|
||||
|
||||
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
|
||||
listeners, err := createListeners(ctx, cfg, tlsCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(httpListeners) != 0 || len(httpsListeners) != 0 {
|
||||
metrics.Start(httpRouter, cfg.Prometheus)
|
||||
metrics.Start(httpsRouter, cfg.Prometheus)
|
||||
}
|
||||
|
||||
metrics.RegisterEventListeners()
|
||||
|
||||
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -151,27 +160,21 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
}
|
||||
|
||||
server = &Server{
|
||||
dnsServers: dnsServers,
|
||||
queryResolver: queryResolver,
|
||||
cfg: cfg,
|
||||
httpListeners: httpListeners,
|
||||
httpsListeners: httpsListeners,
|
||||
httpMux: httpRouter,
|
||||
httpsMux: httpsRouter,
|
||||
cert: cert,
|
||||
dnsServers: dnsServers,
|
||||
queryResolver: queryResolver,
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
server.printConfiguration()
|
||||
|
||||
server.registerDNSHandlers(ctx)
|
||||
err = server.registerAPIEndpoints(httpRouter)
|
||||
|
||||
services, err := server.createServices()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = server.registerAPIEndpoints(httpsRouter)
|
||||
|
||||
server.services, err = service.GroupByListener(services, listeners)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -179,14 +182,35 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
|||
return server, err
|
||||
}
|
||||
|
||||
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) {
|
||||
func (s *Server) createServices() ([]service.Service, error) {
|
||||
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := []service.Service{
|
||||
newHTTPMiscService(s.cfg),
|
||||
newDoHService(s.cfg.Services.DoH, s.handleReq),
|
||||
api.NewService(s.cfg.Services.API, openAPIImpl),
|
||||
metrics.NewService(s.cfg.Services.Metrics, s.cfg.Prometheus),
|
||||
}
|
||||
|
||||
// Remove services the user has not enabled
|
||||
res = slices.DeleteFunc(res, func(svc service.Service) bool {
|
||||
return len(svc.ExposeOn()) == 0
|
||||
})
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
|
||||
var dnsServers []*dns.Server
|
||||
|
||||
var err *multierror.Error
|
||||
|
||||
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
|
||||
for _, address := range addresses {
|
||||
server, err := newServer(getServerAddress(address))
|
||||
server, err := newServer(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -201,52 +225,69 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
|
|||
addServers(createUDPServer, cfg.Ports.DNS),
|
||||
addServers(createTCPServer, cfg.Ports.DNS),
|
||||
addServers(func(address string) (*dns.Server, error) {
|
||||
return createTLSServer(cfg, address, cert)
|
||||
return createTLSServer(address, tlsCfg)
|
||||
}, cfg.Ports.TLS))
|
||||
|
||||
return dnsServers, err.ErrorOrNil()
|
||||
}
|
||||
|
||||
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
|
||||
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) {
|
||||
res := make(map[string]service.Listener)
|
||||
|
||||
listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) {
|
||||
return service.ListenTLS(ctx, endpoint, tlsCfg)
|
||||
}
|
||||
|
||||
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
|
||||
err := errors.Join(
|
||||
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
|
||||
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
|
||||
newListeners(ctx, service.HTTPProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res),
|
||||
newListeners(ctx, service.HTTPSProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res),
|
||||
newListeners(ctx, service.HTTPProtocol, cfg.Services.Metrics.Addrs.HTTP, service.ListenTCP, res),
|
||||
newListeners(ctx, service.HTTPSProtocol, cfg.Services.Metrics.Addrs.HTTPS, listenTLS, res),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return httpListeners, httpsListeners, nil
|
||||
return maps.Values(res), nil
|
||||
}
|
||||
|
||||
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
|
||||
listeners := make([]net.Listener, 0, len(addresses))
|
||||
type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error)
|
||||
|
||||
for _, address := range addresses {
|
||||
listener, err := net.Listen("tcp", getServerAddress(address))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
|
||||
func newListeners[T service.Listener](
|
||||
ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener,
|
||||
) error {
|
||||
for _, addr := range addrs {
|
||||
key := fmt.Sprintf("%s:%s", proto, addr)
|
||||
if _, ok := out[key]; ok {
|
||||
// Avoid "address already in use"
|
||||
// We instead try to merge services, see services.GroupByListener
|
||||
continue
|
||||
}
|
||||
|
||||
listeners = append(listeners, listener)
|
||||
endpoint := service.Endpoint{
|
||||
Protocol: proto,
|
||||
AddrConf: addr,
|
||||
}
|
||||
|
||||
l, err := listen(ctx, endpoint)
|
||||
if err != nil {
|
||||
return err // already has all info
|
||||
}
|
||||
|
||||
out[key] = l
|
||||
}
|
||||
|
||||
return listeners, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) {
|
||||
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
|
||||
return &dns.Server{
|
||||
Addr: address,
|
||||
Net: "tcp-tls",
|
||||
//nolint:gosec
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: uint16(cfg.MinTLSServeVer),
|
||||
CipherSuites: tlsCipherSuites(),
|
||||
},
|
||||
Handler: dns.NewServeMux(),
|
||||
Addr: address,
|
||||
Net: "tcp-tls",
|
||||
TLSConfig: tlsCfg,
|
||||
Handler: dns.NewServeMux(),
|
||||
NotifyStartedFunc: func() {
|
||||
logger().Infof("TLS server is up and running on address %s", address)
|
||||
},
|
||||
|
@ -470,11 +511,15 @@ func toMB(b uint64) uint64 {
|
|||
return b / bytesInKB / bytesInKB
|
||||
}
|
||||
|
||||
const (
|
||||
readHeaderTimeout = 20 * time.Second
|
||||
readTimeout = 20 * time.Second
|
||||
writeTimeout = 20 * time.Second
|
||||
)
|
||||
func newSubServer(svc service.Service) (subServer, error) {
|
||||
switch svc := svc.(type) {
|
||||
case service.HTTPService:
|
||||
return newHTTPServer(svc), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc)
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
||||
|
@ -490,48 +535,22 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
|||
}()
|
||||
}
|
||||
|
||||
for i, listener := range s.httpListeners {
|
||||
listener := listener
|
||||
address := s.cfg.Ports.HTTP[i]
|
||||
for listener, svc := range s.services {
|
||||
listener, svc := listener, svc
|
||||
|
||||
srv, err := newSubServer(svc)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger().Infof("http server is up and running on addr/port %s", address)
|
||||
logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes())
|
||||
|
||||
srv := &http.Server{
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
Handler: s.httpsMux,
|
||||
}
|
||||
|
||||
if err := srv.Serve(listener); err != nil {
|
||||
errCh <- fmt.Errorf("start http listener failed: %w", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for i, listener := range s.httpsListeners {
|
||||
listener := listener
|
||||
address := s.cfg.Ports.HTTPS[i]
|
||||
|
||||
go func() {
|
||||
logger().Infof("https server is up and running on addr/port %s", address)
|
||||
|
||||
server := http.Server{
|
||||
Handler: s.httpsMux,
|
||||
ReadTimeout: readTimeout,
|
||||
ReadHeaderTimeout: readHeaderTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
//nolint:gosec
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: uint16(s.cfg.MinTLSServeVer),
|
||||
CipherSuites: tlsCipherSuites(),
|
||||
Certificates: []tls.Certificate{s.cert},
|
||||
},
|
||||
}
|
||||
|
||||
if err := server.ServeTLS(listener, "", ""); err != nil {
|
||||
errCh <- fmt.Errorf("start https listener failed: %w", err)
|
||||
err := srv.Serve(ctx, listener)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -630,6 +649,8 @@ type msgWriter interface {
|
|||
WriteMsg(msg *dns.Msg) error
|
||||
}
|
||||
|
||||
type dnsHandler func(context.Context, *model.Request, msgWriter)
|
||||
|
||||
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
|
||||
response, err := s.resolve(ctx, request)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,13 +2,10 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/0xERR0R/blocky/resolver"
|
||||
|
||||
|
@ -22,7 +19,6 @@ import (
|
|||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -32,19 +28,8 @@ const (
|
|||
dnsContentType = "application/dns-message"
|
||||
htmlContentType = "text/html; charset=UTF-8"
|
||||
yamlContentType = "text/yaml"
|
||||
corsMaxAge = 5 * time.Minute
|
||||
)
|
||||
|
||||
func secureHeader(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("strict-transport-security", "max-age=63072000")
|
||||
w.Header().Set("x-frame-options", "DENY")
|
||||
w.Header().Set("x-content-type-options", "nosniff")
|
||||
w.Header().Set("x-xss-protection", "1; mode=block")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) {
|
||||
bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver)
|
||||
if err != nil {
|
||||
|
@ -64,108 +49,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
|
|||
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
|
||||
}
|
||||
|
||||
func (s *Server) registerAPIEndpoints(router *chi.Mux) error {
|
||||
const pathDohQuery = "/dns-query"
|
||||
|
||||
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
|
||||
|
||||
router.Get(pathDohQuery, s.dohGetRequestHandler)
|
||||
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
|
||||
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
|
||||
router.Post(pathDohQuery, s.dohPostRequestHandler)
|
||||
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
|
||||
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
dnsParam, ok := req.URL.Query()["dns"]
|
||||
if !ok || len(dnsParam[0]) < 1 {
|
||||
http.Error(rw, "dns param is missing", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
|
||||
if err != nil {
|
||||
http.Error(rw, "wrong message format", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
contentType := req.Header.Get("Content-type")
|
||||
if contentType != dnsContentType {
|
||||
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(rawMsg); err != nil {
|
||||
logger().Error("can't deserialize message: ", err)
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
|
||||
|
||||
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
|
||||
}
|
||||
|
||||
type httpMsgWriter struct {
|
||||
rw http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
|
||||
b, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.rw.Header().Set("content-type", dnsContentType)
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
|
||||
r.rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err = r.rw.Write(b)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) Query(
|
||||
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
||||
) (*model.Response, error) {
|
||||
|
@ -177,27 +60,7 @@ func (s *Server) Query(
|
|||
return s.resolve(ctx, req)
|
||||
}
|
||||
|
||||
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
||||
router := chi.NewRouter()
|
||||
|
||||
configureSecureHeaderHandler(router)
|
||||
|
||||
registerHandlers(cfg, router)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func createHTTPRouter(cfg *config.Config) *chi.Mux {
|
||||
router := chi.NewRouter()
|
||||
|
||||
registerHandlers(cfg, router)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func registerHandlers(cfg *config.Config, router *chi.Mux) {
|
||||
configureCorsHandler(router)
|
||||
|
||||
func configureHTTPRouter(router chi.Router, cfg *config.Config) {
|
||||
configureDebugHandler(router)
|
||||
|
||||
configureDocsHandler(router)
|
||||
|
@ -207,7 +70,7 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) {
|
|||
configureRootHandler(cfg, router)
|
||||
}
|
||||
|
||||
func configureDocsHandler(router *chi.Mux) {
|
||||
func configureDocsHandler(router chi.Router) {
|
||||
router.Get("/docs/openapi.yaml", func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set(contentTypeHeader, yamlContentType)
|
||||
_, err := writer.Write([]byte(docs.OpenAPI))
|
||||
|
@ -215,7 +78,7 @@ func configureDocsHandler(router *chi.Mux) {
|
|||
})
|
||||
}
|
||||
|
||||
func configureStaticAssetsHandler(router *chi.Mux) {
|
||||
func configureStaticAssetsHandler(router chi.Router) {
|
||||
assets, err := web.Assets()
|
||||
util.FatalOnError("unable to load static asset files", err)
|
||||
|
||||
|
@ -223,7 +86,7 @@ func configureStaticAssetsHandler(router *chi.Mux) {
|
|||
router.Handle("/static/*", http.StripPrefix("/static/", fs))
|
||||
}
|
||||
|
||||
func configureRootHandler(cfg *config.Config, router *chi.Mux) {
|
||||
func configureRootHandler(cfg *config.Config, router chi.Router) {
|
||||
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.Header().Set(contentTypeHeader, htmlContentType)
|
||||
|
||||
|
@ -282,22 +145,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit
|
|||
}
|
||||
}
|
||||
|
||||
func configureSecureHeaderHandler(router *chi.Mux) {
|
||||
router.Use(secureHeader)
|
||||
}
|
||||
|
||||
func configureDebugHandler(router *chi.Mux) {
|
||||
func configureDebugHandler(router chi.Router) {
|
||||
router.Mount("/debug", middleware.Profiler())
|
||||
}
|
||||
|
||||
func configureCorsHandler(router *chi.Mux) {
|
||||
crs := cors.New(cors.Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: int(corsMaxAge.Seconds()),
|
||||
})
|
||||
router.Use(crs.Handler)
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ var (
|
|||
)
|
||||
|
||||
var _ = BeforeSuite(func() {
|
||||
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
|
||||
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
|
||||
queryURL = baseURL + "dns-query"
|
||||
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
|
@ -147,10 +147,10 @@ var _ = BeforeSuite(func() {
|
|||
},
|
||||
|
||||
Ports: config.Ports{
|
||||
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
|
||||
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
|
||||
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
|
||||
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
|
||||
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
|
||||
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
|
||||
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
|
||||
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
|
||||
},
|
||||
CertFile: certPem.Path,
|
||||
KeyFile: keyPem.Path,
|
||||
|
@ -634,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
},
|
||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||
Ports: config.Ports{
|
||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
||||
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -678,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
|
|||
},
|
||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||
Ports: config.Ports{
|
||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
||||
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -741,17 +741,18 @@ var _ = Describe("Running DNS server", func() {
|
|||
cfg.KeyFile = ""
|
||||
cfg.CertFile = ""
|
||||
cfg.Ports = config.Ports{
|
||||
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
|
||||
HTTPS: []string{":0"},
|
||||
}
|
||||
sut, err := NewServer(ctx, &cfg)
|
||||
|
||||
sut, err := newTLSConfig(&cfg)
|
||||
Expect(err).Should(Succeed())
|
||||
Expect(sut.cert.Certificate).ShouldNot(BeNil())
|
||||
Expect(sut.Certificates).ShouldNot(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func requestServer(request *dns.Msg) *dns.Msg {
|
||||
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
|
||||
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
|
||||
if err != nil {
|
||||
Log().Fatal("could not connect to server: ", err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// Endpoint is a network endpoint on which to expose a service.
|
||||
type Endpoint struct {
|
||||
// Protocol is the protocol to be exposed on this endpoint.
|
||||
Protocol string
|
||||
|
||||
// AddrConf is the network address as configured by the user.
|
||||
AddrConf string
|
||||
}
|
||||
|
||||
func EndpointsFromAddrs(proto string, addrs []string) []Endpoint {
|
||||
return util.ForEach(addrs, func(addr string) Endpoint {
|
||||
return Endpoint{
|
||||
Protocol: proto,
|
||||
AddrConf: addr,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (e Endpoint) String() string {
|
||||
addr := e.AddrConf
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
addr = "*" + addr
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", e.Protocol, addr)
|
||||
}
|
||||
|
||||
type endpointSet map[Endpoint]struct{}
|
||||
|
||||
func (s endpointSet) ToSlice() []Endpoint {
|
||||
return maps.Keys(s)
|
||||
}
|
||||
|
||||
func (s endpointSet) IntersectSlice(others []Endpoint) {
|
||||
for endpoint := range s {
|
||||
if !slices.Contains(others, endpoint) {
|
||||
delete(s, endpoint)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
HTTPProtocol = "http"
|
||||
HTTPSProtocol = "https"
|
||||
)
|
||||
|
||||
// HTTPService is a Service using a HTTP router.
|
||||
type HTTPService interface {
|
||||
Service
|
||||
Merger
|
||||
|
||||
// Router returns the service's router.
|
||||
Router() chi.Router
|
||||
}
|
||||
|
||||
// HTTPInfo can be embedded in structs to help implement HTTPService.
|
||||
type HTTPInfo struct {
|
||||
Info
|
||||
|
||||
mux *chi.Mux
|
||||
}
|
||||
|
||||
func NewHTTPInfo(name string, endpoints []Endpoint) HTTPInfo {
|
||||
return HTTPInfo{
|
||||
Info: NewInfo(name, endpoints),
|
||||
|
||||
mux: chi.NewMux(),
|
||||
}
|
||||
}
|
||||
|
||||
func (i *HTTPInfo) Router() chi.Router { return i.mux }
|
||||
|
||||
var _ HTTPService = (*SimpleHTTP)(nil)
|
||||
|
||||
// SimpleHTTP implements HTTPService usinig the default HTTP merger.
|
||||
type SimpleHTTP struct{ HTTPInfo }
|
||||
|
||||
func NewSimpleHTTP(name string, endpoints []Endpoint) SimpleHTTP {
|
||||
return SimpleHTTP{HTTPInfo: NewHTTPInfo(name, endpoints)}
|
||||
}
|
||||
|
||||
func (s *SimpleHTTP) Merge(other Service) (Merger, error) {
|
||||
return MergeHTTP(s, other)
|
||||
}
|
||||
|
||||
// MergeHTTP merges two compatible HTTPServices.
|
||||
//
|
||||
// The second parameter is of type `Service` to make it easy to call
|
||||
// from a `Merger.Merge` implementation.
|
||||
func MergeHTTP(a HTTPService, b Service) (Merger, error) {
|
||||
return newHTTPMerger(a).Merge(b)
|
||||
}
|
||||
|
||||
var _ HTTPService = (*httpMerger)(nil)
|
||||
|
||||
// httpMerger can merge HTTPServices by combining their routes.
|
||||
type httpMerger struct {
|
||||
inner []HTTPService
|
||||
router chi.Router
|
||||
endpoints endpointSet
|
||||
}
|
||||
|
||||
func newHTTPMerger(first HTTPService) *httpMerger {
|
||||
return &httpMerger{
|
||||
inner: []HTTPService{first},
|
||||
router: first.Router(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *httpMerger) String() string { return svcString(m) }
|
||||
|
||||
func (m *httpMerger) ServiceName() string {
|
||||
names := util.ForEach(m.inner, func(svc HTTPService) string {
|
||||
return svc.ServiceName()
|
||||
})
|
||||
|
||||
return strings.Join(names, " & ")
|
||||
}
|
||||
|
||||
func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() }
|
||||
func (m *httpMerger) Router() chi.Router { return m.router }
|
||||
|
||||
func (m *httpMerger) Merge(other Service) (Merger, error) {
|
||||
httpSvc, ok := other.(HTTPService)
|
||||
if !ok {
|
||||
return nil, errors.New("not an HTTPService")
|
||||
}
|
||||
|
||||
type middleware = func(http.Handler) http.Handler
|
||||
|
||||
// Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined
|
||||
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
|
||||
m.router.With(middlewares...).Method(method, route, handler)
|
||||
|
||||
// Expose /example/ as /example too
|
||||
// Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830
|
||||
// This means we expose the route without the slash even if it wasn't oringinally registered as such!
|
||||
// The main point of this is for DoH (/dns-query).
|
||||
if strings.HasSuffix(route, "/") {
|
||||
route := strings.TrimSuffix(route, "/")
|
||||
m.router.With(middlewares...).Method(method, route, handler)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
m.inner = append(m.inner, httpSvc)
|
||||
|
||||
// Don't expose any service more than it expects
|
||||
m.endpoints.IntersectSlice(other.ExposeOn())
|
||||
|
||||
return m, nil
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Listener is a net.Listener that provides information about
|
||||
// what protocol and address it is configured for.
|
||||
type Listener interface {
|
||||
fmt.Stringer
|
||||
net.Listener
|
||||
|
||||
// Exposes returns the endpoint for this listener.
|
||||
//
|
||||
// It can be used to find service(s) with matching configuration.
|
||||
Exposes() Endpoint
|
||||
}
|
||||
|
||||
// ListenerInfo can be embedded in structs to help implement Listener.
|
||||
type ListenerInfo struct {
|
||||
Endpoint
|
||||
}
|
||||
|
||||
func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint }
|
||||
|
||||
// NetListener implements Listener using an existing net.Listener.
|
||||
type NetListener struct {
|
||||
net.Listener
|
||||
ListenerInfo
|
||||
}
|
||||
|
||||
func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener {
|
||||
return &NetListener{
|
||||
Listener: inner,
|
||||
ListenerInfo: ListenerInfo{endpoint},
|
||||
}
|
||||
}
|
||||
|
||||
// TCPListener is a Listener for a TCP socket.
|
||||
type TCPListener struct{ NetListener }
|
||||
|
||||
// ListenTCP creates a new TCPListener.
|
||||
func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) {
|
||||
var lc net.ListenConfig
|
||||
|
||||
l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf)
|
||||
if err != nil {
|
||||
return nil, err // err already has all the info we could add
|
||||
}
|
||||
|
||||
inner := NewNetListener(endpoint, l)
|
||||
|
||||
return &TCPListener{*inner}, nil
|
||||
}
|
||||
|
||||
// TLSListener is a Listener using TLS over TCP.
|
||||
type TLSListener struct{ NetListener }
|
||||
|
||||
// ListenTLS creates a new TLSListener.
|
||||
func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) {
|
||||
tcp, err := ListenTCP(ctx, endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inner := tcp.NetListener
|
||||
|
||||
inner.Listener = tls.NewListener(inner.Listener, cfg)
|
||||
|
||||
return &TLSListener{inner}, nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package service
|
||||
|
||||
import "errors"
|
||||
|
||||
// Merger is a Service that can be merged with another compatible one.
|
||||
type Merger interface {
|
||||
Service
|
||||
|
||||
// Merge returns the result of merging the receiver with the other Service.
|
||||
//
|
||||
// Neither the receiver, nor the other Service should be used directly after
|
||||
// calling this method.
|
||||
Merge(other Service) (Merger, error)
|
||||
}
|
||||
|
||||
// MergeAll merges the given services, if they are compatible.
|
||||
//
|
||||
// This allows using multiple compatible services with a single listener.
|
||||
//
|
||||
// All passed-in services must not be re-used.
|
||||
func MergeAll(services []Service) (Service, error) {
|
||||
switch len(services) {
|
||||
case 0:
|
||||
return nil, errors.New("no services given")
|
||||
|
||||
case 1:
|
||||
return services[0], nil
|
||||
}
|
||||
|
||||
merger, err := firstMerger(services)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == merger {
|
||||
continue
|
||||
}
|
||||
|
||||
merger, err = merger.Merge(svc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return merger, nil
|
||||
}
|
||||
|
||||
func firstMerger(services []Service) (Merger, error) {
|
||||
for _, t := range services {
|
||||
if svc, ok := t.(Merger); ok {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no merger found")
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
// Package service exposes types to abstract services from the networking.
|
||||
//
|
||||
// The idea is that we build a set of services and a set of network endpoints (Listener).
|
||||
// The services are then assigned to endpoints based on the address(es) they were configured for.
|
||||
//
|
||||
// Actual service to endpoint binding is not handled by the abstractions in this package as it is
|
||||
// protocol specific.
|
||||
// The general pattern is to make a "server" that wraps a service, and can then be started on an
|
||||
// endpoint using a `Serve` method, similar to `http.Server`.
|
||||
//
|
||||
// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port),
|
||||
// services can implement `Merger`.
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
)
|
||||
|
||||
// Service is a network exposed service.
|
||||
//
|
||||
// It contains only the logic and user configured addresses it should be exposed on.
|
||||
// Is is meant to be associated to one or more sockets via those addresses.
|
||||
// Actual association with a socket is protocol specific.
|
||||
type Service interface {
|
||||
fmt.Stringer
|
||||
|
||||
// ServiceName returns the user friendly name of the service.
|
||||
ServiceName() string
|
||||
|
||||
// ExposeOn returns the set of endpoints the service should be exposed on.
|
||||
//
|
||||
// They can be used to find listener(s) with matching configuration.
|
||||
ExposeOn() []Endpoint
|
||||
}
|
||||
|
||||
func svcString(s Service) string {
|
||||
endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() })
|
||||
|
||||
return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", "))
|
||||
}
|
||||
|
||||
// Info can be embedded in structs to help implement Service.
|
||||
type Info struct {
|
||||
name string
|
||||
endpoints []Endpoint
|
||||
}
|
||||
|
||||
func NewInfo(name string, endpoints []Endpoint) Info {
|
||||
return Info{
|
||||
name: name,
|
||||
endpoints: endpoints,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Info) ServiceName() string { return i.name }
|
||||
func (i *Info) ExposeOn() []Endpoint { return i.endpoints }
|
||||
func (i *Info) String() string { return svcString(i) }
|
||||
|
||||
// GroupByListener returns a map of listener and services grouped by configured address.
|
||||
//
|
||||
// Each input listener is a key in the map. The corresponding value is a service
|
||||
// merged from all services with a matching address.
|
||||
func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) {
|
||||
res := make(map[Listener]Service, len(listeners))
|
||||
unused := slices.Clone(services)
|
||||
|
||||
for _, listener := range listeners {
|
||||
services := findAllCompatible(services, listener.Exposes())
|
||||
if len(services) == 0 {
|
||||
return nil, fmt.Errorf("found no compatible services for listener %s", listener)
|
||||
}
|
||||
|
||||
svc, err := MergeAll(services)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err)
|
||||
}
|
||||
|
||||
res[listener] = svc
|
||||
|
||||
for _, svc := range services {
|
||||
if i := slices.Index(unused, svc); i != -1 {
|
||||
unused = slices.Delete(unused, i, i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(unused) != 0 {
|
||||
return nil, fmt.Errorf("found no compatible listener for services: %v", unused)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// findAllCompatible returns the subset of services that use the given Listener.
|
||||
func findAllCompatible(services []Service, endpoint Endpoint) []Service {
|
||||
res := make([]Service, 0, len(services))
|
||||
|
||||
for _, svc := range services {
|
||||
if isExposedOn(svc, endpoint) {
|
||||
res = append(res, svc)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func isExposedOn(svc Service, endpoint Endpoint) bool {
|
||||
return slices.Index(svc.ExposeOn(), endpoint) != -1
|
||||
}
|
18
trie/trie.go
18
trie/trie.go
|
@ -1,5 +1,10 @@
|
|||
package trie
|
||||
|
||||
import (
|
||||
"github.com/0xERR0R/blocky/log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Trie stores a set of strings and can quickly check
|
||||
// if it contains an element, or one of its parents.
|
||||
//
|
||||
|
@ -108,8 +113,12 @@ func (n *parent) insert(key string, split SplitFunc) {
|
|||
}
|
||||
|
||||
func (n *parent) hasParentOf(key string, split SplitFunc) bool {
|
||||
searchString := key
|
||||
rule := ""
|
||||
|
||||
for {
|
||||
label, rest := split(key)
|
||||
rule = strings.Join([]string{label, rule}, ".")
|
||||
|
||||
child, ok := n.children[label]
|
||||
if !ok {
|
||||
|
@ -132,7 +141,14 @@ func (n *parent) hasParentOf(key string, split SplitFunc) bool {
|
|||
|
||||
case terminal:
|
||||
// Continue down the trie
|
||||
return child.hasParentOf(rest, split)
|
||||
matched := child.hasParentOf(rest, split)
|
||||
if matched {
|
||||
rule = strings.Join([]string{child.String(), rule}, ".")
|
||||
rule = strings.Trim(rule, ".")
|
||||
log.PrefixedLog("trie").Debugf("wildcard block rule '%s' matched with '%s'", rule, searchString)
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package util
|
||||
|
||||
// ForEach implements the functional map operation, under a different
|
||||
// name to avoid confusion with Go's map type.
|
||||
func ForEach[T, U any](slice []T, convert func(T) U) []U {
|
||||
res := make([]U, 0, len(slice))
|
||||
|
||||
for _, t := range slice {
|
||||
u := convert(t)
|
||||
|
||||
res = append(res, u)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// ConcatSlices returns a new slice with contents of all the inputs concatenated.
|
||||
func ConcatSlices[T any](slices ...[]T) []T {
|
||||
// Allocation is usually the bottleneck, so do it all at once
|
||||
totalLen := 0
|
||||
|
||||
for _, slice := range slices {
|
||||
totalLen += len(slice)
|
||||
}
|
||||
|
||||
res := make([]T, 0, totalLen)
|
||||
|
||||
for _, slice := range slices {
|
||||
res = append(res, slice...)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
Loading…
Reference in New Issue