mirror of https://github.com/0xERR0R/blocky.git
Compare commits
19 Commits
c206364713
...
3cd36e3ad4
Author | SHA1 | Date |
---|---|---|
ThinkChaos | 3cd36e3ad4 | |
dependabot[bot] | 4ebe1ef21a | |
dependabot[bot] | 7f20d17d2e | |
dependabot[bot] | cbbe8d46f0 | |
dependabot[bot] | 62b1354fba | |
Thomas Anderson | e99c98b4c2 | |
ThinkChaos | 74b8931998 | |
ThinkChaos | 48cad3b786 | |
ThinkChaos | 1515e17f0f | |
ThinkChaos | e1c717be70 | |
ThinkChaos | d7a2952b1d | |
ThinkChaos | c11f9a1c98 | |
ThinkChaos | 4b37b404bf | |
ThinkChaos | 36d443728d | |
ThinkChaos | 35b1c16878 | |
ThinkChaos | 17b2a94a64 | |
ThinkChaos | c6e3de4ae0 | |
ThinkChaos | c389a4a0f4 | |
ThinkChaos | 7d510d009b |
|
@ -53,7 +53,7 @@ type CacheControl interface {
|
||||||
FlushCaches(ctx context.Context)
|
FlushCaches(ctx context.Context)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
func registerOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
|
||||||
middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware}
|
middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware}
|
||||||
|
|
||||||
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api")
|
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api")
|
||||||
|
|
|
@ -105,7 +105,7 @@ var _ = Describe("API implementation tests", func() {
|
||||||
Describe("RegisterOpenAPIEndpoints", func() {
|
Describe("RegisterOpenAPIEndpoints", func() {
|
||||||
It("adds routes", func() {
|
It("adds routes", func() {
|
||||||
rtr := chi.NewRouter()
|
rtr := chi.NewRouter()
|
||||||
RegisterOpenAPIEndpoints(rtr, sut)
|
registerOpenAPIEndpoints(rtr, sut)
|
||||||
|
|
||||||
Expect(rtr.Routes()).ShouldNot(BeEmpty())
|
Expect(rtr.Routes()).ShouldNot(BeEmpty())
|
||||||
})
|
})
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
"github.com/0xERR0R/blocky/service"
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service implements service.HTTPService.
|
||||||
|
type Service struct {
|
||||||
|
service.SimpleHTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(cfg config.APIService, server StrictServerInterface) *Service {
|
||||||
|
endpoints := util.ConcatSlices(
|
||||||
|
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||||
|
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||||
|
)
|
||||||
|
|
||||||
|
s := &Service{
|
||||||
|
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
|
||||||
|
}
|
||||||
|
|
||||||
|
registerOpenAPIEndpoints(s.Router(), server)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
|
@ -50,7 +50,12 @@ func (cache stringMap) contains(searchString string) bool {
|
||||||
})
|
})
|
||||||
|
|
||||||
if idx < searchBucketLen {
|
if idx < searchBucketLen {
|
||||||
return cache[searchLen][idx*searchLen:idx*searchLen+searchLen] == strings.ToLower(normalized)
|
blockRule := cache[searchLen][idx*searchLen : idx*searchLen+searchLen]
|
||||||
|
if blockRule == normalized {
|
||||||
|
log.PrefixedLog("string_map").Debugf("block rule '%s' matched with '%s'", blockRule, searchString)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -132,7 +137,7 @@ func (cache regexCache) elementCount() int {
|
||||||
func (cache regexCache) contains(searchString string) bool {
|
func (cache regexCache) contains(searchString string) bool {
|
||||||
for _, regex := range cache {
|
for _, regex := range cache {
|
||||||
if regex.MatchString(searchString) {
|
if regex.MatchString(searchString) {
|
||||||
log.PrefixedLog("regexCache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
log.PrefixedLog("regex_cache").Debugf("regex '%s' matched with '%s'", regex, searchString)
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/evt"
|
"github.com/0xERR0R/blocky/evt"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
|
"github.com/0xERR0R/blocky/metrics"
|
||||||
"github.com/0xERR0R/blocky/server"
|
"github.com/0xERR0R/blocky/server"
|
||||||
"github.com/0xERR0R/blocky/util"
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
|
||||||
|
@ -47,6 +48,10 @@ func startServer(_ *cobra.Command, _ []string) error {
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
ctx, cancelFn := context.WithCancel(context.Background())
|
||||||
defer cancelFn()
|
defer cancelFn()
|
||||||
|
|
||||||
|
if cfg.Prometheus.Enable {
|
||||||
|
metrics.StartCollection()
|
||||||
|
}
|
||||||
|
|
||||||
srv, err := server.NewServer(ctx, cfg)
|
srv, err := server.NewServer(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("can't start server: %w", err)
|
return fmt.Errorf("can't start server: %w", err)
|
||||||
|
|
|
@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
|
||||||
|
|
||||||
*l = strings.Split(addresses, ",")
|
*l = strings.Split(addresses, ",")
|
||||||
|
|
||||||
|
// Prefix all ports with :
|
||||||
|
for i, addr := range *l {
|
||||||
|
if !strings.ContainsRune(addr, ':') {
|
||||||
|
(*l)[i] = ":" + addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,6 +233,7 @@ type Config struct {
|
||||||
Redis Redis `yaml:"redis"`
|
Redis Redis `yaml:"redis"`
|
||||||
Log log.Config `yaml:"log"`
|
Log log.Config `yaml:"log"`
|
||||||
Ports Ports `yaml:"ports"`
|
Ports Ports `yaml:"ports"`
|
||||||
|
Services Services `yaml:"-"` // not user exposed yet
|
||||||
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
||||||
CertFile string `yaml:"certFile"`
|
CertFile string `yaml:"certFile"`
|
||||||
KeyFile string `yaml:"keyFile"`
|
KeyFile string `yaml:"keyFile"`
|
||||||
|
@ -255,6 +263,19 @@ type Config struct {
|
||||||
} `yaml:",inline"`
|
} `yaml:",inline"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Services holds network service related configuration.
|
||||||
|
//
|
||||||
|
// The actual config layout is not decided yet.
|
||||||
|
// See https://github.com/0xERR0R/blocky/issues/1206
|
||||||
|
//
|
||||||
|
// The `yaml` struct tags are just for manual testing,
|
||||||
|
// and require replacing `yaml:"-"` in Config to work.
|
||||||
|
type Services struct {
|
||||||
|
API APIService `yaml:"control-api"`
|
||||||
|
DoH DoHService `yaml:"dns-over-https"`
|
||||||
|
Metrics MetricsService `yaml:"metrics"`
|
||||||
|
}
|
||||||
|
|
||||||
type Ports struct {
|
type Ports struct {
|
||||||
DNS ListenConfig `yaml:"dns" default:"53"`
|
DNS ListenConfig `yaml:"dns" default:"53"`
|
||||||
HTTP ListenConfig `yaml:"http"`
|
HTTP ListenConfig `yaml:"http"`
|
||||||
|
@ -594,6 +615,23 @@ func (cfg *Config) validate(logger *logrus.Entry) {
|
||||||
cfg.Upstreams.validate(logger)
|
cfg.Upstreams.validate(logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyPortsToServices sets Services values to match Ports.
|
||||||
|
//
|
||||||
|
// This should be replaced with a migration once everything from Ports is supported in Services.
|
||||||
|
// Done this way for now to avoid creating temporary generic services and updating all Ports related code at once.
|
||||||
|
func (cfg *Config) CopyPortsToServices() {
|
||||||
|
httpAddrs := httpAddrs{
|
||||||
|
HTTPAddrs: HTTPAddrs{HTTP: cfg.Ports.HTTP},
|
||||||
|
HTTPSAddrs: HTTPSAddrs{HTTPS: cfg.Ports.HTTPS},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Services = Services{
|
||||||
|
API: APIService{Addrs: httpAddrs},
|
||||||
|
DoH: DoHService{Addrs: httpAddrs},
|
||||||
|
Metrics: MetricsService{Addrs: httpAddrs},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ConvertPort converts string representation into a valid port (0 - 65535)
|
// ConvertPort converts string representation into a valid port (0 - 65535)
|
||||||
func ConvertPort(in string) (uint16, error) {
|
func ConvertPort(in string) (uint16, error) {
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -462,7 +462,7 @@ bootstrapDns:
|
||||||
err := l.UnmarshalText([]byte("55,:56"))
|
err := l.UnmarshalText([]byte("55,:56"))
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(*l).Should(HaveLen(2))
|
Expect(*l).Should(HaveLen(2))
|
||||||
Expect(*l).Should(ContainElements("55", ":56"))
|
Expect(*l).Should(ContainElements(":55", ":56"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -958,7 +958,7 @@ bootstrapDns:
|
||||||
})
|
})
|
||||||
|
|
||||||
func defaultTestFileConfig(config *Config) {
|
func defaultTestFileConfig(config *Config) {
|
||||||
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
|
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
|
||||||
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
|
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
|
||||||
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
|
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
|
||||||
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
|
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
type (
|
||||||
|
APIService httpService
|
||||||
|
DoHService httpService
|
||||||
|
MetricsService httpService
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpService can be used by any service that uses HTTP(S).
|
||||||
|
type httpService struct {
|
||||||
|
Addrs httpAddrs `yaml:"addrs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpAddrs struct {
|
||||||
|
HTTPAddrs `yaml:",inline"`
|
||||||
|
HTTPSAddrs `yaml:",inline"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPAddrs struct {
|
||||||
|
HTTP ListenConfig `yaml:"http"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HTTPSAddrs struct {
|
||||||
|
HTTPS ListenConfig `yaml:"https"`
|
||||||
|
}
|
8
go.mod
8
go.mod
|
@ -6,7 +6,7 @@ require (
|
||||||
github.com/abice/go-enum v0.6.0
|
github.com/abice/go-enum v0.6.0
|
||||||
github.com/alicebob/miniredis/v2 v2.32.1
|
github.com/alicebob/miniredis/v2 v2.32.1
|
||||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef
|
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef
|
||||||
github.com/avast/retry-go/v4 v4.5.1
|
github.com/avast/retry-go/v4 v4.6.0
|
||||||
github.com/creasty/defaults v1.7.0
|
github.com/creasty/defaults v1.7.0
|
||||||
github.com/go-chi/chi/v5 v5.0.12
|
github.com/go-chi/chi/v5 v5.0.12
|
||||||
github.com/go-chi/cors v1.2.1
|
github.com/go-chi/cors v1.2.1
|
||||||
|
@ -17,10 +17,10 @@ require (
|
||||||
github.com/hashicorp/golang-lru v1.0.2
|
github.com/hashicorp/golang-lru v1.0.2
|
||||||
github.com/mattn/go-colorable v0.1.13
|
github.com/mattn/go-colorable v0.1.13
|
||||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
|
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect
|
||||||
github.com/miekg/dns v1.1.58
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mroth/weightedrand/v2 v2.1.0
|
github.com/mroth/weightedrand/v2 v2.1.0
|
||||||
github.com/onsi/ginkgo/v2 v2.17.1
|
github.com/onsi/ginkgo/v2 v2.17.1
|
||||||
github.com/onsi/gomega v1.32.0
|
github.com/onsi/gomega v1.33.0
|
||||||
github.com/prometheus/client_golang v1.19.0
|
github.com/prometheus/client_golang v1.19.0
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/spf13/cobra v1.8.0
|
github.com/spf13/cobra v1.8.0
|
||||||
|
@ -38,7 +38,7 @@ require (
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef
|
github.com/ThinkChaos/parcour v0.0.0-20230710171753-fbf917c9eaef
|
||||||
github.com/deepmap/oapi-codegen v1.16.2
|
github.com/deepmap/oapi-codegen v1.16.2
|
||||||
github.com/docker/docker v26.0.1+incompatible
|
github.com/docker/docker v26.1.0+incompatible
|
||||||
github.com/docker/go-connections v0.5.0
|
github.com/docker/go-connections v0.5.0
|
||||||
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
|
github.com/dosgo/zigtool v0.0.0-20210923085854-9c6fc1d62198
|
||||||
github.com/oapi-codegen/runtime v1.1.1
|
github.com/oapi-codegen/runtime v1.1.1
|
||||||
|
|
16
go.sum
16
go.sum
|
@ -29,8 +29,8 @@ github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7D
|
||||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk=
|
||||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef h1:2JGTg6JapxP9/R33ZaagQtAM4EkkSYnIAlOG5EI8gkM=
|
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef h1:2JGTg6JapxP9/R33ZaagQtAM4EkkSYnIAlOG5EI8gkM=
|
||||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef/go.mod h1:JS7hed4L1fj0hXcyEejnW57/7LCetXggd+vwrRnYeII=
|
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef/go.mod h1:JS7hed4L1fj0hXcyEejnW57/7LCetXggd+vwrRnYeII=
|
||||||
github.com/avast/retry-go/v4 v4.5.1 h1:AxIx0HGi4VZ3I02jr78j5lZ3M6x1E0Ivxa6b0pUUh7o=
|
github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA=
|
||||||
github.com/avast/retry-go/v4 v4.5.1/go.mod h1:/sipNsvNB3RRuT5iNcb6h73nw3IBmXJ/H3XrCQYSOpc=
|
github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
|
||||||
|
@ -64,8 +64,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0=
|
github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0=
|
||||||
github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
github.com/docker/docker v26.0.1+incompatible h1:t39Hm6lpXuXtgkF0dm1t9a5HkbUfdGy6XbWexmGr+hA=
|
github.com/docker/docker v26.1.0+incompatible h1:W1G9MPNbskA6VZWL7b3ZljTh0pXI68FpINx0GKaOdaM=
|
||||||
github.com/docker/docker v26.0.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
github.com/docker/docker v26.1.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
@ -193,8 +193,8 @@ github.com/mattn/goveralls v0.0.12 h1:PEEeF0k1SsTjOBQ8FOmrOAoCu4ytuMaWCnWe94zxbC
|
||||||
github.com/mattn/goveralls v0.0.12/go.mod h1:44ImGEUfmqH8bBtaMrYKsM65LXfNLWmwaxFGjZwgMSQ=
|
github.com/mattn/goveralls v0.0.12/go.mod h1:44ImGEUfmqH8bBtaMrYKsM65LXfNLWmwaxFGjZwgMSQ=
|
||||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
|
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
|
||||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
|
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
|
||||||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
|
||||||
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
|
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
|
||||||
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||||
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
||||||
|
@ -225,8 +225,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||||
github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8=
|
github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8=
|
||||||
github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs=
|
github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs=
|
||||||
github.com/onsi/gomega v1.32.0 h1:JRYU78fJ1LPxlckP6Txi/EYqJvjtMrDC04/MM5XRHPk=
|
github.com/onsi/gomega v1.33.0 h1:snPCflnZrpMsy94p4lXVEkHo12lmPnc3vY5XBbreexE=
|
||||||
github.com/onsi/gomega v1.32.0/go.mod h1:a4x4gW6Pz2yK1MAmvluYme5lvYTn61afQ2ETw/8n4Lg=
|
github.com/onsi/gomega v1.33.0/go.mod h1:+925n5YtiFsLzzafLUHzVMBpvvRAzrydIBiSIxjX3wY=
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
@ -29,20 +30,27 @@ const (
|
||||||
DS = dns.Type(dns.TypeDS)
|
DS = dns.Type(dns.TypeDS)
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetIntPort returns an port for the current testing
|
// GetIntPort returns a port for the current testing
|
||||||
// process by adding the current ginkgo parallel process to
|
// process by adding the current ginkgo parallel process to
|
||||||
// the base port and returning it as int
|
// the base port and returning it as int.
|
||||||
func GetIntPort(port int) int {
|
func GetIntPort(port int) int {
|
||||||
return port + ginkgo.GinkgoParallelProcess()
|
return port + ginkgo.GinkgoParallelProcess()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStringPort returns an port for the current testing
|
// GetStringPort returns a port for the current testing
|
||||||
// process by adding the current ginkgo parallel process to
|
// process by adding the current ginkgo parallel process to
|
||||||
// the base port and returning it as string
|
// the base port and returning it as string.
|
||||||
func GetStringPort(port int) string {
|
func GetStringPort(port int) string {
|
||||||
return fmt.Sprintf("%d", GetIntPort(port))
|
return fmt.Sprintf("%d", GetIntPort(port))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetHostPort returns a host:port string for the current testing
|
||||||
|
// process by adding the current ginkgo parallel process to
|
||||||
|
// the base port and returning it as string.
|
||||||
|
func GetHostPort(host string, port int) string {
|
||||||
|
return net.JoinHostPort(host, GetStringPort(port))
|
||||||
|
}
|
||||||
|
|
||||||
// TempFile creates temp file with passed data
|
// TempFile creates temp file with passed data
|
||||||
func TempFile(data string) *os.File {
|
func TempFile(data string) *os.File {
|
||||||
f, err := os.CreateTemp("", "prefix")
|
f, err := os.CreateTemp("", "prefix")
|
||||||
|
|
|
@ -1,12 +1,8 @@
|
||||||
package metrics
|
package metrics
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/0xERR0R/blocky/config"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//nolint:gochecknoglobals
|
//nolint:gochecknoglobals
|
||||||
|
@ -17,12 +13,9 @@ func RegisterMetric(c prometheus.Collector) {
|
||||||
_ = reg.Register(c)
|
_ = reg.Register(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts prometheus endpoint
|
func StartCollection() {
|
||||||
func Start(router *chi.Mux, cfg config.Metrics) {
|
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
||||||
if cfg.Enable {
|
_ = reg.Register(collectors.NewGoCollector())
|
||||||
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
|
||||||
_ = reg.Register(collectors.NewGoCollector())
|
registerEventListeners()
|
||||||
router.Handle(cfg.Path, promhttp.InstrumentMetricHandler(reg,
|
|
||||||
promhttp.HandlerFor(reg, promhttp.HandlerOpts{})))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,8 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterEventListeners registers all metric handlers by the event bus
|
// registerEventListeners registers all metric handlers on the event bus
|
||||||
func RegisterEventListeners() {
|
func registerEventListeners() {
|
||||||
registerBlockingEventListeners()
|
registerBlockingEventListeners()
|
||||||
registerCachingEventListeners()
|
registerCachingEventListeners()
|
||||||
registerApplicationEventListeners()
|
registerApplicationEventListeners()
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
"github.com/0xERR0R/blocky/service"
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service implements service.HTTPService.
|
||||||
|
type Service struct {
|
||||||
|
service.SimpleHTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service {
|
||||||
|
endpoints := util.ConcatSlices(
|
||||||
|
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||||
|
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !metricsCfg.Enable || len(endpoints) == 0 {
|
||||||
|
// Avoid setting up collectors and listeners
|
||||||
|
return new(Service)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Service{
|
||||||
|
SimpleHTTP: service.NewSimpleHTTP("Metrics", endpoints),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Router().Handle(
|
||||||
|
metricsCfg.Path,
|
||||||
|
promhttp.InstrumentMetricHandler(reg, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})),
|
||||||
|
)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
"github.com/0xERR0R/blocky/service"
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dohService struct {
|
||||||
|
service.SimpleHTTP
|
||||||
|
|
||||||
|
handler dnsHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService {
|
||||||
|
endpoints := util.ConcatSlices(
|
||||||
|
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||||
|
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||||
|
)
|
||||||
|
|
||||||
|
s := &dohService{
|
||||||
|
SimpleHTTP: service.NewSimpleHTTP("DoH", endpoints),
|
||||||
|
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Router().Route("/dns-query", func(mux chi.Router) {
|
||||||
|
// Handlers for / also handle /dns-query without trailing slash
|
||||||
|
|
||||||
|
mux.Get("/", s.handleGET)
|
||||||
|
mux.Get("/{clientID}", s.handleGET)
|
||||||
|
|
||||||
|
mux.Post("/", s.handlePOST)
|
||||||
|
mux.Post("/{clientID}", s.handlePOST)
|
||||||
|
})
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
dnsParam, ok := req.URL.Query()["dns"]
|
||||||
|
if !ok || len(dnsParam[0]) < 1 {
|
||||||
|
http.Error(rw, "dns param is missing", http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, "wrong message format", http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawMsg) > dohMessageLimit {
|
||||||
|
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.processDohMessage(rawMsg, rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
contentType := req.Header.Get("Content-type")
|
||||||
|
if contentType != dnsContentType {
|
||||||
|
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawMsg, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawMsg) > dohMessageLimit {
|
||||||
|
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.processDohMessage(rawMsg, rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(rawMsg); err != nil {
|
||||||
|
logger().Error("can't deserialize message: ", err)
|
||||||
|
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
|
||||||
|
|
||||||
|
s.handler(ctx, dnsReq, httpMsgWriter{rw})
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpMsgWriter struct {
|
||||||
|
rw http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
|
||||||
|
b, err := msg.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rw.Header().Set("content-type", dnsContentType)
|
||||||
|
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
|
||||||
|
r.rw.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
_, err = r.rw.Write(b)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/0xERR0R/blocky/config"
|
||||||
|
"github.com/0xERR0R/blocky/service"
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/go-chi/cors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpMiscService implements service.HTTPService.
|
||||||
|
//
|
||||||
|
// This supports the existing single HTTP/HTTPS endpoints
|
||||||
|
// that expose everything. The goal is to split it up
|
||||||
|
// and remove it.
|
||||||
|
type httpMiscService struct {
|
||||||
|
service.SimpleHTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPMiscService(cfg *config.Config) *httpMiscService {
|
||||||
|
endpoints := util.ConcatSlices(
|
||||||
|
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
|
||||||
|
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
|
||||||
|
)
|
||||||
|
|
||||||
|
s := &httpMiscService{
|
||||||
|
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
|
||||||
|
}
|
||||||
|
|
||||||
|
configureHTTPRouter(s.Router(), cfg)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpServer implements subServer for HTTP.
|
||||||
|
type httpServer struct {
|
||||||
|
service.HTTPService
|
||||||
|
|
||||||
|
inner http.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPServer(svc service.HTTPService) *httpServer {
|
||||||
|
const (
|
||||||
|
readHeaderTimeout = 20 * time.Second
|
||||||
|
readTimeout = 20 * time.Second
|
||||||
|
writeTimeout = 20 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
return &httpServer{
|
||||||
|
HTTPService: svc,
|
||||||
|
|
||||||
|
inner: http.Server{
|
||||||
|
Handler: withCommonMiddleware(svc.Router()),
|
||||||
|
ReadHeaderTimeout: readHeaderTimeout,
|
||||||
|
ReadTimeout: readTimeout,
|
||||||
|
WriteTimeout: writeTimeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
s.inner.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s.inner.Serve(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func withCommonMiddleware(inner http.Handler) *chi.Mux {
|
||||||
|
// Middleware must be defined before routes, so
|
||||||
|
// create a new router and mount the inner handler
|
||||||
|
mux := chi.NewMux()
|
||||||
|
|
||||||
|
mux.Use(
|
||||||
|
secureHeadersMiddleware,
|
||||||
|
newCORSMiddleware(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mux.Mount("/", inner)
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpMiddleware = func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
func secureHeadersMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.TLS != nil {
|
||||||
|
w.Header().Set("strict-transport-security", "max-age=63072000")
|
||||||
|
w.Header().Set("x-frame-options", "DENY")
|
||||||
|
w.Header().Set("x-content-type-options", "nosniff")
|
||||||
|
w.Header().Set("x-xss-protection", "1; mode=block")
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCORSMiddleware() httpMiddleware {
|
||||||
|
const corsMaxAge = 5 * time.Minute
|
||||||
|
|
||||||
|
options := cors.Options{
|
||||||
|
AllowCredentials: true,
|
||||||
|
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||||
|
AllowedMethods: []string{"GET", "POST"},
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
ExposedHeaders: []string{"Link"},
|
||||||
|
MaxAge: int(corsMaxAge.Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
return cors.New(options).Handler
|
||||||
|
}
|
255
server/server.go
255
server/server.go
|
@ -18,15 +18,20 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/0xERR0R/blocky/api"
|
||||||
"github.com/0xERR0R/blocky/config"
|
"github.com/0xERR0R/blocky/config"
|
||||||
"github.com/0xERR0R/blocky/log"
|
"github.com/0xERR0R/blocky/log"
|
||||||
"github.com/0xERR0R/blocky/metrics"
|
"github.com/0xERR0R/blocky/metrics"
|
||||||
"github.com/0xERR0R/blocky/model"
|
"github.com/0xERR0R/blocky/model"
|
||||||
"github.com/0xERR0R/blocky/redis"
|
"github.com/0xERR0R/blocky/redis"
|
||||||
"github.com/0xERR0R/blocky/resolver"
|
"github.com/0xERR0R/blocky/resolver"
|
||||||
|
"github.com/0xERR0R/blocky/service"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/util"
|
"github.com/0xERR0R/blocky/util"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
@ -44,14 +49,18 @@ const (
|
||||||
|
|
||||||
// Server controls the endpoints for DNS and HTTP
|
// Server controls the endpoints for DNS and HTTP
|
||||||
type Server struct {
|
type Server struct {
|
||||||
dnsServers []*dns.Server
|
dnsServers []*dns.Server
|
||||||
httpListeners []net.Listener
|
queryResolver resolver.ChainedResolver
|
||||||
httpsListeners []net.Listener
|
cfg *config.Config
|
||||||
queryResolver resolver.ChainedResolver
|
|
||||||
cfg *config.Config
|
services map[service.Listener]service.Service
|
||||||
httpMux *chi.Mux
|
}
|
||||||
httpsMux *chi.Mux
|
|
||||||
cert tls.Certificate
|
type subServer interface {
|
||||||
|
fmt.Stringer
|
||||||
|
service.Service
|
||||||
|
|
||||||
|
Serve(context.Context, net.Listener) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func logger() *logrus.Entry {
|
func logger() *logrus.Entry {
|
||||||
|
@ -71,14 +80,6 @@ func tlsCipherSuites() []uint16 {
|
||||||
return tlsCipherSuites
|
return tlsCipherSuites
|
||||||
}
|
}
|
||||||
|
|
||||||
func getServerAddress(addr string) string {
|
|
||||||
if !strings.Contains(addr, ":") {
|
|
||||||
addr = fmt.Sprintf(":%s", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type NewServerFunc func(address string) (*dns.Server, error)
|
type NewServerFunc func(address string) (*dns.Server, error)
|
||||||
|
|
||||||
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
||||||
|
@ -99,39 +100,47 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates new server instance with passed config
|
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
|
||||||
//
|
|
||||||
//nolint:funlen
|
|
||||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
|
||||||
var cert tls.Certificate
|
var cert tls.Certificate
|
||||||
|
|
||||||
|
cert, err := retrieveCertificate(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("can't retrieve cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// #nosec G402 // See TLSVersion.validate
|
||||||
|
res := &tls.Config{
|
||||||
|
MinVersion: uint16(cfg.MinTLSServeVer),
|
||||||
|
CipherSuites: tlsCipherSuites(),
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates new server instance with passed config
|
||||||
|
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||||
|
cfg.CopyPortsToServices()
|
||||||
|
|
||||||
|
var tlsCfg *tls.Config
|
||||||
|
|
||||||
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
|
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
|
||||||
cert, err = retrieveCertificate(cfg)
|
tlsCfg, err = newTLSConfig(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't retrieve cert: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServers, err := createServers(cfg, cert)
|
dnsServers, err := createServers(cfg, tlsCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("server creation failed: %w", err)
|
return nil, fmt.Errorf("server creation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpRouter := createHTTPRouter(cfg)
|
listeners, err := createListeners(ctx, cfg, tlsCfg)
|
||||||
httpsRouter := createHTTPSRouter(cfg)
|
|
||||||
|
|
||||||
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(httpListeners) != 0 || len(httpsListeners) != 0 {
|
|
||||||
metrics.Start(httpRouter, cfg.Prometheus)
|
|
||||||
metrics.Start(httpsRouter, cfg.Prometheus)
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.RegisterEventListeners()
|
|
||||||
|
|
||||||
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
|
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -151,27 +160,21 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
||||||
}
|
}
|
||||||
|
|
||||||
server = &Server{
|
server = &Server{
|
||||||
dnsServers: dnsServers,
|
dnsServers: dnsServers,
|
||||||
queryResolver: queryResolver,
|
queryResolver: queryResolver,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
httpListeners: httpListeners,
|
|
||||||
httpsListeners: httpsListeners,
|
|
||||||
httpMux: httpRouter,
|
|
||||||
httpsMux: httpsRouter,
|
|
||||||
cert: cert,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
server.printConfiguration()
|
server.printConfiguration()
|
||||||
|
|
||||||
server.registerDNSHandlers(ctx)
|
server.registerDNSHandlers(ctx)
|
||||||
err = server.registerAPIEndpoints(httpRouter)
|
|
||||||
|
|
||||||
|
services, err := server.createServices()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = server.registerAPIEndpoints(httpsRouter)
|
server.services, err = service.GroupByListener(services, listeners)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -179,14 +182,35 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
|
||||||
return server, err
|
return server, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) {
|
func (s *Server) createServices() ([]service.Service, error) {
|
||||||
|
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := []service.Service{
|
||||||
|
newHTTPMiscService(s.cfg),
|
||||||
|
newDoHService(s.cfg.Services.DoH, s.handleReq),
|
||||||
|
api.NewService(s.cfg.Services.API, openAPIImpl),
|
||||||
|
metrics.NewService(s.cfg.Services.Metrics, s.cfg.Prometheus),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove services the user has not enabled
|
||||||
|
res = slices.DeleteFunc(res, func(svc service.Service) bool {
|
||||||
|
return len(svc.ExposeOn()) == 0
|
||||||
|
})
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
|
||||||
var dnsServers []*dns.Server
|
var dnsServers []*dns.Server
|
||||||
|
|
||||||
var err *multierror.Error
|
var err *multierror.Error
|
||||||
|
|
||||||
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
|
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
server, err := newServer(getServerAddress(address))
|
server, err := newServer(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -201,52 +225,69 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
|
||||||
addServers(createUDPServer, cfg.Ports.DNS),
|
addServers(createUDPServer, cfg.Ports.DNS),
|
||||||
addServers(createTCPServer, cfg.Ports.DNS),
|
addServers(createTCPServer, cfg.Ports.DNS),
|
||||||
addServers(func(address string) (*dns.Server, error) {
|
addServers(func(address string) (*dns.Server, error) {
|
||||||
return createTLSServer(cfg, address, cert)
|
return createTLSServer(address, tlsCfg)
|
||||||
}, cfg.Ports.TLS))
|
}, cfg.Ports.TLS))
|
||||||
|
|
||||||
return dnsServers, err.ErrorOrNil()
|
return dnsServers, err.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
|
func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) {
|
||||||
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
|
res := make(map[string]service.Listener)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) {
|
||||||
|
return service.ListenTLS(ctx, endpoint, tlsCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
|
err := errors.Join(
|
||||||
|
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
|
||||||
|
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
|
||||||
|
newListeners(ctx, service.HTTPProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res),
|
||||||
|
newListeners(ctx, service.HTTPSProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res),
|
||||||
|
newListeners(ctx, service.HTTPProtocol, cfg.Services.Metrics.Addrs.HTTP, service.ListenTCP, res),
|
||||||
|
newListeners(ctx, service.HTTPSProtocol, cfg.Services.Metrics.Addrs.HTTPS, listenTLS, res),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return httpListeners, httpsListeners, nil
|
return maps.Values(res), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
|
type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error)
|
||||||
listeners := make([]net.Listener, 0, len(addresses))
|
|
||||||
|
|
||||||
for _, address := range addresses {
|
func newListeners[T service.Listener](
|
||||||
listener, err := net.Listen("tcp", getServerAddress(address))
|
ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener,
|
||||||
if err != nil {
|
) error {
|
||||||
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
|
for _, addr := range addrs {
|
||||||
|
key := fmt.Sprintf("%s:%s", proto, addr)
|
||||||
|
if _, ok := out[key]; ok {
|
||||||
|
// Avoid "address already in use"
|
||||||
|
// We instead try to merge services, see services.GroupByListener
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
listeners = append(listeners, listener)
|
endpoint := service.Endpoint{
|
||||||
|
Protocol: proto,
|
||||||
|
AddrConf: addr,
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := listen(ctx, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err // already has all info
|
||||||
|
}
|
||||||
|
|
||||||
|
out[key] = l
|
||||||
}
|
}
|
||||||
|
|
||||||
return listeners, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) {
|
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
|
||||||
return &dns.Server{
|
return &dns.Server{
|
||||||
Addr: address,
|
Addr: address,
|
||||||
Net: "tcp-tls",
|
Net: "tcp-tls",
|
||||||
//nolint:gosec
|
TLSConfig: tlsCfg,
|
||||||
TLSConfig: &tls.Config{
|
Handler: dns.NewServeMux(),
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
MinVersion: uint16(cfg.MinTLSServeVer),
|
|
||||||
CipherSuites: tlsCipherSuites(),
|
|
||||||
},
|
|
||||||
Handler: dns.NewServeMux(),
|
|
||||||
NotifyStartedFunc: func() {
|
NotifyStartedFunc: func() {
|
||||||
logger().Infof("TLS server is up and running on address %s", address)
|
logger().Infof("TLS server is up and running on address %s", address)
|
||||||
},
|
},
|
||||||
|
@ -470,11 +511,15 @@ func toMB(b uint64) uint64 {
|
||||||
return b / bytesInKB / bytesInKB
|
return b / bytesInKB / bytesInKB
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func newSubServer(svc service.Service) (subServer, error) {
|
||||||
readHeaderTimeout = 20 * time.Second
|
switch svc := svc.(type) {
|
||||||
readTimeout = 20 * time.Second
|
case service.HTTPService:
|
||||||
writeTimeout = 20 * time.Second
|
return newHTTPServer(svc), nil
|
||||||
)
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the server
|
// Start starts the server
|
||||||
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
||||||
|
@ -490,48 +535,22 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, listener := range s.httpListeners {
|
for listener, svc := range s.services {
|
||||||
listener := listener
|
listener, svc := listener, svc
|
||||||
address := s.cfg.Ports.HTTP[i]
|
|
||||||
|
srv, err := newSubServer(svc)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
logger().Infof("http server is up and running on addr/port %s", address)
|
logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes())
|
||||||
|
|
||||||
srv := &http.Server{
|
err := srv.Serve(ctx, listener)
|
||||||
ReadTimeout: readTimeout,
|
if err != nil {
|
||||||
ReadHeaderTimeout: readHeaderTimeout,
|
errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err)
|
||||||
WriteTimeout: writeTimeout,
|
|
||||||
Handler: s.httpsMux,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := srv.Serve(listener); err != nil {
|
|
||||||
errCh <- fmt.Errorf("start http listener failed: %w", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, listener := range s.httpsListeners {
|
|
||||||
listener := listener
|
|
||||||
address := s.cfg.Ports.HTTPS[i]
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
logger().Infof("https server is up and running on addr/port %s", address)
|
|
||||||
|
|
||||||
server := http.Server{
|
|
||||||
Handler: s.httpsMux,
|
|
||||||
ReadTimeout: readTimeout,
|
|
||||||
ReadHeaderTimeout: readHeaderTimeout,
|
|
||||||
WriteTimeout: writeTimeout,
|
|
||||||
//nolint:gosec
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
MinVersion: uint16(s.cfg.MinTLSServeVer),
|
|
||||||
CipherSuites: tlsCipherSuites(),
|
|
||||||
Certificates: []tls.Certificate{s.cert},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := server.ServeTLS(listener, "", ""); err != nil {
|
|
||||||
errCh <- fmt.Errorf("start https listener failed: %w", err)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -630,6 +649,8 @@ type msgWriter interface {
|
||||||
WriteMsg(msg *dns.Msg) error
|
WriteMsg(msg *dns.Msg) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dnsHandler func(context.Context, *model.Request, msgWriter)
|
||||||
|
|
||||||
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
|
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
|
||||||
response, err := s.resolve(ctx, request)
|
response, err := s.resolve(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -2,13 +2,10 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/0xERR0R/blocky/resolver"
|
"github.com/0xERR0R/blocky/resolver"
|
||||||
|
|
||||||
|
@ -22,7 +19,6 @@ import (
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/go-chi/cors"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,19 +28,8 @@ const (
|
||||||
dnsContentType = "application/dns-message"
|
dnsContentType = "application/dns-message"
|
||||||
htmlContentType = "text/html; charset=UTF-8"
|
htmlContentType = "text/html; charset=UTF-8"
|
||||||
yamlContentType = "text/yaml"
|
yamlContentType = "text/yaml"
|
||||||
corsMaxAge = 5 * time.Minute
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func secureHeader(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("strict-transport-security", "max-age=63072000")
|
|
||||||
w.Header().Set("x-frame-options", "DENY")
|
|
||||||
w.Header().Set("x-content-type-options", "nosniff")
|
|
||||||
w.Header().Set("x-xss-protection", "1; mode=block")
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) {
|
func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) {
|
||||||
bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver)
|
bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -64,108 +49,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
|
||||||
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
|
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) registerAPIEndpoints(router *chi.Mux) error {
|
|
||||||
const pathDohQuery = "/dns-query"
|
|
||||||
|
|
||||||
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
|
|
||||||
|
|
||||||
router.Get(pathDohQuery, s.dohGetRequestHandler)
|
|
||||||
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
|
|
||||||
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
|
|
||||||
router.Post(pathDohQuery, s.dohPostRequestHandler)
|
|
||||||
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
|
|
||||||
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
dnsParam, ok := req.URL.Query()["dns"]
|
|
||||||
if !ok || len(dnsParam[0]) < 1 {
|
|
||||||
http.Error(rw, "dns param is missing", http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
|
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, "wrong message format", http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rawMsg) > dohMessageLimit {
|
|
||||||
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.processDohMessage(rawMsg, rw, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
contentType := req.Header.Get("Content-type")
|
|
||||||
if contentType != dnsContentType {
|
|
||||||
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawMsg, err := io.ReadAll(req.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rawMsg) > dohMessageLimit {
|
|
||||||
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.processDohMessage(rawMsg, rw, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(rawMsg); err != nil {
|
|
||||||
logger().Error("can't deserialize message: ", err)
|
|
||||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
|
|
||||||
|
|
||||||
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
|
|
||||||
}
|
|
||||||
|
|
||||||
type httpMsgWriter struct {
|
|
||||||
rw http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
|
|
||||||
b, err := msg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rw.Header().Set("content-type", dnsContentType)
|
|
||||||
|
|
||||||
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
|
|
||||||
r.rw.WriteHeader(http.StatusOK)
|
|
||||||
|
|
||||||
_, err = r.rw.Write(b)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Query(
|
func (s *Server) Query(
|
||||||
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
||||||
) (*model.Response, error) {
|
) (*model.Response, error) {
|
||||||
|
@ -177,27 +60,7 @@ func (s *Server) Query(
|
||||||
return s.resolve(ctx, req)
|
return s.resolve(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
func configureHTTPRouter(router chi.Router, cfg *config.Config) {
|
||||||
router := chi.NewRouter()
|
|
||||||
|
|
||||||
configureSecureHeaderHandler(router)
|
|
||||||
|
|
||||||
registerHandlers(cfg, router)
|
|
||||||
|
|
||||||
return router
|
|
||||||
}
|
|
||||||
|
|
||||||
func createHTTPRouter(cfg *config.Config) *chi.Mux {
|
|
||||||
router := chi.NewRouter()
|
|
||||||
|
|
||||||
registerHandlers(cfg, router)
|
|
||||||
|
|
||||||
return router
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerHandlers(cfg *config.Config, router *chi.Mux) {
|
|
||||||
configureCorsHandler(router)
|
|
||||||
|
|
||||||
configureDebugHandler(router)
|
configureDebugHandler(router)
|
||||||
|
|
||||||
configureDocsHandler(router)
|
configureDocsHandler(router)
|
||||||
|
@ -207,7 +70,7 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) {
|
||||||
configureRootHandler(cfg, router)
|
configureRootHandler(cfg, router)
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureDocsHandler(router *chi.Mux) {
|
func configureDocsHandler(router chi.Router) {
|
||||||
router.Get("/docs/openapi.yaml", func(writer http.ResponseWriter, request *http.Request) {
|
router.Get("/docs/openapi.yaml", func(writer http.ResponseWriter, request *http.Request) {
|
||||||
writer.Header().Set(contentTypeHeader, yamlContentType)
|
writer.Header().Set(contentTypeHeader, yamlContentType)
|
||||||
_, err := writer.Write([]byte(docs.OpenAPI))
|
_, err := writer.Write([]byte(docs.OpenAPI))
|
||||||
|
@ -215,7 +78,7 @@ func configureDocsHandler(router *chi.Mux) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureStaticAssetsHandler(router *chi.Mux) {
|
func configureStaticAssetsHandler(router chi.Router) {
|
||||||
assets, err := web.Assets()
|
assets, err := web.Assets()
|
||||||
util.FatalOnError("unable to load static asset files", err)
|
util.FatalOnError("unable to load static asset files", err)
|
||||||
|
|
||||||
|
@ -223,7 +86,7 @@ func configureStaticAssetsHandler(router *chi.Mux) {
|
||||||
router.Handle("/static/*", http.StripPrefix("/static/", fs))
|
router.Handle("/static/*", http.StripPrefix("/static/", fs))
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureRootHandler(cfg *config.Config, router *chi.Mux) {
|
func configureRootHandler(cfg *config.Config, router chi.Router) {
|
||||||
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
|
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
|
||||||
writer.Header().Set(contentTypeHeader, htmlContentType)
|
writer.Header().Set(contentTypeHeader, htmlContentType)
|
||||||
|
|
||||||
|
@ -282,22 +145,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureSecureHeaderHandler(router *chi.Mux) {
|
func configureDebugHandler(router chi.Router) {
|
||||||
router.Use(secureHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func configureDebugHandler(router *chi.Mux) {
|
|
||||||
router.Mount("/debug", middleware.Profiler())
|
router.Mount("/debug", middleware.Profiler())
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureCorsHandler(router *chi.Mux) {
|
|
||||||
crs := cors.New(cors.Options{
|
|
||||||
AllowedOrigins: []string{"*"},
|
|
||||||
AllowedMethods: []string{"GET", "POST"},
|
|
||||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
|
||||||
ExposedHeaders: []string{"Link"},
|
|
||||||
AllowCredentials: true,
|
|
||||||
MaxAge: int(corsMaxAge.Seconds()),
|
|
||||||
})
|
|
||||||
router.Use(crs.Handler)
|
|
||||||
}
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = BeforeSuite(func() {
|
var _ = BeforeSuite(func() {
|
||||||
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
|
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
|
||||||
queryURL = baseURL + "dns-query"
|
queryURL = baseURL + "dns-query"
|
||||||
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
|
||||||
ctx, cancelFn := context.WithCancel(context.Background())
|
ctx, cancelFn := context.WithCancel(context.Background())
|
||||||
|
@ -147,10 +147,10 @@ var _ = BeforeSuite(func() {
|
||||||
},
|
},
|
||||||
|
|
||||||
Ports: config.Ports{
|
Ports: config.Ports{
|
||||||
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
|
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
|
||||||
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
|
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
|
||||||
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
|
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
|
||||||
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
|
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
|
||||||
},
|
},
|
||||||
CertFile: certPem.Path,
|
CertFile: certPem.Path,
|
||||||
KeyFile: keyPem.Path,
|
KeyFile: keyPem.Path,
|
||||||
|
@ -634,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
},
|
},
|
||||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||||
Ports: config.Ports{
|
Ports: config.Ports{
|
||||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -678,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
|
||||||
},
|
},
|
||||||
Blocking: config.Blocking{BlockType: "zeroIp"},
|
Blocking: config.Blocking{BlockType: "zeroIp"},
|
||||||
Ports: config.Ports{
|
Ports: config.Ports{
|
||||||
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
|
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -741,17 +741,18 @@ var _ = Describe("Running DNS server", func() {
|
||||||
cfg.KeyFile = ""
|
cfg.KeyFile = ""
|
||||||
cfg.CertFile = ""
|
cfg.CertFile = ""
|
||||||
cfg.Ports = config.Ports{
|
cfg.Ports = config.Ports{
|
||||||
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
|
HTTPS: []string{":0"},
|
||||||
}
|
}
|
||||||
sut, err := NewServer(ctx, &cfg)
|
|
||||||
|
sut, err := newTLSConfig(&cfg)
|
||||||
Expect(err).Should(Succeed())
|
Expect(err).Should(Succeed())
|
||||||
Expect(sut.cert.Certificate).ShouldNot(BeNil())
|
Expect(sut.Certificates).ShouldNot(BeEmpty())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
func requestServer(request *dns.Msg) *dns.Msg {
|
func requestServer(request *dns.Msg) *dns.Msg {
|
||||||
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
|
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log().Fatal("could not connect to server: ", err)
|
Log().Fatal("could not connect to server: ", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Endpoint is a network endpoint on which to expose a service.
|
||||||
|
type Endpoint struct {
|
||||||
|
// Protocol is the protocol to be exposed on this endpoint.
|
||||||
|
Protocol string
|
||||||
|
|
||||||
|
// AddrConf is the network address as configured by the user.
|
||||||
|
AddrConf string
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndpointsFromAddrs(proto string, addrs []string) []Endpoint {
|
||||||
|
return util.ForEach(addrs, func(addr string) Endpoint {
|
||||||
|
return Endpoint{
|
||||||
|
Protocol: proto,
|
||||||
|
AddrConf: addr,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Endpoint) String() string {
|
||||||
|
addr := e.AddrConf
|
||||||
|
if strings.HasPrefix(addr, ":") {
|
||||||
|
addr = "*" + addr
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s://%s", e.Protocol, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type endpointSet map[Endpoint]struct{}
|
||||||
|
|
||||||
|
func (s endpointSet) ToSlice() []Endpoint {
|
||||||
|
return maps.Keys(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s endpointSet) IntersectSlice(others []Endpoint) {
|
||||||
|
for endpoint := range s {
|
||||||
|
if !slices.Contains(others, endpoint) {
|
||||||
|
delete(s, endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,123 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HTTPProtocol = "http"
|
||||||
|
HTTPSProtocol = "https"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPService is a Service using a HTTP router.
|
||||||
|
type HTTPService interface {
|
||||||
|
Service
|
||||||
|
Merger
|
||||||
|
|
||||||
|
// Router returns the service's router.
|
||||||
|
Router() chi.Router
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPInfo can be embedded in structs to help implement HTTPService.
|
||||||
|
type HTTPInfo struct {
|
||||||
|
Info
|
||||||
|
|
||||||
|
mux *chi.Mux
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPInfo(name string, endpoints []Endpoint) HTTPInfo {
|
||||||
|
return HTTPInfo{
|
||||||
|
Info: NewInfo(name, endpoints),
|
||||||
|
|
||||||
|
mux: chi.NewMux(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *HTTPInfo) Router() chi.Router { return i.mux }
|
||||||
|
|
||||||
|
var _ HTTPService = (*SimpleHTTP)(nil)
|
||||||
|
|
||||||
|
// SimpleHTTP implements HTTPService usinig the default HTTP merger.
|
||||||
|
type SimpleHTTP struct{ HTTPInfo }
|
||||||
|
|
||||||
|
func NewSimpleHTTP(name string, endpoints []Endpoint) SimpleHTTP {
|
||||||
|
return SimpleHTTP{HTTPInfo: NewHTTPInfo(name, endpoints)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SimpleHTTP) Merge(other Service) (Merger, error) {
|
||||||
|
return MergeHTTP(s, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeHTTP merges two compatible HTTPServices.
|
||||||
|
//
|
||||||
|
// The second parameter is of type `Service` to make it easy to call
|
||||||
|
// from a `Merger.Merge` implementation.
|
||||||
|
func MergeHTTP(a HTTPService, b Service) (Merger, error) {
|
||||||
|
return newHTTPMerger(a).Merge(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ HTTPService = (*httpMerger)(nil)
|
||||||
|
|
||||||
|
// httpMerger can merge HTTPServices by combining their routes.
|
||||||
|
type httpMerger struct {
|
||||||
|
inner []HTTPService
|
||||||
|
router chi.Router
|
||||||
|
endpoints endpointSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPMerger(first HTTPService) *httpMerger {
|
||||||
|
return &httpMerger{
|
||||||
|
inner: []HTTPService{first},
|
||||||
|
router: first.Router(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *httpMerger) String() string { return svcString(m) }
|
||||||
|
|
||||||
|
func (m *httpMerger) ServiceName() string {
|
||||||
|
names := util.ForEach(m.inner, func(svc HTTPService) string {
|
||||||
|
return svc.ServiceName()
|
||||||
|
})
|
||||||
|
|
||||||
|
return strings.Join(names, " & ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() }
|
||||||
|
func (m *httpMerger) Router() chi.Router { return m.router }
|
||||||
|
|
||||||
|
func (m *httpMerger) Merge(other Service) (Merger, error) {
|
||||||
|
httpSvc, ok := other.(HTTPService)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("not an HTTPService")
|
||||||
|
}
|
||||||
|
|
||||||
|
type middleware = func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined
|
||||||
|
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
|
||||||
|
m.router.With(middlewares...).Method(method, route, handler)
|
||||||
|
|
||||||
|
// Expose /example/ as /example too
|
||||||
|
// Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830
|
||||||
|
// This means we expose the route without the slash even if it wasn't oringinally registered as such!
|
||||||
|
// The main point of this is for DoH (/dns-query).
|
||||||
|
if strings.HasSuffix(route, "/") {
|
||||||
|
route := strings.TrimSuffix(route, "/")
|
||||||
|
m.router.With(middlewares...).Method(method, route, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
m.inner = append(m.inner, httpSvc)
|
||||||
|
|
||||||
|
// Don't expose any service more than it expects
|
||||||
|
m.endpoints.IntersectSlice(other.ExposeOn())
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listener is a net.Listener that provides information about
|
||||||
|
// what protocol and address it is configured for.
|
||||||
|
type Listener interface {
|
||||||
|
fmt.Stringer
|
||||||
|
net.Listener
|
||||||
|
|
||||||
|
// Exposes returns the endpoint for this listener.
|
||||||
|
//
|
||||||
|
// It can be used to find service(s) with matching configuration.
|
||||||
|
Exposes() Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenerInfo can be embedded in structs to help implement Listener.
|
||||||
|
type ListenerInfo struct {
|
||||||
|
Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint }
|
||||||
|
|
||||||
|
// NetListener implements Listener using an existing net.Listener.
|
||||||
|
type NetListener struct {
|
||||||
|
net.Listener
|
||||||
|
ListenerInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener {
|
||||||
|
return &NetListener{
|
||||||
|
Listener: inner,
|
||||||
|
ListenerInfo: ListenerInfo{endpoint},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPListener is a Listener for a TCP socket.
|
||||||
|
type TCPListener struct{ NetListener }
|
||||||
|
|
||||||
|
// ListenTCP creates a new TCPListener.
|
||||||
|
func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) {
|
||||||
|
var lc net.ListenConfig
|
||||||
|
|
||||||
|
l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err // err already has all the info we could add
|
||||||
|
}
|
||||||
|
|
||||||
|
inner := NewNetListener(endpoint, l)
|
||||||
|
|
||||||
|
return &TCPListener{*inner}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSListener is a Listener using TLS over TCP.
|
||||||
|
type TLSListener struct{ NetListener }
|
||||||
|
|
||||||
|
// ListenTLS creates a new TLSListener.
|
||||||
|
func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) {
|
||||||
|
tcp, err := ListenTCP(ctx, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
inner := tcp.NetListener
|
||||||
|
|
||||||
|
inner.Listener = tls.NewListener(inner.Listener, cfg)
|
||||||
|
|
||||||
|
return &TLSListener{inner}, nil
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// Merger is a Service that can be merged with another compatible one.
|
||||||
|
type Merger interface {
|
||||||
|
Service
|
||||||
|
|
||||||
|
// Merge returns the result of merging the receiver with the other Service.
|
||||||
|
//
|
||||||
|
// Neither the receiver, nor the other Service should be used directly after
|
||||||
|
// calling this method.
|
||||||
|
Merge(other Service) (Merger, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeAll merges the given services, if they are compatible.
|
||||||
|
//
|
||||||
|
// This allows using multiple compatible services with a single listener.
|
||||||
|
//
|
||||||
|
// All passed-in services must not be re-used.
|
||||||
|
func MergeAll(services []Service) (Service, error) {
|
||||||
|
switch len(services) {
|
||||||
|
case 0:
|
||||||
|
return nil, errors.New("no services given")
|
||||||
|
|
||||||
|
case 1:
|
||||||
|
return services[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
merger, err := firstMerger(services)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, svc := range services {
|
||||||
|
if svc == merger {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
merger, err = merger.Merge(svc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return merger, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstMerger(services []Service) (Merger, error) {
|
||||||
|
for _, t := range services {
|
||||||
|
if svc, ok := t.(Merger); ok {
|
||||||
|
return svc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("no merger found")
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
// Package service exposes types to abstract services from the networking.
|
||||||
|
//
|
||||||
|
// The idea is that we build a set of services and a set of network endpoints (Listener).
|
||||||
|
// The services are then assigned to endpoints based on the address(es) they were configured for.
|
||||||
|
//
|
||||||
|
// Actual service to endpoint binding is not handled by the abstractions in this package as it is
|
||||||
|
// protocol specific.
|
||||||
|
// The general pattern is to make a "server" that wraps a service, and can then be started on an
|
||||||
|
// endpoint using a `Serve` method, similar to `http.Server`.
|
||||||
|
//
|
||||||
|
// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port),
|
||||||
|
// services can implement `Merger`.
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/0xERR0R/blocky/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service is a network exposed service.
|
||||||
|
//
|
||||||
|
// It contains only the logic and user configured addresses it should be exposed on.
|
||||||
|
// Is is meant to be associated to one or more sockets via those addresses.
|
||||||
|
// Actual association with a socket is protocol specific.
|
||||||
|
type Service interface {
|
||||||
|
fmt.Stringer
|
||||||
|
|
||||||
|
// ServiceName returns the user friendly name of the service.
|
||||||
|
ServiceName() string
|
||||||
|
|
||||||
|
// ExposeOn returns the set of endpoints the service should be exposed on.
|
||||||
|
//
|
||||||
|
// They can be used to find listener(s) with matching configuration.
|
||||||
|
ExposeOn() []Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func svcString(s Service) string {
|
||||||
|
endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() })
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info can be embedded in structs to help implement Service.
|
||||||
|
type Info struct {
|
||||||
|
name string
|
||||||
|
endpoints []Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInfo(name string, endpoints []Endpoint) Info {
|
||||||
|
return Info{
|
||||||
|
name: name,
|
||||||
|
endpoints: endpoints,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Info) ServiceName() string { return i.name }
|
||||||
|
func (i *Info) ExposeOn() []Endpoint { return i.endpoints }
|
||||||
|
func (i *Info) String() string { return svcString(i) }
|
||||||
|
|
||||||
|
// GroupByListener returns a map of listener and services grouped by configured address.
|
||||||
|
//
|
||||||
|
// Each input listener is a key in the map. The corresponding value is a service
|
||||||
|
// merged from all services with a matching address.
|
||||||
|
func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) {
|
||||||
|
res := make(map[Listener]Service, len(listeners))
|
||||||
|
unused := slices.Clone(services)
|
||||||
|
|
||||||
|
for _, listener := range listeners {
|
||||||
|
services := findAllCompatible(services, listener.Exposes())
|
||||||
|
if len(services) == 0 {
|
||||||
|
return nil, fmt.Errorf("found no compatible services for listener %s", listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc, err := MergeAll(services)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res[listener] = svc
|
||||||
|
|
||||||
|
for _, svc := range services {
|
||||||
|
if i := slices.Index(unused, svc); i != -1 {
|
||||||
|
unused = slices.Delete(unused, i, i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unused) != 0 {
|
||||||
|
return nil, fmt.Errorf("found no compatible listener for services: %v", unused)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findAllCompatible returns the subset of services that use the given Listener.
|
||||||
|
func findAllCompatible(services []Service, endpoint Endpoint) []Service {
|
||||||
|
res := make([]Service, 0, len(services))
|
||||||
|
|
||||||
|
for _, svc := range services {
|
||||||
|
if isExposedOn(svc, endpoint) {
|
||||||
|
res = append(res, svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func isExposedOn(svc Service, endpoint Endpoint) bool {
|
||||||
|
return slices.Index(svc.ExposeOn(), endpoint) != -1
|
||||||
|
}
|
18
trie/trie.go
18
trie/trie.go
|
@ -1,5 +1,10 @@
|
||||||
package trie
|
package trie
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/0xERR0R/blocky/log"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// Trie stores a set of strings and can quickly check
|
// Trie stores a set of strings and can quickly check
|
||||||
// if it contains an element, or one of its parents.
|
// if it contains an element, or one of its parents.
|
||||||
//
|
//
|
||||||
|
@ -108,8 +113,12 @@ func (n *parent) insert(key string, split SplitFunc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *parent) hasParentOf(key string, split SplitFunc) bool {
|
func (n *parent) hasParentOf(key string, split SplitFunc) bool {
|
||||||
|
searchString := key
|
||||||
|
rule := ""
|
||||||
|
|
||||||
for {
|
for {
|
||||||
label, rest := split(key)
|
label, rest := split(key)
|
||||||
|
rule = strings.Join([]string{label, rule}, ".")
|
||||||
|
|
||||||
child, ok := n.children[label]
|
child, ok := n.children[label]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -132,7 +141,14 @@ func (n *parent) hasParentOf(key string, split SplitFunc) bool {
|
||||||
|
|
||||||
case terminal:
|
case terminal:
|
||||||
// Continue down the trie
|
// Continue down the trie
|
||||||
return child.hasParentOf(rest, split)
|
matched := child.hasParentOf(rest, split)
|
||||||
|
if matched {
|
||||||
|
rule = strings.Join([]string{child.String(), rule}, ".")
|
||||||
|
rule = strings.Trim(rule, ".")
|
||||||
|
log.PrefixedLog("trie").Debugf("wildcard block rule '%s' matched with '%s'", rule, searchString)
|
||||||
|
}
|
||||||
|
|
||||||
|
return matched
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
// ForEach implements the functional map operation, under a different
|
||||||
|
// name to avoid confusion with Go's map type.
|
||||||
|
func ForEach[T, U any](slice []T, convert func(T) U) []U {
|
||||||
|
res := make([]U, 0, len(slice))
|
||||||
|
|
||||||
|
for _, t := range slice {
|
||||||
|
u := convert(t)
|
||||||
|
|
||||||
|
res = append(res, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConcatSlices returns a new slice with contents of all the inputs concatenated.
|
||||||
|
func ConcatSlices[T any](slices ...[]T) []T {
|
||||||
|
// Allocation is usually the bottleneck, so do it all at once
|
||||||
|
totalLen := 0
|
||||||
|
|
||||||
|
for _, slice := range slices {
|
||||||
|
totalLen += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
res := make([]T, 0, totalLen)
|
||||||
|
|
||||||
|
for _, slice := range slices {
|
||||||
|
res = append(res, slice...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
Loading…
Reference in New Issue