refactor: add `service` package to prepare for split HTTP handling

Package service exposes types to abstract services from the networking.

The idea is that we build a set of services and a set of network
endpoints (Listener). The services are then assigned to endpoints based
on the address(es) they were configured for.

Actual service to endpoint binding is not handled by the abstractions in
this package as it is protocol specific.
The general pattern is to make a "server" that wraps a service, and can
then be started on an endpoint using a `Serve` method,
similar to `http.Server`.

To support exposing multiple compatible services on a single endpoint
(example: DoH + metrics on a single port),
services can implement `Merger`.
This commit is contained in:
ThinkChaos 2024-04-03 13:40:04 -04:00
parent 4b37b404bf
commit c11f9a1c98
No known key found for this signature in database
9 changed files with 566 additions and 81 deletions

View File

@ -6,17 +6,55 @@ import (
"net/http"
"time"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
)
type httpServer struct {
inner http.Server
name string
// httpMiscService implements service.HTTPService.
//
// This supports the existing single HTTP/HTTPS endpoints
// that expose everything. The goal is to split it up
// and remove it.
type httpMiscService struct {
service.HTTPInfo
}
func newHTTPServer(name string, handler http.Handler) *httpServer {
func newHTTPMiscService(
cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler,
) *httpMiscService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
)
return &httpMiscService{
HTTPInfo: service.HTTPInfo{
Info: service.Info{
Name: "HTTP",
Endpoints: endpoints,
},
Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler),
},
}
}
func (s *httpMiscService) Merge(other service.Service) (service.Merger, error) {
return service.MergeHTTP(s, other)
}
// httpServer implements subServer for HTTP.
type httpServer struct {
service.HTTPService
inner http.Server
}
func newHTTPServer(svc service.HTTPService) *httpServer {
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
@ -24,22 +62,17 @@ func newHTTPServer(name string, handler http.Handler) *httpServer {
)
return &httpServer{
HTTPService: svc,
inner: http.Server{
ReadTimeout: readTimeout,
Handler: withCommonMiddleware(svc.Router()),
ReadHeaderTimeout: readHeaderTimeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
Handler: withCommonMiddleware(handler),
},
name: name,
}
}
func (s *httpServer) String() string {
return s.name
}
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
go func() {
<-ctx.Done()

View File

@ -18,6 +18,7 @@ import (
"net/http"
"runtime"
"runtime/debug"
"slices"
"strings"
"time"
@ -27,6 +28,8 @@ import (
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/service"
"golang.org/x/exp/maps"
"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
@ -49,7 +52,14 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config
servers map[net.Listener]*httpServer
services map[service.Listener]service.Service
}
type subServer interface {
fmt.Stringer
service.Service
Serve(context.Context, net.Listener) error
}
func logger() *logrus.Entry {
@ -108,8 +118,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
}
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var tlsCfg *tls.Config
@ -125,7 +133,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err)
}
httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg)
listeners, err := createListeners(ctx, cfg, tlsCfg)
if err != nil {
return nil, err
}
@ -154,41 +162,43 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
servers: make(map[net.Listener]*httpServer),
}
server.printConfiguration()
server.registerDNSHandlers(ctx)
openAPIImpl, err := server.createOpenAPIInterfaceImpl()
services, err := server.createServices()
if err != nil {
return nil, err
}
httpRouter := createHTTPRouter(cfg, openAPIImpl)
server.registerDoHEndpoints(httpRouter)
if len(cfg.Ports.HTTP) != 0 {
srv := newHTTPServer("http", httpRouter)
for _, l := range httpListeners {
server.servers[l] = srv
}
}
if len(cfg.Ports.HTTPS) != 0 {
srv := newHTTPServer("https", httpRouter)
for _, l := range httpsListeners {
server.servers[l] = srv
}
server.services, err = service.GroupByListener(services, listeners)
if err != nil {
return nil, err
}
return server, err
}
func (s *Server) createServices() ([]service.Service, error) {
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return nil, err
}
res := []service.Service{
newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq),
}
// Remove services the user has not enabled
res = slices.DeleteFunc(res, func(svc service.Service) bool {
return len(svc.ExposeOn()) == 0
})
return res, nil
}
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
var dnsServers []*dns.Server
@ -217,48 +227,51 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
return dnsServers, err.ErrorOrNil()
}
func createHTTPListeners(
cfg *config.Config, tlsCfg *tls.Config,
) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP)
if err != nil {
return nil, nil, err
func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) {
res := make(map[string]service.Listener)
listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) {
return service.ListenTLS(ctx, endpoint, tlsCfg)
}
httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg)
if err != nil {
return nil, nil, err
}
return httpListeners, httpsListeners, nil
}
func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses))
for _, address := range addresses {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
}
listeners = append(listeners, listener)
}
return listeners, nil
}
func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) {
listeners, err := newTCPListeners(proto, addresses)
err := errors.Join(
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
)
if err != nil {
return nil, err
}
for i, inner := range listeners {
listeners[i] = tls.NewListener(inner, tlsCfg)
return maps.Values(res), nil
}
type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error)
func newListeners[T service.Listener](
ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener,
) error {
for _, addr := range addrs {
key := fmt.Sprintf("%s:%s", proto, addr)
if _, ok := out[key]; ok {
// Avoid "address already in use"
// We instead try to merge services, see services.GroupByListener
continue
}
endpoint := service.Endpoint{
Protocol: proto,
AddrConf: addr,
}
l, err := listen(ctx, endpoint)
if err != nil {
return err // already has all info
}
out[key] = l
}
return listeners, nil
return nil
}
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
@ -490,6 +503,16 @@ func toMB(b uint64) uint64 {
return b / bytesInKB / bytesInKB
}
func newSubServer(svc service.Service) (subServer, error) {
switch svc := svc.(type) {
case service.HTTPService:
return newHTTPServer(svc), nil
default:
return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc)
}
}
// Start starts the server
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
logger().Info("Starting server")
@ -504,11 +527,18 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
}()
}
for listener, srv := range s.servers {
listener, srv := listener, srv
for listener, svc := range s.services {
listener, svc := listener, svc
srv, err := newSubServer(svc)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err)
return
}
go func() {
logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr())
logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes())
err := srv.Serve(ctx, listener)
if err != nil {
@ -611,6 +641,8 @@ type msgWriter interface {
WriteMsg(msg *dns.Msg) error
}
type dnsHandler func(context.Context, *model.Request, msgWriter)
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
response, err := s.resolve(ctx, request)
if err != nil {

View File

@ -52,9 +52,11 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}
func (s *Server) registerDoHEndpoints(router *chi.Mux) {
func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) {
const pathDohQuery = "/dns-query"
s := &dohServer{dnsHandler}
router.Get(pathDohQuery, s.dohGetRequestHandler)
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
@ -63,7 +65,9 @@ func (s *Server) registerDoHEndpoints(router *chi.Mux) {
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
}
func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
type dohServer struct{ handler dnsHandler }
func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)
@ -87,7 +91,7 @@ func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request)
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
@ -111,7 +115,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
@ -122,7 +126,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpRe
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
s.handler(ctx, dnsReq, httpMsgWriter{rw})
}
type httpMsgWriter struct {
@ -156,7 +160,7 @@ func (s *Server) Query(
return s.resolve(ctx, req)
}
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux {
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux {
router := chi.NewRouter()
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
@ -169,6 +173,8 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface)
configureRootHandler(cfg, router)
registerDoHEndpoints(router, dnsHandler)
metrics.Start(router, cfg.Prometheus)
return router

51
service/endpoint.go Normal file
View File

@ -0,0 +1,51 @@
package service
import (
"fmt"
"slices"
"strings"
"github.com/0xERR0R/blocky/util"
"golang.org/x/exp/maps"
)
// Endpoint is a network endpoint on which to expose a service.
type Endpoint struct {
// Protocol is the protocol to be exposed on this endpoint.
Protocol string
// AddrConf is the network address as configured by the user.
AddrConf string
}
func EndpointsFromAddrs(proto string, addrs []string) []Endpoint {
return util.ForEach(addrs, func(addr string) Endpoint {
return Endpoint{
Protocol: proto,
AddrConf: addr,
}
})
}
func (e Endpoint) String() string {
addr := e.AddrConf
if strings.HasPrefix(addr, ":") {
addr = "*" + addr
}
return fmt.Sprintf("%s://%s", e.Protocol, addr)
}
type endpointSet map[Endpoint]struct{}
func (s endpointSet) ToSlice() []Endpoint {
return maps.Keys(s)
}
func (s endpointSet) IntersectSlice(others []Endpoint) {
for endpoint := range s {
if !slices.Contains(others, endpoint) {
delete(s, endpoint)
}
}
}

93
service/http.go Normal file
View File

@ -0,0 +1,93 @@
package service
import (
"errors"
"net/http"
"strings"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
)
const (
HTTPProtocol = "http"
HTTPSProtocol = "https"
)
// HTTPService is a Service using a HTTP router.
type HTTPService interface {
Service
Merger
// Router returns the service's router.
Router() chi.Router
}
// HTTPInfo can be embedded in structs to help implement HTTPService.
type HTTPInfo struct {
Info
Mux *chi.Mux
}
func (i *HTTPInfo) Router() chi.Router { return i.Mux }
// MergeHTTP merges two compatible HTTPServices.
//
// The second parameter is of type `Service` to make it easy to call
// from a `Merger.Merge` implementation.
func MergeHTTP(a HTTPService, b Service) (Merger, error) {
return newHTTPMerger(a).Merge(b)
}
var _ HTTPService = (*httpMerger)(nil)
// httpMerger can merge HTTPServices by combining their routes.
type httpMerger struct {
inner []HTTPService
router chi.Router
endpoints endpointSet
}
func newHTTPMerger(first HTTPService) *httpMerger {
return &httpMerger{
inner: []HTTPService{first},
router: first.Router(),
}
}
func (m *httpMerger) String() string { return svcString(m) }
func (m *httpMerger) ServiceName() string {
names := util.ForEach(m.inner, func(svc HTTPService) string {
return svc.ServiceName()
})
return strings.Join(names, " & ")
}
func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() }
func (m *httpMerger) Router() chi.Router { return m.router }
func (m *httpMerger) Merge(other Service) (Merger, error) {
httpSvc, ok := other.(HTTPService)
if !ok {
return nil, errors.New("not an HTTPService")
}
type middleware = func(http.Handler) http.Handler
// Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
m.router.With(middlewares...).Method(method, route, handler)
return nil
})
m.inner = append(m.inner, httpSvc)
// Don't expose any service more than it expects
m.endpoints.IntersectSlice(other.ExposeOn())
return m, nil
}

