diff --git a/server/http.go b/server/http.go index cac0e810..7c4da323 100644 --- a/server/http.go +++ b/server/http.go @@ -6,17 +6,55 @@ import ( "net/http" "time" + "github.com/0xERR0R/blocky/api" + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/service" + "github.com/0xERR0R/blocky/util" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" ) -type httpServer struct { - inner http.Server - - name string +// httpMiscService implements service.HTTPService. +// +// This supports the existing single HTTP/HTTPS endpoints +// that expose everything. The goal is to split it up +// and remove it. +type httpMiscService struct { + service.HTTPInfo } -func newHTTPServer(name string, handler http.Handler) *httpServer { +func newHTTPMiscService( + cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler, +) *httpMiscService { + endpoints := util.ConcatSlices( + service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP), + service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS), + ) + + return &httpMiscService{ + HTTPInfo: service.HTTPInfo{ + Info: service.Info{ + Name: "HTTP", + Endpoints: endpoints, + }, + + Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler), + }, + } +} + +func (s *httpMiscService) Merge(other service.Service) (service.Merger, error) { + return service.MergeHTTP(s, other) +} + +// httpServer implements subServer for HTTP. +type httpServer struct { + service.HTTPService + + inner http.Server +} + +func newHTTPServer(svc service.HTTPService) *httpServer { const ( readHeaderTimeout = 20 * time.Second readTimeout = 20 * time.Second @@ -24,22 +62,17 @@ func newHTTPServer(name string, handler http.Handler) *httpServer { ) return &httpServer{ + HTTPService: svc, + inner: http.Server{ - ReadTimeout: readTimeout, + Handler: withCommonMiddleware(svc.Router()), ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: readTimeout, WriteTimeout: writeTimeout, - - Handler: withCommonMiddleware(handler), }, - - name: name, } } -func (s *httpServer) String() string { - return s.name -} - func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { go func() { <-ctx.Done() diff --git a/server/server.go b/server/server.go index 5c43bcac..cb27f749 100644 --- a/server/server.go +++ b/server/server.go @@ -18,6 +18,7 @@ import ( "net/http" "runtime" "runtime/debug" + "slices" "strings" "time" @@ -27,6 +28,8 @@ import ( "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/resolver" + "github.com/0xERR0R/blocky/service" + "golang.org/x/exp/maps" "github.com/0xERR0R/blocky/util" "github.com/google/uuid" @@ -49,7 +52,14 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config - servers map[net.Listener]*httpServer + services map[service.Listener]service.Service +} + +type subServer interface { + fmt.Stringer + service.Service + + Serve(context.Context, net.Listener) error } func logger() *logrus.Entry { @@ -108,8 +118,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config -// -//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -125,7 +133,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg) + listeners, err := createListeners(ctx, cfg, tlsCfg) if err != nil { return nil, err } @@ -154,41 +162,43 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err dnsServers: dnsServers, queryResolver: queryResolver, cfg: cfg, - - servers: make(map[net.Listener]*httpServer), } server.printConfiguration() server.registerDNSHandlers(ctx) - openAPIImpl, err := server.createOpenAPIInterfaceImpl() + services, err := server.createServices() if err != nil { return nil, err } - httpRouter := createHTTPRouter(cfg, openAPIImpl) - server.registerDoHEndpoints(httpRouter) - - if len(cfg.Ports.HTTP) != 0 { - srv := newHTTPServer("http", httpRouter) - - for _, l := range httpListeners { - server.servers[l] = srv - } - } - - if len(cfg.Ports.HTTPS) != 0 { - srv := newHTTPServer("https", httpRouter) - - for _, l := range httpsListeners { - server.servers[l] = srv - } + server.services, err = service.GroupByListener(services, listeners) + if err != nil { + return nil, err } return server, err } +func (s *Server) createServices() ([]service.Service, error) { + openAPIImpl, err := s.createOpenAPIInterfaceImpl() + if err != nil { + return nil, err + } + + res := []service.Service{ + newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq), + } + + // Remove services the user has not enabled + res = slices.DeleteFunc(res, func(svc service.Service) bool { + return len(svc.ExposeOn()) == 0 + }) + + return res, nil +} + func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) { var dnsServers []*dns.Server @@ -217,48 +227,51 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error return dnsServers, err.ErrorOrNil() } -func createHTTPListeners( - cfg *config.Config, tlsCfg *tls.Config, -) (httpListeners, httpsListeners []net.Listener, err error) { - httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP) - if err != nil { - return nil, nil, err +func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) { + res := make(map[string]service.Listener) + + listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) { + return service.ListenTLS(ctx, endpoint, tlsCfg) } - httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg) - if err != nil { - return nil, nil, err - } - - return httpListeners, httpsListeners, nil -} - -func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { - listeners := make([]net.Listener, 0, len(addresses)) - - for _, address := range addresses { - listener, err := net.Listen("tcp", address) - if err != nil { - return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err) - } - - listeners = append(listeners, listener) - } - - return listeners, nil -} - -func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) { - listeners, err := newTCPListeners(proto, addresses) + err := errors.Join( + newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res), + newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res), + ) if err != nil { return nil, err } - for i, inner := range listeners { - listeners[i] = tls.NewListener(inner, tlsCfg) + return maps.Values(res), nil +} + +type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error) + +func newListeners[T service.Listener]( + ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener, +) error { + for _, addr := range addrs { + key := fmt.Sprintf("%s:%s", proto, addr) + if _, ok := out[key]; ok { + // Avoid "address already in use" + // We instead try to merge services, see services.GroupByListener + continue + } + + endpoint := service.Endpoint{ + Protocol: proto, + AddrConf: addr, + } + + l, err := listen(ctx, endpoint) + if err != nil { + return err // already has all info + } + + out[key] = l } - return listeners, nil + return nil } func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { @@ -490,6 +503,16 @@ func toMB(b uint64) uint64 { return b / bytesInKB / bytesInKB } +func newSubServer(svc service.Service) (subServer, error) { + switch svc := svc.(type) { + case service.HTTPService: + return newHTTPServer(svc), nil + + default: + return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc) + } +} + // Start starts the server func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Info("Starting server") @@ -504,11 +527,18 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { }() } - for listener, srv := range s.servers { - listener, srv := listener, srv + for listener, svc := range s.services { + listener, svc := listener, svc + + srv, err := newSubServer(svc) + if err != nil { + errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err) + + return + } go func() { - logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr()) + logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes()) err := srv.Serve(ctx, listener) if err != nil { @@ -611,6 +641,8 @@ type msgWriter interface { WriteMsg(msg *dns.Msg) error } +type dnsHandler func(context.Context, *model.Request, msgWriter) + func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) { response, err := s.resolve(ctx, request) if err != nil { diff --git a/server/server_endpoints.go b/server/server_endpoints.go index af862a98..d61a44da 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -52,9 +52,11 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func (s *Server) registerDoHEndpoints(router *chi.Mux) { +func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) { const pathDohQuery = "/dns-query" + s := &dohServer{dnsHandler} + router.Get(pathDohQuery, s.dohGetRequestHandler) router.Get(pathDohQuery+"/", s.dohGetRequestHandler) router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) @@ -63,7 +65,9 @@ func (s *Server) registerDoHEndpoints(router *chi.Mux) { router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) } -func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { +type dohServer struct{ handler dnsHandler } + +func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { dnsParam, ok := req.URL.Query()["dns"] if !ok || len(dnsParam[0]) < 1 { http.Error(rw, "dns param is missing", http.StatusBadRequest) @@ -87,7 +91,7 @@ func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) s.processDohMessage(rawMsg, rw, req) } -func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { +func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { contentType := req.Header.Get("Content-type") if contentType != dnsContentType { http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType) @@ -111,7 +115,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request s.processDohMessage(rawMsg, rw, req) } -func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { +func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { msg := new(dns.Msg) if err := msg.Unpack(rawMsg); err != nil { logger().Error("can't deserialize message: ", err) @@ -122,7 +126,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpRe ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg) - s.handleReq(ctx, dnsReq, httpMsgWriter{rw}) + s.handler(ctx, dnsReq, httpMsgWriter{rw}) } type httpMsgWriter struct { @@ -156,7 +160,7 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux { router := chi.NewRouter() api.RegisterOpenAPIEndpoints(router, openAPIImpl) @@ -169,6 +173,8 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) configureRootHandler(cfg, router) + registerDoHEndpoints(router, dnsHandler) + metrics.Start(router, cfg.Prometheus) return router diff --git a/service/endpoint.go b/service/endpoint.go new file mode 100644 index 00000000..94b1b242 --- /dev/null +++ b/service/endpoint.go @@ -0,0 +1,51 @@ +package service + +import ( + "fmt" + "slices" + "strings" + + "github.com/0xERR0R/blocky/util" + "golang.org/x/exp/maps" +) + +// Endpoint is a network endpoint on which to expose a service. +type Endpoint struct { + // Protocol is the protocol to be exposed on this endpoint. + Protocol string + + // AddrConf is the network address as configured by the user. + AddrConf string +} + +func EndpointsFromAddrs(proto string, addrs []string) []Endpoint { + return util.ForEach(addrs, func(addr string) Endpoint { + return Endpoint{ + Protocol: proto, + AddrConf: addr, + } + }) +} + +func (e Endpoint) String() string { + addr := e.AddrConf + if strings.HasPrefix(addr, ":") { + addr = "*" + addr + } + + return fmt.Sprintf("%s://%s", e.Protocol, addr) +} + +type endpointSet map[Endpoint]struct{} + +func (s endpointSet) ToSlice() []Endpoint { + return maps.Keys(s) +} + +func (s endpointSet) IntersectSlice(others []Endpoint) { + for endpoint := range s { + if !slices.Contains(others, endpoint) { + delete(s, endpoint) + } + } +} diff --git a/service/http.go b/service/http.go new file mode 100644 index 00000000..16d09fd1 --- /dev/null +++ b/service/http.go @@ -0,0 +1,93 @@ +package service + +import ( + "errors" + "net/http" + "strings" + + "github.com/0xERR0R/blocky/util" + "github.com/go-chi/chi/v5" +) + +const ( + HTTPProtocol = "http" + HTTPSProtocol = "https" +) + +// HTTPService is a Service using a HTTP router. +type HTTPService interface { + Service + Merger + + // Router returns the service's router. + Router() chi.Router +} + +// HTTPInfo can be embedded in structs to help implement HTTPService. +type HTTPInfo struct { + Info + + Mux *chi.Mux +} + +func (i *HTTPInfo) Router() chi.Router { return i.Mux } + +// MergeHTTP merges two compatible HTTPServices. +// +// The second parameter is of type `Service` to make it easy to call +// from a `Merger.Merge` implementation. +func MergeHTTP(a HTTPService, b Service) (Merger, error) { + return newHTTPMerger(a).Merge(b) +} + +var _ HTTPService = (*httpMerger)(nil) + +// httpMerger can merge HTTPServices by combining their routes. +type httpMerger struct { + inner []HTTPService + router chi.Router + endpoints endpointSet +} + +func newHTTPMerger(first HTTPService) *httpMerger { + return &httpMerger{ + inner: []HTTPService{first}, + router: first.Router(), + } +} + +func (m *httpMerger) String() string { return svcString(m) } + +func (m *httpMerger) ServiceName() string { + names := util.ForEach(m.inner, func(svc HTTPService) string { + return svc.ServiceName() + }) + + return strings.Join(names, " & ") +} + +func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() } +func (m *httpMerger) Router() chi.Router { return m.router } + +func (m *httpMerger) Merge(other Service) (Merger, error) { + httpSvc, ok := other.(HTTPService) + if !ok { + return nil, errors.New("not an HTTPService") + } + + type middleware = func(http.Handler) http.Handler + + // Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined + _ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error { + m.router.With(middlewares...).Method(method, route, handler) + + return nil + }) + + m.inner = append(m.inner, httpSvc) + + // Don't expose any service more than it expects + m.endpoints.IntersectSlice(other.ExposeOn()) + + return m, nil +} diff --git a/service/listener.go b/service/listener.go new file mode 100644 index 00000000..507f6b33 --- /dev/null +++ b/service/listener.go @@ -0,0 +1,74 @@ +package service + +import ( + "context" + "crypto/tls" + "fmt" + "net" +) + +// Listener is a net.Listener that provides information about +// what protocol and address it is configured for. +type Listener interface { + fmt.Stringer + net.Listener + + // Exposes returns the endpoint for this listener. + // + // It can be used to find service(s) with matching configuration. + Exposes() Endpoint +} + +// ListenerInfo can be embedded in structs to help implement Listener. +type ListenerInfo struct { + Endpoint +} + +func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint } + +// NetListener implements Listener using an existing net.Listener. +type NetListener struct { + net.Listener + ListenerInfo +} + +func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener { + return &NetListener{ + Listener: inner, + ListenerInfo: ListenerInfo{endpoint}, + } +} + +// TCPListener is a Listener for a TCP socket. +type TCPListener struct{ NetListener } + +// ListenTCP creates a new TCPListener. +func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) { + var lc net.ListenConfig + + l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf) + if err != nil { + return nil, err // err already has all the info we could add + } + + inner := NewNetListener(endpoint, l) + + return &TCPListener{*inner}, nil +} + +// TLSListener is a Listener using TLS over TCP. +type TLSListener struct{ NetListener } + +// ListenTLS creates a new TLSListener. +func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) { + tcp, err := ListenTCP(ctx, endpoint) + if err != nil { + return nil, err + } + + inner := tcp.NetListener + + inner.Listener = tls.NewListener(inner.Listener, cfg) + + return &TLSListener{inner}, nil +} diff --git a/service/merge.go b/service/merge.go new file mode 100644 index 00000000..53c9d943 --- /dev/null +++ b/service/merge.go @@ -0,0 +1,57 @@ +package service + +import "errors" + +// Merger is a Service that can be merged with another compatible one. +type Merger interface { + Service + + // Merge returns the result of merging the receiver with the other Service. + // + // Neither the receiver, nor the other Service should be used directly after + // calling this method. + Merge(other Service) (Merger, error) +} + +// MergeAll merges the given services, if they are compatible. +// +// This allows using multiple compatible services with a single listener. +// +// All passed-in services must not be re-used. +func MergeAll(services []Service) (Service, error) { + switch len(services) { + case 0: + return nil, errors.New("no services given") + + case 1: + return services[0], nil + } + + merger, err := firstMerger(services) + if err != nil { + return nil, err + } + + for _, svc := range services { + if svc == merger { + continue + } + + merger, err = merger.Merge(svc) + if err != nil { + return nil, err + } + } + + return merger, nil +} + +func firstMerger(services []Service) (Merger, error) { + for _, t := range services { + if svc, ok := t.(Merger); ok { + return svc, nil + } + } + + return nil, errors.New("no merger found") +} diff --git a/service/service.go b/service/service.go new file mode 100644 index 00000000..45e95dd4 --- /dev/null +++ b/service/service.go @@ -0,0 +1,106 @@ +// Package service exposes types to abstract services from the networking. +// +// The idea is that we build a set of services and a set of network endpoints (Listener). +// The services are then assigned to endpoints based on the address(es) they were configured for. +// +// Actual service to endpoint binding is not handled by the abstractions in this package as it is +// protocol specific. +// The general pattern is to make a "server" that wraps a service, and can then be started on an +// endpoint using a `Serve` method, similar to `http.Server`. +// +// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port), +// services can implement `Merger`. +package service + +import ( + "fmt" + "slices" + "strings" + + "github.com/0xERR0R/blocky/util" +) + +// Service is a network exposed service. +// +// It contains only the logic and user configured addresses it should be exposed on. +// Is is meant to be associated to one or more sockets via those addresses. +// Actual association with a socket is protocol specific. +type Service interface { + fmt.Stringer + + // ServiceName returns the user friendly name of the service. + ServiceName() string + + // ExposeOn returns the set of endpoints the service should be exposed on. + // + // They can be used to find listener(s) with matching configuration. + ExposeOn() []Endpoint +} + +func svcString(s Service) string { + endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() }) + + return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", ")) +} + +// Info can be embedded in structs to help implement Service. +type Info struct { + Name string + Endpoints []Endpoint +} + +func (i *Info) ServiceName() string { return i.Name } +func (i *Info) ExposeOn() []Endpoint { return i.Endpoints } +func (i *Info) String() string { return svcString(i) } + +// GroupByListener returns a map of listener and services grouped by configured address. +// +// Each input listener is a key in the map. The corresponding value is a service +// merged from all services with a matching address. +func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) { + res := make(map[Listener]Service, len(listeners)) + unused := slices.Clone(services) + + for _, listener := range listeners { + services := findAllCompatible(services, listener.Exposes()) + if len(services) == 0 { + return nil, fmt.Errorf("found no compatible services for listener %s", listener) + } + + svc, err := MergeAll(services) + if err != nil { + return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err) + } + + res[listener] = svc + + for _, svc := range services { + if i := slices.Index(unused, svc); i != -1 { + unused = slices.Delete(unused, i, i+1) + } + } + } + + if len(unused) != 0 { + return nil, fmt.Errorf("found no compatible listener for services: %v", unused) + } + + return res, nil +} + +// findAllCompatible returns the subset of services that use the given Listener. +func findAllCompatible(services []Service, endpoint Endpoint) []Service { + res := make([]Service, 0, len(services)) + + for _, svc := range services { + if isExposedOn(svc, endpoint) { + res = append(res, svc) + } + } + + return res +} + +func isExposedOn(svc Service, endpoint Endpoint) bool { + return slices.Index(svc.ExposeOn(), endpoint) != -1 +} diff --git a/util/slices.go b/util/slices.go new file mode 100644 index 00000000..fe10a830 --- /dev/null +++ b/util/slices.go @@ -0,0 +1,33 @@ +package util + +// ForEach implements the functional map operation, under a different +// name to avoid confusion with Go's map type. +func ForEach[T, U any](slice []T, convert func(T) U) []U { + res := make([]U, 0, len(slice)) + + for _, t := range slice { + u := convert(t) + + res = append(res, u) + } + + return res +} + +// ConcatSlices returns a new slice with contents of all the inputs concatenated. +func ConcatSlices[T any](slices ...[]T) []T { + // Allocation is usually the bottleneck, so do it all at once + totalLen := 0 + + for _, slice := range slices { + totalLen += len(slice) + } + + res := make([]T, 0, totalLen) + + for _, slice := range slices { + res = append(res, slice...) + } + + return res +}