This commit is contained in:
ThinkChaos 2024-04-26 15:47:39 -04:00 committed by GitHub
commit 3cd36e3ad4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1005 additions and 308 deletions

View File

@ -53,7 +53,7 @@ type CacheControl interface {
FlushCaches(ctx context.Context)
}
func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
func registerOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware}
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api")

View File

@ -105,7 +105,7 @@ var _ = Describe("API implementation tests", func() {
Describe("RegisterOpenAPIEndpoints", func() {
It("adds routes", func() {
rtr := chi.NewRouter()
RegisterOpenAPIEndpoints(rtr, sut)
registerOpenAPIEndpoints(rtr, sut)
Expect(rtr.Routes()).ShouldNot(BeEmpty())
})

27
api/service.go Normal file
View File

@ -0,0 +1,27 @@
package api
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
)
// Service implements service.HTTPService.
type Service struct {
service.SimpleHTTP
}
func NewService(cfg config.APIService, server StrictServerInterface) *Service {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
)
s := &Service{
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
}
registerOpenAPIEndpoints(s.Router(), server)
return s
}

View File

@ -10,6 +10,7 @@ import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/evt"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/server"
"github.com/0xERR0R/blocky/util"
@ -47,6 +48,10 @@ func startServer(_ *cobra.Command, _ []string) error {
ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()
if cfg.Prometheus.Enable {
metrics.StartCollection()
}
srv, err := server.NewServer(ctx, cfg)
if err != nil {
return fmt.Errorf("can't start server: %w", err)

View File

@ -170,6 +170,13 @@ func (l *ListenConfig) UnmarshalText(data []byte) error {
*l = strings.Split(addresses, ",")
// Prefix all ports with :
for i, addr := range *l {
if !strings.ContainsRune(addr, ':') {
(*l)[i] = ":" + addr
}
}
return nil
}
@ -226,6 +233,7 @@ type Config struct {
Redis Redis `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports Ports `yaml:"ports"`
Services Services `yaml:"-"` // not user exposed yet
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
@ -255,6 +263,19 @@ type Config struct {
} `yaml:",inline"`
}
// Services holds network service related configuration.
//
// The actual config layout is not decided yet.
// See https://github.com/0xERR0R/blocky/issues/1206
//
// The `yaml` struct tags are just for manual testing,
// and require replacing `yaml:"-"` in Config to work.
type Services struct {
API APIService `yaml:"control-api"`
DoH DoHService `yaml:"dns-over-https"`
Metrics MetricsService `yaml:"metrics"`
}
type Ports struct {
DNS ListenConfig `yaml:"dns" default:"53"`
HTTP ListenConfig `yaml:"http"`
@ -594,6 +615,23 @@ func (cfg *Config) validate(logger *logrus.Entry) {
cfg.Upstreams.validate(logger)
}
// CopyPortsToServices sets Services values to match Ports.
//
// This should be replaced with a migration once everything from Ports is supported in Services.
// Done this way for now to avoid creating temporary generic services and updating all Ports related code at once.
func (cfg *Config) CopyPortsToServices() {
httpAddrs := httpAddrs{
HTTPAddrs: HTTPAddrs{HTTP: cfg.Ports.HTTP},
HTTPSAddrs: HTTPSAddrs{HTTPS: cfg.Ports.HTTPS},
}
cfg.Services = Services{
API: APIService{Addrs: httpAddrs},
DoH: DoHService{Addrs: httpAddrs},
Metrics: MetricsService{Addrs: httpAddrs},
}
}
// ConvertPort converts string representation into a valid port (0 - 65535)
func ConvertPort(in string) (uint16, error) {
const (

View File

@ -462,7 +462,7 @@ bootstrapDns:
err := l.UnmarshalText([]byte("55,:56"))
Expect(err).Should(Succeed())
Expect(*l).Should(HaveLen(2))
Expect(*l).Should(ContainElements("55", ":56"))
Expect(*l).Should(ContainElements(":55", ":56"))
})
})
})
@ -958,7 +958,7 @@ bootstrapDns:
})
func defaultTestFileConfig(config *Config) {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Ports.DNS).Should(Equal(ListenConfig{":55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.Init.Strategy).Should(Equal(InitStrategyFailOnError))
Expect(config.Upstreams.UserAgent).Should(Equal("testBlocky"))
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))

25
config/doh_service.go Normal file
View File

@ -0,0 +1,25 @@
package config
type (
APIService httpService
DoHService httpService
MetricsService httpService
)
// httpService can be used by any service that uses HTTP(S).
type httpService struct {
Addrs httpAddrs `yaml:"addrs"`
}
type httpAddrs struct {
HTTPAddrs `yaml:",inline"`
HTTPSAddrs `yaml:",inline"`
}
type HTTPAddrs struct {
HTTP ListenConfig `yaml:"http"`
}
type HTTPSAddrs struct {
HTTPS ListenConfig `yaml:"https"`
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
@ -29,20 +30,27 @@ const (
DS = dns.Type(dns.TypeDS)
)
// GetIntPort returns an port for the current testing
// GetIntPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as int
// the base port and returning it as int.
func GetIntPort(port int) int {
return port + ginkgo.GinkgoParallelProcess()
}
// GetStringPort returns an port for the current testing
// GetStringPort returns a port for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string
// the base port and returning it as string.
func GetStringPort(port int) string {
return fmt.Sprintf("%d", GetIntPort(port))
}
// GetHostPort returns a host:port string for the current testing
// process by adding the current ginkgo parallel process to
// the base port and returning it as string.
func GetHostPort(host string, port int) string {
return net.JoinHostPort(host, GetStringPort(port))
}
// TempFile creates temp file with passed data
func TempFile(data string) *os.File {
f, err := os.CreateTemp("", "prefix")

View File

@ -1,12 +1,8 @@
package metrics
import (
"github.com/0xERR0R/blocky/config"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
//nolint:gochecknoglobals
@ -17,12 +13,9 @@ func RegisterMetric(c prometheus.Collector) {
_ = reg.Register(c)
}
// Start starts prometheus endpoint
func Start(router *chi.Mux, cfg config.Metrics) {
if cfg.Enable {
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
_ = reg.Register(collectors.NewGoCollector())
router.Handle(cfg.Path, promhttp.InstrumentMetricHandler(reg,
promhttp.HandlerFor(reg, promhttp.HandlerOpts{})))
}
func StartCollection() {
_ = reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
_ = reg.Register(collectors.NewGoCollector())
registerEventListeners()
}

View File

@ -11,8 +11,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
// RegisterEventListeners registers all metric handlers by the event bus
func RegisterEventListeners() {
// registerEventListeners registers all metric handlers on the event bus
func registerEventListeners() {
registerBlockingEventListeners()
registerCachingEventListeners()
registerApplicationEventListeners()

36
metrics/service.go Normal file
View File

@ -0,0 +1,36 @@
package metrics
import (
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Service implements service.HTTPService.
type Service struct {
service.SimpleHTTP
}
func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
)
if !metricsCfg.Enable || len(endpoints) == 0 {
// Avoid setting up collectors and listeners
return new(Service)
}
s := &Service{
SimpleHTTP: service.NewSimpleHTTP("Metrics", endpoints),
}
s.Router().Handle(
metricsCfg.Path,
promhttp.InstrumentMetricHandler(reg, promhttp.HandlerFor(reg, promhttp.HandlerOpts{})),
)
return s
}

126
server/doh.go Normal file
View File

@ -0,0 +1,126 @@
package server
import (
"encoding/base64"
"io"
"net/http"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/miekg/dns"
)
type dohService struct {
service.SimpleHTTP
handler dnsHandler
}
func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
)
s := &dohService{
SimpleHTTP: service.NewSimpleHTTP("DoH", endpoints),
handler: handler,
}
s.Router().Route("/dns-query", func(mux chi.Router) {
// Handlers for / also handle /dns-query without trailing slash
mux.Get("/", s.handleGET)
mux.Get("/{clientID}", s.handleGET)
mux.Post("/", s.handlePOST)
mux.Post("/{clientID}", s.handlePOST)
})
return s
}
func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)
return
}
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
if err != nil {
http.Error(rw, "wrong message format", http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
return
}
rawMsg, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
s.handler(ctx, dnsReq, httpMsgWriter{rw})
}
type httpMsgWriter struct {
rw http.ResponseWriter
}
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}
r.rw.Header().Set("content-type", dnsContentType)
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)
_, err = r.rw.Write(b)
return err
}

119
server/http.go Normal file
View File

@ -0,0 +1,119 @@
package server
import (
"context"
"net"
"net/http"
"time"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
)
// httpMiscService implements service.HTTPService.
//
// This supports the existing single HTTP/HTTPS endpoints
// that expose everything. The goal is to split it up
// and remove it.
type httpMiscService struct {
service.SimpleHTTP
}
func newHTTPMiscService(cfg *config.Config) *httpMiscService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
)
s := &httpMiscService{
SimpleHTTP: service.NewSimpleHTTP("API", endpoints),
}
configureHTTPRouter(s.Router(), cfg)
return s
}
// httpServer implements subServer for HTTP.
type httpServer struct {
service.HTTPService
inner http.Server
}
func newHTTPServer(svc service.HTTPService) *httpServer {
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)
return &httpServer{
HTTPService: svc,
inner: http.Server{
Handler: withCommonMiddleware(svc.Router()),
ReadHeaderTimeout: readHeaderTimeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
},
}
}
func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
go func() {
<-ctx.Done()
s.inner.Close()
}()
return s.inner.Serve(l)
}
func withCommonMiddleware(inner http.Handler) *chi.Mux {
// Middleware must be defined before routes, so
// create a new router and mount the inner handler
mux := chi.NewMux()
mux.Use(
secureHeadersMiddleware,
newCORSMiddleware(),
)
mux.Mount("/", inner)
return mux
}
type httpMiddleware = func(http.Handler) http.Handler
func secureHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
}
next.ServeHTTP(w, r)
})
}
func newCORSMiddleware() httpMiddleware {
const corsMaxAge = 5 * time.Minute
options := cors.Options{
AllowCredentials: true,
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
AllowedMethods: []string{"GET", "POST"},
AllowedOrigins: []string{"*"},
ExposedHeaders: []string{"Link"},
MaxAge: int(corsMaxAge.Seconds()),
}
return cors.New(options).Handler
}

View File

@ -18,15 +18,20 @@ import (
"net/http"
"runtime"
"runtime/debug"
"slices"
"strings"
"time"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/service"
"golang.org/x/exp/maps"
"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror"
@ -44,14 +49,18 @@ const (
// Server controls the endpoints for DNS and HTTP
type Server struct {
dnsServers []*dns.Server
httpListeners []net.Listener
httpsListeners []net.Listener
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
httpsMux *chi.Mux
cert tls.Certificate
dnsServers []*dns.Server
queryResolver resolver.ChainedResolver
cfg *config.Config
services map[service.Listener]service.Service
}
type subServer interface {
fmt.Stringer
service.Service
Serve(context.Context, net.Listener) error
}
func logger() *logrus.Entry {
@ -71,14 +80,6 @@ func tlsCipherSuites() []uint16 {
return tlsCipherSuites
}
func getServerAddress(addr string) string {
if !strings.Contains(addr, ":") {
addr = fmt.Sprintf(":%s", addr)
}
return addr
}
type NewServerFunc func(address string) (*dns.Server, error)
func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
@ -99,39 +100,47 @@ func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
return
}
// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
var cert tls.Certificate
cert, err := retrieveCertificate(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
}
// #nosec G402 // See TLSVersion.validate
res := &tls.Config{
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{cert},
}
return res, nil
}
// NewServer creates new server instance with passed config
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
cfg.CopyPortsToServices()
var tlsCfg *tls.Config
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
cert, err = retrieveCertificate(cfg)
tlsCfg, err = newTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("can't retrieve cert: %w", err)
return nil, err
}
}
dnsServers, err := createServers(cfg, cert)
dnsServers, err := createServers(cfg, tlsCfg)
if err != nil {
return nil, fmt.Errorf("server creation failed: %w", err)
}
httpRouter := createHTTPRouter(cfg)
httpsRouter := createHTTPSRouter(cfg)
httpListeners, httpsListeners, err := createHTTPListeners(cfg)
listeners, err := createListeners(ctx, cfg, tlsCfg)
if err != nil {
return nil, err
}
if len(httpListeners) != 0 || len(httpsListeners) != 0 {
metrics.Start(httpRouter, cfg.Prometheus)
metrics.Start(httpsRouter, cfg.Prometheus)
}
metrics.RegisterEventListeners()
bootstrap, err := resolver.NewBootstrap(ctx, cfg)
if err != nil {
return nil, err
@ -151,27 +160,21 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
}
server = &Server{
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
httpMux: httpRouter,
httpsMux: httpsRouter,
cert: cert,
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,
}
server.printConfiguration()
server.registerDNSHandlers(ctx)
err = server.registerAPIEndpoints(httpRouter)
services, err := server.createServices()
if err != nil {
return nil, err
}
err = server.registerAPIEndpoints(httpsRouter)
server.services, err = service.GroupByListener(services, listeners)
if err != nil {
return nil, err
}
@ -179,14 +182,35 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return server, err
}
func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, error) {
func (s *Server) createServices() ([]service.Service, error) {
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return nil, err
}
res := []service.Service{
newHTTPMiscService(s.cfg),
newDoHService(s.cfg.Services.DoH, s.handleReq),
api.NewService(s.cfg.Services.API, openAPIImpl),
metrics.NewService(s.cfg.Services.Metrics, s.cfg.Prometheus),
}
// Remove services the user has not enabled
res = slices.DeleteFunc(res, func(svc service.Service) bool {
return len(svc.ExposeOn()) == 0
})
return res, nil
}
func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
var dnsServers []*dns.Server
var err *multierror.Error
addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
for _, address := range addresses {
server, err := newServer(getServerAddress(address))
server, err := newServer(address)
if err != nil {
return err
}
@ -201,52 +225,69 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
addServers(createUDPServer, cfg.Ports.DNS),
addServers(createTCPServer, cfg.Ports.DNS),
addServers(func(address string) (*dns.Server, error) {
return createTLSServer(cfg, address, cert)
return createTLSServer(address, tlsCfg)
}, cfg.Ports.TLS))
return dnsServers, err.ErrorOrNil()
}
func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
if err != nil {
return nil, nil, err
func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) {
res := make(map[string]service.Listener)
listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) {
return service.ListenTLS(ctx, endpoint, tlsCfg)
}
httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
err := errors.Join(
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
newListeners(ctx, service.HTTPProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res),
newListeners(ctx, service.HTTPProtocol, cfg.Services.Metrics.Addrs.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Services.Metrics.Addrs.HTTPS, listenTLS, res),
)
if err != nil {
return nil, nil, err
return nil, err
}
return httpListeners, httpsListeners, nil
return maps.Values(res), nil
}
func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses))
type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error)
for _, address := range addresses {
listener, err := net.Listen("tcp", getServerAddress(address))
if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
func newListeners[T service.Listener](
ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener,
) error {
for _, addr := range addrs {
key := fmt.Sprintf("%s:%s", proto, addr)
if _, ok := out[key]; ok {
// Avoid "address already in use"
// We instead try to merge services, see services.GroupByListener
continue
}
listeners = append(listeners, listener)
endpoint := service.Endpoint{
Protocol: proto,
AddrConf: addr,
}
l, err := listen(ctx, endpoint)
if err != nil {
return err // already has all info
}
out[key] = l
}
return listeners, nil
return nil
}
func createTLSServer(cfg *config.Config, address string, cert tls.Certificate) (*dns.Server, error) {
func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{
Addr: address,
Net: "tcp-tls",
//nolint:gosec
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: uint16(cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
},
Handler: dns.NewServeMux(),
Addr: address,
Net: "tcp-tls",
TLSConfig: tlsCfg,
Handler: dns.NewServeMux(),
NotifyStartedFunc: func() {
logger().Infof("TLS server is up and running on address %s", address)
},
@ -470,11 +511,15 @@ func toMB(b uint64) uint64 {
return b / bytesInKB / bytesInKB
}
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)
func newSubServer(svc service.Service) (subServer, error) {
switch svc := svc.(type) {
case service.HTTPService:
return newHTTPServer(svc), nil
default:
return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc)
}
}
// Start starts the server
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
@ -490,48 +535,22 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
}()
}
for i, listener := range s.httpListeners {
listener := listener
address := s.cfg.Ports.HTTP[i]
for listener, svc := range s.services {
listener, svc := listener, svc
srv, err := newSubServer(svc)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err)
return
}
go func() {
logger().Infof("http server is up and running on addr/port %s", address)
logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes())
srv := &http.Server{
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
Handler: s.httpsMux,
}
if err := srv.Serve(listener); err != nil {
errCh <- fmt.Errorf("start http listener failed: %w", err)
}
}()
}
for i, listener := range s.httpsListeners {
listener := listener
address := s.cfg.Ports.HTTPS[i]
go func() {
logger().Infof("https server is up and running on addr/port %s", address)
server := http.Server{
Handler: s.httpsMux,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
//nolint:gosec
TLSConfig: &tls.Config{
MinVersion: uint16(s.cfg.MinTLSServeVer),
CipherSuites: tlsCipherSuites(),
Certificates: []tls.Certificate{s.cert},
},
}
if err := server.ServeTLS(listener, "", ""); err != nil {
errCh <- fmt.Errorf("start https listener failed: %w", err)
err := srv.Serve(ctx, listener)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", srv, listener.Addr(), err)
}
}()
}
@ -630,6 +649,8 @@ type msgWriter interface {
WriteMsg(msg *dns.Msg) error
}
type dnsHandler func(context.Context, *model.Request, msgWriter)
func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
response, err := s.resolve(ctx, request)
if err != nil {

View File

@ -2,13 +2,10 @@ package server
import (
"context"
"encoding/base64"
"fmt"
"html/template"
"io"
"net"
"net/http"
"time"
"github.com/0xERR0R/blocky/resolver"
@ -22,7 +19,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/miekg/dns"
)
@ -32,19 +28,8 @@ const (
dnsContentType = "application/dns-message"
htmlContentType = "text/html; charset=UTF-8"
yamlContentType = "text/yaml"
corsMaxAge = 5 * time.Minute
)
func secureHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
next.ServeHTTP(w, r)
})
}
func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) {
bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver)
if err != nil {
@ -64,108 +49,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}
func (s *Server) registerAPIEndpoints(router *chi.Mux) error {
const pathDohQuery = "/dns-query"
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return err
}
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
router.Get(pathDohQuery, s.dohGetRequestHandler)
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
router.Post(pathDohQuery, s.dohPostRequestHandler)
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
return nil
}
func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)
return
}
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
if err != nil {
http.Error(rw, "wrong message format", http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
return
}
rawMsg, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if len(rawMsg) > dohMessageLimit {
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
return
}
s.processDohMessage(rawMsg, rw, req)
}
func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
s.handleReq(ctx, dnsReq, httpMsgWriter{rw})
}
type httpMsgWriter struct {
rw http.ResponseWriter
}
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}
r.rw.Header().Set("content-type", dnsContentType)
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)
_, err = r.rw.Write(b)
return err
}
func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
@ -177,27 +60,7 @@ func (s *Server) Query(
return s.resolve(ctx, req)
}
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
router := chi.NewRouter()
configureSecureHeaderHandler(router)
registerHandlers(cfg, router)
return router
}
func createHTTPRouter(cfg *config.Config) *chi.Mux {
router := chi.NewRouter()
registerHandlers(cfg, router)
return router
}
func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureCorsHandler(router)
func configureHTTPRouter(router chi.Router, cfg *config.Config) {
configureDebugHandler(router)
configureDocsHandler(router)
@ -207,7 +70,7 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureRootHandler(cfg, router)
}
func configureDocsHandler(router *chi.Mux) {
func configureDocsHandler(router chi.Router) {
router.Get("/docs/openapi.yaml", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set(contentTypeHeader, yamlContentType)
_, err := writer.Write([]byte(docs.OpenAPI))
@ -215,7 +78,7 @@ func configureDocsHandler(router *chi.Mux) {
})
}
func configureStaticAssetsHandler(router *chi.Mux) {
func configureStaticAssetsHandler(router chi.Router) {
assets, err := web.Assets()
util.FatalOnError("unable to load static asset files", err)
@ -223,7 +86,7 @@ func configureStaticAssetsHandler(router *chi.Mux) {
router.Handle("/static/*", http.StripPrefix("/static/", fs))
}
func configureRootHandler(cfg *config.Config, router *chi.Mux) {
func configureRootHandler(cfg *config.Config, router chi.Router) {
router.Get("/", func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Set(contentTypeHeader, htmlContentType)
@ -282,22 +145,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit
}
}
func configureSecureHeaderHandler(router *chi.Mux) {
router.Use(secureHeader)
}
func configureDebugHandler(router *chi.Mux) {
func configureDebugHandler(router chi.Router) {
router.Mount("/debug", middleware.Profiler())
}
func configureCorsHandler(router *chi.Mux) {
crs := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: int(corsMaxAge.Seconds()),
})
router.Use(crs.Handler)
}

View File

@ -44,7 +44,7 @@ var (
)
var _ = BeforeSuite(func() {
baseURL = "http://localhost:" + GetStringPort(httpBasePort) + "/"
baseURL = fmt.Sprintf("http://%s/", GetHostPort("localhost", httpBasePort))
queryURL = baseURL + "dns-query"
var upstreamGoogle, upstreamFritzbox, upstreamClient config.Upstream
ctx, cancelFn := context.WithCancel(context.Background())
@ -147,10 +147,10 @@ var _ = BeforeSuite(func() {
},
Ports: config.Ports{
DNS: config.ListenConfig{GetStringPort(dnsBasePort)},
TLS: config.ListenConfig{GetStringPort(tlsBasePort)},
HTTP: config.ListenConfig{GetStringPort(httpBasePort)},
HTTPS: config.ListenConfig{GetStringPort(httpsBasePort)},
DNS: config.ListenConfig{GetHostPort("", dnsBasePort)},
TLS: config.ListenConfig{GetHostPort("", tlsBasePort)},
HTTP: config.ListenConfig{GetHostPort("", httpBasePort)},
HTTPS: config.ListenConfig{GetHostPort("", httpsBasePort)},
},
CertFile: certPem.Path,
KeyFile: keyPem.Path,
@ -634,7 +634,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})
@ -678,7 +678,7 @@ var _ = Describe("Running DNS server", func() {
},
Blocking: config.Blocking{BlockType: "zeroIp"},
Ports: config.Ports{
DNS: config.ListenConfig{"127.0.0.1:" + GetStringPort(dnsBasePort2)},
DNS: config.ListenConfig{GetHostPort("127.0.0.1", dnsBasePort2)},
},
})
@ -741,17 +741,18 @@ var _ = Describe("Running DNS server", func() {
cfg.KeyFile = ""
cfg.CertFile = ""
cfg.Ports = config.Ports{
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
HTTPS: []string{":0"},
}
sut, err := NewServer(ctx, &cfg)
sut, err := newTLSConfig(&cfg)
Expect(err).Should(Succeed())
Expect(sut.cert.Certificate).ShouldNot(BeNil())
Expect(sut.Certificates).ShouldNot(BeEmpty())
})
})
})
func requestServer(request *dns.Msg) *dns.Msg {
conn, err := net.Dial("udp", ":"+GetStringPort(dnsBasePort))
conn, err := net.Dial("udp", GetHostPort("", dnsBasePort))
if err != nil {
Log().Fatal("could not connect to server: ", err)
}

51
service/endpoint.go Normal file
View File

@ -0,0 +1,51 @@
package service
import (
"fmt"
"slices"
"strings"
"github.com/0xERR0R/blocky/util"
"golang.org/x/exp/maps"
)
// Endpoint is a network endpoint on which to expose a service.
type Endpoint struct {
// Protocol is the protocol to be exposed on this endpoint.
Protocol string
// AddrConf is the network address as configured by the user.
AddrConf string
}
func EndpointsFromAddrs(proto string, addrs []string) []Endpoint {
return util.ForEach(addrs, func(addr string) Endpoint {
return Endpoint{
Protocol: proto,
AddrConf: addr,
}
})
}
func (e Endpoint) String() string {
addr := e.AddrConf
if strings.HasPrefix(addr, ":") {
addr = "*" + addr
}
return fmt.Sprintf("%s://%s", e.Protocol, addr)
}
type endpointSet map[Endpoint]struct{}
func (s endpointSet) ToSlice() []Endpoint {
return maps.Keys(s)
}
func (s endpointSet) IntersectSlice(others []Endpoint) {
for endpoint := range s {
if !slices.Contains(others, endpoint) {
delete(s, endpoint)
}
}
}

123
service/http.go Normal file
View File

@ -0,0 +1,123 @@
package service
import (
"errors"
"net/http"
"strings"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
)
const (
HTTPProtocol = "http"
HTTPSProtocol = "https"
)
// HTTPService is a Service using a HTTP router.
type HTTPService interface {
Service
Merger
// Router returns the service's router.
Router() chi.Router
}
// HTTPInfo can be embedded in structs to help implement HTTPService.
type HTTPInfo struct {
Info
mux *chi.Mux
}
func NewHTTPInfo(name string, endpoints []Endpoint) HTTPInfo {
return HTTPInfo{
Info: NewInfo(name, endpoints),
mux: chi.NewMux(),
}
}
func (i *HTTPInfo) Router() chi.Router { return i.mux }
var _ HTTPService = (*SimpleHTTP)(nil)
// SimpleHTTP implements HTTPService usinig the default HTTP merger.
type SimpleHTTP struct{ HTTPInfo }
func NewSimpleHTTP(name string, endpoints []Endpoint) SimpleHTTP {
return SimpleHTTP{HTTPInfo: NewHTTPInfo(name, endpoints)}
}
func (s *SimpleHTTP) Merge(other Service) (Merger, error) {
return MergeHTTP(s, other)
}
// MergeHTTP merges two compatible HTTPServices.
//
// The second parameter is of type `Service` to make it easy to call
// from a `Merger.Merge` implementation.
func MergeHTTP(a HTTPService, b Service) (Merger, error) {
return newHTTPMerger(a).Merge(b)
}
var _ HTTPService = (*httpMerger)(nil)
// httpMerger can merge HTTPServices by combining their routes.
type httpMerger struct {
inner []HTTPService
router chi.Router
endpoints endpointSet
}
func newHTTPMerger(first HTTPService) *httpMerger {
return &httpMerger{
inner: []HTTPService{first},
router: first.Router(),
}
}
func (m *httpMerger) String() string { return svcString(m) }
func (m *httpMerger) ServiceName() string {
names := util.ForEach(m.inner, func(svc HTTPService) string {
return svc.ServiceName()
})
return strings.Join(names, " & ")
}
func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() }
func (m *httpMerger) Router() chi.Router { return m.router }
func (m *httpMerger) Merge(other Service) (Merger, error) {
httpSvc, ok := other.(HTTPService)
if !ok {
return nil, errors.New("not an HTTPService")
}
type middleware = func(http.Handler) http.Handler
// Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
m.router.With(middlewares...).Method(method, route, handler)
// Expose /example/ as /example too
// Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830
// This means we expose the route without the slash even if it wasn't oringinally registered as such!
// The main point of this is for DoH (/dns-query).
if strings.HasSuffix(route, "/") {
route := strings.TrimSuffix(route, "/")
m.router.With(middlewares...).Method(method, route, handler)
}
return nil
})
m.inner = append(m.inner, httpSvc)
// Don't expose any service more than it expects
m.endpoints.IntersectSlice(other.ExposeOn())
return m, nil
}

74
service/listener.go Normal file
View File

@ -0,0 +1,74 @@
package service
import (
"context"
"crypto/tls"
"fmt"
"net"
)
// Listener is a net.Listener that provides information about
// what protocol and address it is configured for.
type Listener interface {
fmt.Stringer
net.Listener
// Exposes returns the endpoint for this listener.
//
// It can be used to find service(s) with matching configuration.
Exposes() Endpoint
}
// ListenerInfo can be embedded in structs to help implement Listener.
type ListenerInfo struct {
Endpoint
}
func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint }
// NetListener implements Listener using an existing net.Listener.
type NetListener struct {
net.Listener
ListenerInfo
}
func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener {
return &NetListener{
Listener: inner,
ListenerInfo: ListenerInfo{endpoint},
}
}
// TCPListener is a Listener for a TCP socket.
type TCPListener struct{ NetListener }
// ListenTCP creates a new TCPListener.
func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) {
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf)
if err != nil {
return nil, err // err already has all the info we could add
}
inner := NewNetListener(endpoint, l)
return &TCPListener{*inner}, nil
}
// TLSListener is a Listener using TLS over TCP.
type TLSListener struct{ NetListener }
// ListenTLS creates a new TLSListener.
func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) {
tcp, err := ListenTCP(ctx, endpoint)
if err != nil {
return nil, err
}
inner := tcp.NetListener
inner.Listener = tls.NewListener(inner.Listener, cfg)
return &TLSListener{inner}, nil
}

57
service/merge.go Normal file
View File

@ -0,0 +1,57 @@
package service
import "errors"
// Merger is a Service that can be merged with another compatible one.
type Merger interface {
Service
// Merge returns the result of merging the receiver with the other Service.
//
// Neither the receiver, nor the other Service should be used directly after
// calling this method.
Merge(other Service) (Merger, error)
}
// MergeAll merges the given services, if they are compatible.
//
// This allows using multiple compatible services with a single listener.
//
// All passed-in services must not be re-used.
func MergeAll(services []Service) (Service, error) {
switch len(services) {
case 0:
return nil, errors.New("no services given")
case 1:
return services[0], nil
}
merger, err := firstMerger(services)
if err != nil {
return nil, err
}
for _, svc := range services {
if svc == merger {
continue
}
merger, err = merger.Merge(svc)
if err != nil {
return nil, err
}
}
return merger, nil
}
func firstMerger(services []Service) (Merger, error) {
for _, t := range services {
if svc, ok := t.(Merger); ok {
return svc, nil
}
}
return nil, errors.New("no merger found")
}

113
service/service.go Normal file
View File

@ -0,0 +1,113 @@
// Package service exposes types to abstract services from the networking.
//
// The idea is that we build a set of services and a set of network endpoints (Listener).
// The services are then assigned to endpoints based on the address(es) they were configured for.
//
// Actual service to endpoint binding is not handled by the abstractions in this package as it is
// protocol specific.
// The general pattern is to make a "server" that wraps a service, and can then be started on an
// endpoint using a `Serve` method, similar to `http.Server`.
//
// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port),
// services can implement `Merger`.
package service
import (
"fmt"
"slices"
"strings"
"github.com/0xERR0R/blocky/util"
)
// Service is a network exposed service.
//
// It contains only the logic and user configured addresses it should be exposed on.
// Is is meant to be associated to one or more sockets via those addresses.
// Actual association with a socket is protocol specific.
type Service interface {
fmt.Stringer
// ServiceName returns the user friendly name of the service.
ServiceName() string
// ExposeOn returns the set of endpoints the service should be exposed on.
//
// They can be used to find listener(s) with matching configuration.
ExposeOn() []Endpoint
}
func svcString(s Service) string {
endpoints := util.ForEach(s.ExposeOn(), func(e Endpoint) string { return e.String() })
return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", "))
}
// Info can be embedded in structs to help implement Service.
type Info struct {
name string
endpoints []Endpoint
}
func NewInfo(name string, endpoints []Endpoint) Info {
return Info{
name: name,
endpoints: endpoints,
}
}
func (i *Info) ServiceName() string { return i.name }
func (i *Info) ExposeOn() []Endpoint { return i.endpoints }
func (i *Info) String() string { return svcString(i) }
// GroupByListener returns a map of listener and services grouped by configured address.
//
// Each input listener is a key in the map. The corresponding value is a service
// merged from all services with a matching address.
func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) {
res := make(map[Listener]Service, len(listeners))
unused := slices.Clone(services)
for _, listener := range listeners {
services := findAllCompatible(services, listener.Exposes())
if len(services) == 0 {
return nil, fmt.Errorf("found no compatible services for listener %s", listener)
}
svc, err := MergeAll(services)
if err != nil {
return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err)
}
res[listener] = svc
for _, svc := range services {
if i := slices.Index(unused, svc); i != -1 {
unused = slices.Delete(unused, i, i+1)
}
}
}
if len(unused) != 0 {
return nil, fmt.Errorf("found no compatible listener for services: %v", unused)
}
return res, nil
}
// findAllCompatible returns the subset of services that use the given Listener.
func findAllCompatible(services []Service, endpoint Endpoint) []Service {
res := make([]Service, 0, len(services))
for _, svc := range services {
if isExposedOn(svc, endpoint) {
res = append(res, svc)
}
}
return res
}
func isExposedOn(svc Service, endpoint Endpoint) bool {
return slices.Index(svc.ExposeOn(), endpoint) != -1
}

33
util/slices.go Normal file
View File

@ -0,0 +1,33 @@
package util
// ForEach implements the functional map operation, under a different
// name to avoid confusion with Go's map type.
func ForEach[T, U any](slice []T, convert func(T) U) []U {
res := make([]U, 0, len(slice))
for _, t := range slice {
u := convert(t)
res = append(res, u)
}
return res
}
// ConcatSlices returns a new slice with contents of all the inputs concatenated.
func ConcatSlices[T any](slices ...[]T) []T {
// Allocation is usually the bottleneck, so do it all at once
totalLen := 0
for _, slice := range slices {
totalLen += len(slice)
}
res := make([]T, 0, totalLen)
for _, slice := range slices {
res = append(res, slice...)
}
return res
}