74
service/listener.go Normal file
View File

@ -0,0 +1,74 @@
package service
import (
"context"
"crypto/tls"
"fmt"
"net"
)
// Listener is a net.Listener that provides information about
// what protocol and address it is configured for.
type Listener interface {
fmt.Stringer
net.Listener
// Exposes returns the endpoint for this listener.
//
// It can be used to find service(s) with matching configuration.
Exposes() Endpoint
}
// ListenerInfo can be embedded in structs to help implement Listener.
type ListenerInfo struct {
Endpoint
}
func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint }
// NetListener implements Listener using an existing net.Listener.
type NetListener struct {
net.Listener
ListenerInfo
}
func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener {
return &NetListener{
Listener: inner,
ListenerInfo: ListenerInfo{endpoint},
}
}
// TCPListener is a Listener for a TCP socket.
type TCPListener struct{ NetListener }
// ListenTCP creates a new TCPListener.
func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf)
if err != nil {
return nil, err // err already has all the info we could add
}
inner := NewNetListener(endpoint, l)
return &TCPListener{*inner}, nil
}
// TLSListener is a Listener using TLS over TCP.
type TLSListener struct{ NetListener }
// ListenTLS creates a new TLSListener.
func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) {
tcp, err := ListenTCP(ctx, endpoint)
if err != nil {
return nil, err
}
inner := tcp.NetListener
inner.Listener = tls.NewListener(inner.Listener, cfg)
return &TLSListener{inner}, nil
}

