From c11f9a1c989631dcfd7d736ff68baedcd2a5dd35 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Wed, 3 Apr 2024 13:40:04 -0400 Subject: [PATCH] refactor: add `service` package to prepare for split HTTP handling Package service exposes types to abstract services from the networking. The idea is that we build a set of services and a set of network endpoints (Listener). The services are then assigned to endpoints based on the address(es) they were configured for. Actual service to endpoint binding is not handled by the abstractions in this package as it is protocol specific. The general pattern is to make a "server" that wraps a service, and can then be started on an endpoint using a `Serve` method, similar to `http.Server`. To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port), services can implement `Merger`. --- server/http.go | 61 +++++++++++---- server/server.go | 154 ++++++++++++++++++++++--------------- server/server_endpoints.go | 18 +++-- service/endpoint.go | 51 ++++++++++++ service/http.go | 93 ++++++++++++++++++++++ service/listener.go | 74 ++++++++++++++++++ service/merge.go | 57 ++++++++++++++ service/service.go | 106 +++++++++++++++++++++++++ util/slices.go | 33 ++++++++ 9 files changed, 566 insertions(+), 81 deletions(-) create mode 100644 service/endpoint.go create mode 100644 service/http.go create mode 100644 service/listener.go create mode 100644 service/merge.go create mode 100644 service/service.go create mode 100644 util/slices.go diff --git a/server/http.go b/server/http.go index cac0e810..7c4da323 100644 --- a/server/http.go +++ b/server/http.go @@ -6,17 +6,55 @@ import ( "net/http" "time" + "github.com/0xERR0R/blocky/api" + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/service" + "github.com/0xERR0R/blocky/util" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" ) -type httpServer struct { - inner http.Server - - name string +// httpMiscService implements service.HTTPService. +// +// This supports the existing single HTTP/HTTPS endpoints +// that expose everything. The goal is to split it up +// and remove it. +type httpMiscService struct { + service.HTTPInfo } -func newHTTPServer(name string, handler http.Handler) *httpServer { +func newHTTPMiscService( + cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler, +) *httpMiscService { + endpoints := util.ConcatSlices( + service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP), + service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS), + ) + + return &httpMiscService{ + HTTPInfo: service.HTTPInfo{ + Info: service.Info{ + Name: "HTTP", + Endpoints: endpoints, + }, + + Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler), + }, + } +} + +func (s *httpMiscService) Merge(other service.Service) (service.Merger, error) { + return service.MergeHTTP(s, other) +} + +// httpServer implements subServer for HTTP. +type httpServer struct { + service.HTTPService + + inner http.Server +} + +func newHTTPServer(svc service.HTTPService) *httpServer { const ( readHeaderTimeout = 20 * time.Second readTimeout = 20 * time.Second @@ -24,22 +62,17 @@ func newHTTPServer(name string, handler http.Handler) *httpServer { ) return &httpServer{ + HTTPService: svc, + inner: http.Server{ - ReadTimeout: readTimeout, + Handler: withCommonMiddleware(svc.Router()), ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: readTimeout, WriteTimeout: writeTimeout, - - Handler: withCommonMiddleware(handler), }, - - name: name, } } -func (s *httpServer) String() string { - return s.name -} - func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { go func() { <-ctx.Done() diff --git a/server/server.go b/server/server.go index 5c43bcac..cb27f749 100644 --- a/server/server.go +++ b/server/server.go @@ -18,6 +18,7 @@ import ( "net/http" "runtime" "runtime/debug" + "slices" "strings" "time" @@ -27,6 +28,8 @@ import ( "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/resolver" + "github.com/0xERR0R/blocky/service" + "golang.org/x/exp/maps" "github.com/0xERR0R/blocky/util" "github.com/google/uuid" @@ -49,7 +52,14 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config - servers map[net.Listener]*httpServer + services map[service.Listener]service.Service +} + +type subServer interface { + fmt.Stringer + service.Service + + Serve(context.Context, net.Listener) error } func logger() *logrus.Entry { @@ -108,8 +118,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config -// -//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -125,7 +133,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg) + listeners, err := createListeners(ctx, cfg, tlsCfg) if err != nil { return nil, err } @@ -154,41 +162,43 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err dnsServers: dnsServers, queryResolver: queryResolver, cfg: cfg, - - servers: make(map[net.Listener]*httpServer), } server.printConfiguration() server.registerDNSHandlers(ctx) - openAPIImpl, err := server.createOpenAPIInterfaceImpl() + services, err := server.createServices() if err != nil { return nil, err } - httpRouter := createHTTPRouter(cfg, openAPIImpl) - server.registerDoHEndpoints(httpRouter) - - if len(cfg.Ports.HTTP) != 0 { - srv := newHTTPServer("http", httpRouter) - - for _, l := range httpListeners { - server.servers[l] = srv - } - } - - if len(cfg.Ports.HTTPS) != 0 { - srv := newHTTPServer("https", httpRouter) - - for _, l := range httpsListeners { - server.servers[l] = srv - } + server.services, err = service.GroupByListener(services, listeners) + if err != nil { + return nil, err } return server, err } +func (s *Server) createServices() ([]service.Service, error) { + openAPIImpl, err := s.createOpenAPIInterfaceImpl() + if err != nil { + return nil, err + } + + res := []service.Service{ + newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq), + } + + // Remove services the user has not enabled + res = slices.DeleteFunc(res, func(svc service.Service) bool { + return len(svc.ExposeOn()) == 0 + }) + + return res, nil +} + func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) { var dnsServers []*dns.Server @@ -217,48 +227,51 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error return dnsServers, err.ErrorOrNil() } -func createHTTPListeners( - cfg *config.Config, tlsCfg *tls.Config, -) (httpListeners, httpsListeners []net.Listener, err error) { - httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP) - if err != nil { - return nil, nil, err +func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) { + res := make(map[string]service.Listener) + + listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) { + return service.ListenTLS(ctx, endpoint, tlsCfg) } - httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg) - if err != nil { - return nil, nil, err - } - - return httpListeners, httpsListeners, nil -} - -func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { - listeners := make([]net.Listener, 0, len(addresses)) - - for _, address := range addresses { - listener, err := net.Listen("tcp", address) - if err != nil { - return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err) - } - - listeners = append(listeners, listener) - } - - return listeners, nil -} - -func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) { - listeners, err := newTCPListeners(proto, addresses) + err := errors.Join( + newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res), + newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res), + ) if err != nil { return nil, err } - for i, inner := range listeners { - listeners[i] = tls.NewListener(inner, tlsCfg) + return maps.Values(res), nil +} + +type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error) + +func newListeners[T service.Listener]( + ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener, +) error { + for _, addr := range addrs { + key := fmt.Sprintf("%s:%s", proto, addr) + if _, ok := out[key]; ok { + // Avoid "address already in use" + // We instead try to merge services, see services.GroupByListener + continue + } + + endpoint := service.Endpoint{ + Protocol: proto, + AddrConf: addr, + } + + l, err := listen(ctx, endpoint) + if err != nil { + return err // already has all info + } + + out[key] = l } - return listeners, nil + return nil } func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { @@ -490,6 +503,16 @@ func toMB(b uint64) uint64 { return b / bytesInKB / bytesInKB } +func newSubServer(svc service.Service) (subServer, error) { + switch svc := svc.(type) { + case service.HTTPService: + return newHTTPServer(svc), nil + + default: + return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc) + } +} + // Start starts the server func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Info("Starting server") @@ -504,11 +527,18 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { }() } - for listener, srv := range s.servers { - listener, srv := listener, srv + for listener, svc := range s.services { + listener, svc := listener, svc + + srv, err := newSubServer(svc) + if err != nil { + errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err) + + return + } go func() { - logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr()) + logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes()) err := srv.Serve(ctx, listener) if err != nil { @@ -611,6 +641,8 @@ type msgWriter interface { WriteMsg(msg *dns.Msg) error } +type dnsHandler func(context.Context, *model.Request, msgWriter) + func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) { response, err := s.resolve(ctx, request) if err != nil { diff --git a/server/server_endpoints.go b/server/server_endpoints.go index af862a98..d61a44da 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -52,9 +52,11 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func (s *Server) registerDoHEndpoints(router *chi.Mux) { +func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) { const pathDohQuery = "/dns-query" + s := &dohServer{dnsHandler} + router.Get(pathDohQuery, s.dohGetRequestHandler) router.Get(pathDohQuery+"/", s.dohGetRequestHandler) router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) @@ -63,7 +65,9 @@ func (s *Server) registerDoHEndpoints(router *chi.Mux) { router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) } -func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { +type dohServer struct{ handler dnsHandler } + +func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { dnsParam, ok := req.URL.Query()["dns"] if !ok || len(dnsParam[0]) < 1 { http.Error(rw, "dns param is missing", http.StatusBadRequest) @@ -87,7 +91,7 @@ func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) s.processDohMessage(rawMsg, rw, req) } -func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { +func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { contentType := req.Header.Get("Content-type") if contentType != dnsContentType { http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType) @@ -111,7 +115,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request s.processDohMessage(rawMsg, rw, req) } -func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { +func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { msg := new(dns.Msg) if err := msg.Unpack(rawMsg); err != nil { logger().Error("can't deserialize message: ", err) @@ -122,7 +126,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpRe ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg) - s.handleReq(ctx, dnsReq, httpMsgWriter{rw}) + s.handler(ctx, dnsReq, httpMsgWriter{rw}) } type httpMsgWriter struct { @@ -156,7 +160,7 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux { router := chi.NewRouter() api.RegisterOpenAPIEndpoints(router, openAPIImpl) @@ -169,6 +173,8 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) configureRootHandler(cfg, router) + registerDoHEndpoints(router, dnsHandler) + metrics.Start(router, cfg.Prometheus) return router diff --git a/service/endpoint.go b/service/endpoint.go new file mode 100644 index 00000000..94b1b242 --- /dev/null +++ b/service/endpoint.go @@ -0,0 +1,51 @@ +package service + +import ( + "fmt" + "slices" + "strings" + + "github.com/0xERR0R/blocky/util" + "golang.org/x/exp/maps" +) + +// Endpoint is a network endpoint on which to expose a service. +type Endpoint struct { + // Protocol is the protocol to be exposed on this endpoint. + Protocol string + + // AddrConf is the network address as configured by the user. + AddrConf string +} + +func EndpointsFromAddrs(proto string, addrs []string) []Endpoint { + return util.ForEach(addrs, func(addr string) Endpoint { + return Endpoint{ + Protocol: proto, + AddrConf: addr, + } + }) +} + +func (e Endpoint) String() string { + addr := e.AddrConf + if strings.HasPrefix(addr, ":") { + addr = "*" + addr + } + + return fmt.Sprintf("%s://%s", e.Protocol, addr) +} + +type endpointSet map[Endpoint]struct{} + +func (s endpointSet) ToSlice() []Endpoint { + return maps.Keys(s) +} + +func (s endpointSet) IntersectSlice(others []Endpoint) { + for endpoint := range s { + if !slices.Contains(others, endpoint) { + delete(s, endpoint) + } + } +} diff --git a/service/http.go b/service/http.go new file mode 100644 index 00000000..16d09fd1 --- /dev/null +++ b/service/http.go @@ -0,0 +1,93 @@ +package service + +import ( + "errors" + "net/http" + "strings" + + "github.com/0xERR0R/blocky/util" + "github.com/go-chi/chi/v5" +) + +const ( + HTTPProtocol = "http" + HTTPSProtocol = "https" +) + +// HTTPService is a Service using a HTTP router. +type HTTPService interface { + Service + Merger + + // Router returns the service's router. + Router() chi.Router +} + +// HTTPInfo can be embedded in structs to help implement HTTPService. +type HTTPInfo struct { + Info + + Mux *chi.Mux +} + +func (i *HTTPInfo) Router() chi.Router { return i.Mux } + +// MergeHTTP merges two compatible HTTPServices. +// +// The second parameter is of type `Service` to make it easy to call +// from a `Merger.Merge` implementation. +func MergeHTTP(a HTTPService, b Service) (Merger, error) { + return newHTTPMerger(a).Merge(b) +} + +var _ HTTPService = (*httpMerger)(nil) + +// httpMerger can merge HTTPServices by combining their routes. +type httpMerger struct { + inner []HTTPService + router chi.Router + endpoints endpointSet +} + +func newHTTPMerger(first HTTPService) *httpMerger { + return &httpMerger{ + inner: []HTTPService{first}, + router: first.Router(), + } +} + +func (m *httpMerger) String() string { return svcString(m) } + +func (m *httpMerger) ServiceName() string { + names := util.ForEach(m.inner, func(svc HTTPService) string { + return svc.ServiceName() + }) + + return strings.Join(names, " & ") +} + +func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() } +func (m *httpMerger) Router() chi.Router { return m.router } + +func (m *httpMerger) Merge(other Service) (Merger, error) { + httpSvc, ok := other.(HTTPService) + if !ok { + return nil, errors.New("not an HTTPService") + } + + type middleware = func(http.Handler) http.Handler + + // Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined + _ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error { + m.router.With(middlewares...).Method(method, route, handler) + + return nil + }) + + m.inner = append(m.inner, httpSvc) + + // Don't expose any service more than it expects + m.endpoints.IntersectSlice(other.ExposeOn()) + + return m, nil +} diff --git a/service/listener.go b/service/listener.go new file mode 100644 index 00000000..507f6b33 --- /dev/null +++ b/service/listener.go @@ -0,0 +1,74 @@ +package service + +import ( + "context" + "crypto/tls" + "fmt" + "net" +) + +// Listener is a net.Listener that provides information about +// what protocol and address it is configured for. +type Listener interface { + fmt.Stringer + net.Listener + + // Exposes returns the endpoint for this listener. + // + // It can be used to find service(s) with matching configuration. + Exposes() Endpoint +} + +// ListenerInfo can be embedded in structs to help implement Listener. +type ListenerInfo struct { + Endpoint +} + +func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint } + +// NetListener implements Listener using an existing net.Listener. +type NetListener struct { + net.Listener + ListenerInfo +} + +func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener { + return &NetListener{ + Listener: inner, + ListenerInfo: ListenerInfo{endpoint}, + } +} + +// TCPListener is a Listener for a TCP socket. +type TCPListener struct{ NetListener } + +// ListenTCP creates a new TCPListener. +func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) { + var lc net.ListenConfig + + l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf) + if err != nil { + return nil, err // err already has all the info we could add + } + + inner := NewNetListener(endpoint, l) + + return &TCPListener{*inner}, nil +} + +// TLSListener is a Listener using TLS over TCP. +type TLSListener struct{ NetListener } + +// ListenTLS creates a new TLSListener. +func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) { + tcp, err := ListenTCP(ctx, endpoint) + if err != nil { + return nil, err + } + + inner := tcp.NetListener + + inner.Listener = tls.NewListener(inner.Listener, cfg) + + return &TLSListener{inner}, nil +} diff --git a/service/merge.go b/service/merge.go new file mode 100644 index 00000000..53c9d943 --- /dev/null +++ b/service/merge.go @@ -0,0 +1,57 @@ +package service + +import "errors" + +// Merger is a Service that can be merged with another compatible one. +type Merger interface { + Service + + // Merge returns the result of merging the receiver with the other Service. + // + // Neither the receiver, nor the other Service should be used directly after + // calling this method. + Merge(other Service) (Merger, error) +} + +// MergeAll merges the given services, if they are compatible. +// +// This allows using multiple compatible services with a single listener. +// +// All passed-in services must not be re-used. +func MergeAll(services []Service) (Service, error) { + switch len(services) { + case 0: + return nil, errors.New("no services given") + + case 1: + return services[0], nil + } + + merger, err := firstMerger(services) + if err != nil { + return nil, err + } + + for _, svc := range services { + if svc == merger { + continue + } + + merger, err = merger.Merge(svc) + if err != nil { + return nil, err + } + } + + return merger, nil +} + +func firstMerger(services []Service) (Merger, error) { + for _, t := range services { + if svc, ok := t.(Merger); ok { + return svc, nil + } + } + + return nil, errors.New("no merger found") +} diff --git a/service/service.go b/service/service.go new file mode 100644 index 00000000..45e95dd4 --- /dev/null +++ b/service/service.go @@ -0,0 +1,106 @@ +// Package service exposes types to abstract services from the networking. +// +// The idea is that we build a set of services and a set of network endpoints (Listener). +// The services are then assigned to endpoints based on the address(es) they were configured for. +// +// Actual service to endpoint binding is not handled by the abstractions in this package as it is +// protocol specific. +// The general pattern is to make a "server" that wraps a service, and can then be started on an +// endpoint using a `Serve` method, similar to `http.Server`. +// +// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port), +// services can implement `Merger`. +package service + +import ( + "fmt" + "slices" + "strings" + + "github.com/0xERR0R/blocky/util" +) + +// Service is a network exposed service. +// +// It contains only the logic and user configured addresses it should be exposed on. +// Is is meant to be associated to one or more sockets via those addresses. +// Actual association with a socket is protocol specific. +type Service interface { + fmt.Stringer + + // ServiceName returns the user friendly name of the service. + ServiceName() string + + // ExposeOn returns the set of endpoints the service should be exposed on. + // + // They can be used to find listener(s) with matching configuration. + ExposeOn() []Endpoint +} + +func svcString(s Service) string { + endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() }) + + return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", ")) +} + +// Info can be embedded in structs to help implement Service. +type Info struct { + Name string + Endpoints []Endpoint +} + +func (i *Info) ServiceName() string { return i.Name } +func (i *Info) ExposeOn() []Endpoint { return i.Endpoints } +func (i *Info) String() string { return svcString(i) } + +// GroupByListener returns a map of listener and services grouped by configured address. +// +// Each input listener is a key in the map. The corresponding value is a service +// merged from all services with a matching address. +func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) { + res := make(map[Listener]Service, len(listeners)) + unused := slices.Clone(services) + + for _, listener := range listeners { + services := findAllCompatible(services, listener.Exposes()) + if len(services) == 0 { + return nil, fmt.Errorf("found no compatible services for listener %s", listener) + } + + svc, err := MergeAll(services) + if err != nil { + return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err) + } + + res[listener] = svc + + for _, svc := range services { + if i := slices.Index(unused, svc); i != -1 { + unused = slices.Delete(unused, i, i+1) + } + } + } + + if len(unused) != 0 { + return nil, fmt.Errorf("found no compatible listener for services: %v", unused) + } + + return res, nil +} + +// findAllCompatible returns the subset of services that use the given Listener. +func findAllCompatible(services []Service, endpoint Endpoint) []Service { + res := make([]Service, 0, len(services)) + + for _, svc := range services { + if isExposedOn(svc, endpoint) { + res = append(res, svc) + } + } + + return res +} + +func isExposedOn(svc Service, endpoint Endpoint) bool { + return slices.Index(svc.ExposeOn(), endpoint) != -1 +} diff --git a/util/slices.go b/util/slices.go new file mode 100644 index 00000000..fe10a830 --- /dev/null +++ b/util/slices.go @@ -0,0 +1,33 @@ +package util + +// ForEach implements the functional map operation, under a different +// name to avoid confusion with Go's map type. +func ForEach[T, U any](slice []T, convert func(T) U) []U { + res := make([]U, 0, len(slice)) + + for _, t := range slice { + u := convert(t) + + res = append(res, u) + } + + return res +} + +// ConcatSlices returns a new slice with contents of all the inputs concatenated. +func ConcatSlices[T any](slices ...[]T) []T { + // Allocation is usually the bottleneck, so do it all at once + totalLen := 0 + + for _, slice := range slices { + totalLen += len(slice) + } + + res := make([]T, 0, totalLen) + + for _, slice := range slices { + res = append(res, slice...) + } + + return res +}