mirror of https://github.com/0xERR0R/blocky.git
refactor: switch DoH to Service pattern
This commit is contained in:
parent
c11f9a1c98
commit
d7a2952b1d
|
@ -233,6 +233,7 @@ type Config struct {
|
|||
Redis Redis `yaml:"redis"`
|
||||
Log log.Config `yaml:"log"`
|
||||
Ports Ports `yaml:"ports"`
|
||||
Services Services `yaml:"-"` // not user exposed yet
|
||||
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
|
||||
CertFile string `yaml:"certFile"`
|
||||
KeyFile string `yaml:"keyFile"`
|
||||
|
@ -262,6 +263,17 @@ type Config struct {
|
|||
} `yaml:",inline"`
|
||||
}
|
||||
|
||||
// Services holds network service related configuration.
|
||||
//
|
||||
// The actual config layout is not decided yet.
|
||||
// See https://github.com/0xERR0R/blocky/issues/1206
|
||||
//
|
||||
// The `yaml` struct tags are just for manual testing,
|
||||
// and require replacing `yaml:"-"` in Config to work.
|
||||
type Services struct {
|
||||
DoH DoHService `yaml:"dns-over-https"`
|
||||
}
|
||||
|
||||
type Ports struct {
|
||||
DNS ListenConfig `yaml:"dns" default:"53"`
|
||||
HTTP ListenConfig `yaml:"http"`
|
||||
|
@ -601,6 +613,19 @@ func (cfg *Config) validate(logger *logrus.Entry) {
|
|||
cfg.Upstreams.validate(logger)
|
||||
}
|
||||
|
||||
// CopyPortsToServices sets Services values to match Ports.
|
||||
//
|
||||
// This should be replaced with a migration once everything from Ports is supported in Services.
|
||||
// Done this way for now to avoid creating temporary generic services and updating all Ports related code at once.
|
||||
func (cfg *Config) CopyPortsToServices() {
|
||||
cfg.Services = Services{
|
||||
DoH: DoHService{Addrs: DoHAddrs{
|
||||
HTTPAddrs: HTTPAddrs{HTTP: cfg.Ports.HTTP},
|
||||
HTTPSAddrs: HTTPSAddrs{HTTPS: cfg.Ports.HTTPS},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertPort converts string representation into a valid port (0 - 65535)
|
||||
func ConvertPort(in string) (uint16, error) {
|
||||
const (
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
package config
|
||||
|
||||
type DoHService struct {
|
||||
Addrs DoHAddrs `yaml:"addrs"`
|
||||
}
|
||||
|
||||
type DoHAddrs struct {
|
||||
HTTPAddrs `yaml:",inline"`
|
||||
HTTPSAddrs `yaml:",inline"`
|
||||
}
|
||||
|
||||
type HTTPAddrs struct {
|
||||
HTTP ListenConfig `yaml:"http"`
|
||||
}
|
||||
|
||||
type HTTPSAddrs struct {
|
||||
HTTPS ListenConfig `yaml:"https"`
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
"github.com/0xERR0R/blocky/service"
|
||||
"github.com/0xERR0R/blocky/util"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type dohService struct {
|
||||
service.HTTPInfo
|
||||
|
||||
handler dnsHandler
|
||||
}
|
||||
|
||||
func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService {
|
||||
endpoints := util.ConcatSlices(
|
||||
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP),
|
||||
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS),
|
||||
)
|
||||
|
||||
s := &dohService{
|
||||
HTTPInfo: service.HTTPInfo{
|
||||
Info: service.Info{
|
||||
Name: "DoH",
|
||||
Endpoints: endpoints,
|
||||
},
|
||||
|
||||
Mux: chi.NewMux(),
|
||||
},
|
||||
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
s.Mux.Route("/dns-query", func(mux chi.Router) {
|
||||
// Handlers for / also handle /dns-query without trailing slash
|
||||
|
||||
mux.Get("/", s.handleGET)
|
||||
mux.Get("/{clientID}", s.handleGET)
|
||||
|
||||
mux.Post("/", s.handlePOST)
|
||||
mux.Post("/{clientID}", s.handlePOST)
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *dohService) Merge(other service.Service) (service.Merger, error) {
|
||||
return service.MergeHTTP(s, other)
|
||||
}
|
||||
|
||||
func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) {
|
||||
dnsParam, ok := req.URL.Query()["dns"]
|
||||
if !ok || len(dnsParam[0]) < 1 {
|
||||
http.Error(rw, "dns param is missing", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
|
||||
if err != nil {
|
||||
http.Error(rw, "wrong message format", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) {
|
||||
contentType := req.Header.Get("Content-type")
|
||||
if contentType != dnsContentType {
|
||||
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(rawMsg); err != nil {
|
||||
logger().Error("can't deserialize message: ", err)
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
|
||||
|
||||
s.handler(ctx, dnsReq, httpMsgWriter{rw})
|
||||
}
|
||||
|
||||
type httpMsgWriter struct {
|
||||
rw http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
|
||||
b, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.rw.Header().Set("content-type", dnsContentType)
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
|
||||
r.rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err = r.rw.Write(b)
|
||||
|
||||
return err
|
||||
}
|
|
@ -23,9 +23,7 @@ type httpMiscService struct {
|
|||
service.HTTPInfo
|
||||
}
|
||||
|
||||
func newHTTPMiscService(
|
||||
cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler,
|
||||
) *httpMiscService {
|
||||
func newHTTPMiscService(cfg *config.Config, openAPIImpl api.StrictServerInterface) *httpMiscService {
|
||||
endpoints := util.ConcatSlices(
|
||||
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
|
||||
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
|
||||
|
@ -38,7 +36,7 @@ func newHTTPMiscService(
|
|||
Endpoints: endpoints,
|
||||
},
|
||||
|
||||
Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler),
|
||||
Mux: createHTTPRouter(cfg, openAPIImpl),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -118,7 +118,11 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
|
|||
}
|
||||
|
||||
// NewServer creates new server instance with passed config
|
||||
//
|
||||
//nolint:funlen
|
||||
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
|
||||
cfg.CopyPortsToServices()
|
||||
|
||||
var tlsCfg *tls.Config
|
||||
|
||||
if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 {
|
||||
|
@ -188,7 +192,8 @@ func (s *Server) createServices() ([]service.Service, error) {
|
|||
}
|
||||
|
||||
res := []service.Service{
|
||||
newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq),
|
||||
newHTTPMiscService(s.cfg, openAPIImpl),
|
||||
newDoHService(s.cfg.Services.DoH, s.handleReq),
|
||||
}
|
||||
|
||||
// Remove services the user has not enabled
|
||||
|
@ -237,6 +242,8 @@ func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config
|
|||
err := errors.Join(
|
||||
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
|
||||
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
|
||||
newListeners(ctx, service.HTTPProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res),
|
||||
newListeners(ctx, service.HTTPSProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -2,10 +2,8 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
|
@ -52,103 +50,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
|
|||
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
|
||||
}
|
||||
|
||||
func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) {
|
||||
const pathDohQuery = "/dns-query"
|
||||
|
||||
s := &dohServer{dnsHandler}
|
||||
|
||||
router.Get(pathDohQuery, s.dohGetRequestHandler)
|
||||
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
|
||||
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
|
||||
router.Post(pathDohQuery, s.dohPostRequestHandler)
|
||||
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
|
||||
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
|
||||
}
|
||||
|
||||
type dohServer struct{ handler dnsHandler }
|
||||
|
||||
func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
dnsParam, ok := req.URL.Query()["dns"]
|
||||
if !ok || len(dnsParam[0]) < 1 {
|
||||
http.Error(rw, "dns param is missing", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
|
||||
if err != nil {
|
||||
http.Error(rw, "wrong message format", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
contentType := req.Header.Get("Content-type")
|
||||
if contentType != dnsContentType {
|
||||
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(rawMsg) > dohMessageLimit {
|
||||
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.processDohMessage(rawMsg, rw, req)
|
||||
}
|
||||
|
||||
func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(rawMsg); err != nil {
|
||||
logger().Error("can't deserialize message: ", err)
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)
|
||||
|
||||
s.handler(ctx, dnsReq, httpMsgWriter{rw})
|
||||
}
|
||||
|
||||
type httpMsgWriter struct {
|
||||
rw http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
|
||||
b, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.rw.Header().Set("content-type", dnsContentType)
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
|
||||
r.rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err = r.rw.Write(b)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) Query(
|
||||
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
||||
) (*model.Response, error) {
|
||||
|
@ -160,7 +61,7 @@ func (s *Server) Query(
|
|||
return s.resolve(ctx, req)
|
||||
}
|
||||
|
||||
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux {
|
||||
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux {
|
||||
router := chi.NewRouter()
|
||||
|
||||
api.RegisterOpenAPIEndpoints(router, openAPIImpl)
|
||||
|
@ -173,8 +74,6 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface,
|
|||
|
||||
configureRootHandler(cfg, router)
|
||||
|
||||
registerDoHEndpoints(router, dnsHandler)
|
||||
|
||||
metrics.Start(router, cfg.Prometheus)
|
||||
|
||||
return router
|
||||
|
|
|
@ -81,6 +81,15 @@ func (m *httpMerger) Merge(other Service) (Merger, error) {
|
|||
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
|
||||
m.router.With(middlewares...).Method(method, route, handler)
|
||||
|
||||
// Expose /example/ as /example too
|
||||
// Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830
|
||||
// This means we expose the route without the slash even if it wasn't oringinally registered as such!
|
||||
// The main point of this is for DoH (/dns-query).
|
||||
if strings.HasSuffix(route, "/") {
|
||||
route := strings.TrimSuffix(route, "/")
|
||||
m.router.With(middlewares...).Method(method, route, handler)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
|
|
Loading…
Reference in New Issue