57
service/merge.go Normal file
View File

@ -0,0 +1,57 @@
package service
import "errors"
// Merger is a Service that can be merged with another compatible one.
type Merger interface {
Service
// Merge returns the result of merging the receiver with the other Service.
//
// Neither the receiver, nor the other Service should be used directly after
// calling this method.
Merge(other Service) (Merger, error)
}
// MergeAll merges the given services, if they are compatible.
//
// This allows using multiple compatible services with a single listener.
//
// All passed-in services must not be re-used.
func MergeAll(services []Service) (Service, error) {
switch len(services) {
case 0:
return nil, errors.New("no services given")
case 1:
return services[0], nil
}
merger, err := firstMerger(services)
if err != nil {
return nil, err
}
for _, svc := range services {
if svc == merger {
continue
}
merger, err = merger.Merge(svc)
if err != nil {
return nil, err
}
}
return merger, nil
}
func firstMerger(services []Service) (Merger, error) {
for _, t := range services {
if svc, ok := t.(Merger); ok {
return svc, nil
}
}
return nil, errors.New("no merger found")
}

106
service/service.go Normal file
View File

@ -0,0 +1,106 @@
// Package service exposes types to abstract services from the networking.
//
// The idea is that we build a set of services and a set of network endpoints (Listener).
// The services are then assigned to endpoints based on the address(es) they were configured for.
//
// Actual service to endpoint binding is not handled by the abstractions in this package as it is
// protocol specific.
// The general pattern is to make a "server" that wraps a service, and can then be started on an
// endpoint using a `Serve` method, similar to `http.Server`.
//
// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port),
// services can implement `Merger`.
package service
import (
"fmt"
"slices"
"strings"
"github.com/0xERR0R/blocky/util"
)
// Service is a network exposed service.
//
// It contains only the logic and user configured addresses it should be exposed on.
// Is is meant to be associated to one or more sockets via those addresses.
// Actual association with a socket is protocol specific.
type Service interface {
fmt.Stringer
// ServiceName returns the user friendly name of the service.
ServiceName() string
// ExposeOn returns the set of endpoints the service should be exposed on.
//
// They can be used to find listener(s) with matching configuration.
ExposeOn() []Endpoint
}
func svcString(s Service) string {
endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() })
return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", "))
}
// Info can be embedded in structs to help implement Service.
type Info struct {
Name string
Endpoints []Endpoint
}
func (i *Info) ServiceName() string { return i.Name }
func (i *Info) ExposeOn() []Endpoint { return i.Endpoints }
func (i *Info) String() string { return svcString(i) }
// GroupByListener returns a map of listener and services grouped by configured address.
//
// Each input listener is a key in the map. The corresponding value is a service
// merged from all services with a matching address.
func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) {
res := make(map[Listener]Service, len(listeners))
unused := slices.Clone(services)
for _, listener := range listeners {
services := findAllCompatible(services, listener.Exposes())
if len(services) == 0 {
return nil, fmt.Errorf("found no compatible services for listener %s", listener)
}
svc, err := MergeAll(services)
if err != nil {
return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err)
}
res[listener] = svc
for _, svc := range services {
if i := slices.Index(unused, svc); i != -1 {
unused = slices.Delete(unused, i, i+1)
}
}
}
if len(unused) != 0 {
return nil, fmt.Errorf("found no compatible listener for services: %v", unused)
}
return res, nil
}
// findAllCompatible returns the subset of services that use the given Listener.
func findAllCompatible(services []Service, endpoint Endpoint) []Service {
res := make([]Service, 0, len(services))
for _, svc := range services {
if isExposedOn(svc, endpoint) {
res = append(res, svc)
}
}
return res
}
func isExposedOn(svc Service, endpoint Endpoint) bool {
return slices.Index(svc.ExposeOn(), endpoint) != -1
}

33
util/slices.go Normal file
View File

@ -0,0 +1,33 @@
package util
// ForEach implements the functional map operation, under a different
// name to avoid confusion with Go's map type.
func ForEach[T, U any](slice []T, convert func(T) U) []U {
res := make([]U, 0, len(slice))
for _, t := range slice {
u := convert(t)
res = append(res, u)
}
return res
}
// ConcatSlices returns a new slice with contents of all the inputs concatenated.
func ConcatSlices[T any](slices ...[]T) []T {
// Allocation is usually the bottleneck, so do it all at once
totalLen := 0
for _, slice := range slices {
totalLen += len(slice)
}
res := make([]T, 0, totalLen)
for _, slice := range slices {
res = append(res, slice...)
}
return res
}