blocky/server/server.go

335 lines
8.5 KiB
Go

package server
import (
"fmt"
"net"
"net/http"
"runtime"
"runtime/debug"
"strings"
"time"
"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/log"
"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
// Server controls the endpoints for DNS and HTTP
type Server struct {
udpServer *dns.Server
tcpServer *dns.Server
httpListener net.Listener
httpsListener net.Listener
queryResolver resolver.Resolver
cfg *config.Config
httpMux *chi.Mux
}
func logger() *logrus.Entry {
return log.PrefixedLog("server")
}
func getServerAddress(addr string) string {
address := addr
if !strings.Contains(addr, ":") {
address = fmt.Sprintf(":%s", addr)
}
return address
}
// NewServer creates new server instance with passed config
func NewServer(cfg *config.Config) (server *Server, err error) {
address := getServerAddress(cfg.Port)
log.ConfigureLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogTimestamp)
udpServer := createUDPServer(address)
tcpServer := createTCPServer(address)
var httpListener, httpsListener net.Listener
router := createRouter(cfg)
if cfg.HTTPPort != "" {
if httpListener, err = net.Listen("tcp", getServerAddress(cfg.HTTPPort)); err != nil {
return nil, fmt.Errorf("start http listener on %s failed: %w", cfg.HTTPPort, err)
}
metrics.Start(router, cfg.Prometheus)
}
if cfg.HTTPSPort != "" {
if cfg.CertFile == "" || cfg.KeyFile == "" {
return nil, fmt.Errorf("httpsCertFile and httpsKeyFile parameters are mandatory for HTTPS")
}
if httpsListener, err = net.Listen("tcp", getServerAddress(cfg.HTTPSPort)); err != nil {
return nil, fmt.Errorf("start https listener on port %s failed: %w", cfg.HTTPSPort, err)
}
metrics.Start(router, cfg.Prometheus)
}
metrics.RegisterEventListeners()
queryResolver := createQueryResolver(cfg)
server = &Server{
udpServer: udpServer,
tcpServer: tcpServer,
queryResolver: queryResolver,
cfg: cfg,
httpListener: httpListener,
httpsListener: httpsListener,
httpMux: router,
}
server.printConfiguration()
server.registerDNSHandlers(udpServer)
server.registerDNSHandlers(tcpServer)
server.registerAPIEndpoints(router)
registerResolverAPIEndpoints(router, queryResolver)
return server, nil
}
func registerResolverAPIEndpoints(router chi.Router, res resolver.Resolver) {
for res != nil {
api.RegisterEndpoint(router, res)
if cr, ok := res.(resolver.ChainedResolver); ok {
res = cr.GetNext()
} else {
return
}
}
}
func createTCPServer(address string) *dns.Server {
return &dns.Server{
Addr: address,
Net: "tcp",
Handler: dns.NewServeMux(),
NotifyStartedFunc: func() {
logger().Infof("tcp server is up and running on address %s", address)
},
}
}
func createUDPServer(address string) *dns.Server {
return &dns.Server{
Addr: address,
Net: "udp",
Handler: dns.NewServeMux(),
NotifyStartedFunc: func() {
logger().Infof("udp server is up and running on address %s", address)
},
UDPSize: 65535}
}
func createQueryResolver(cfg *config.Config) resolver.Resolver {
return resolver.Chain(
resolver.NewIPv6Checker(cfg.DisableIPv6),
resolver.NewClientNamesResolver(cfg.ClientLookup),
resolver.NewQueryLoggingResolver(cfg.QueryLog),
resolver.NewStatsResolver(),
resolver.NewMetricsResolver(cfg.Prometheus),
resolver.NewCustomDNSResolver(cfg.CustomDNS),
resolver.NewBlockingResolver(cfg.Blocking),
resolver.NewCachingResolver(cfg.Caching),
resolver.NewConditionalUpstreamResolver(cfg.Conditional),
resolver.NewParallelBestResolver(cfg.Upstream.ExternalResolvers),
)
}
func (s *Server) registerDNSHandlers(server *dns.Server) {
handler := server.Handler.(*dns.ServeMux)
handler.HandleFunc(".", s.OnRequest)
handler.HandleFunc("healthcheck.blocky", s.OnHealthCheck)
}
func (s *Server) printConfiguration() {
logger().Info("current configuration:")
res := s.queryResolver
for res != nil {
logger().Infof("-> resolver: '%s'", resolver.Name(res))
for _, c := range res.Configuration() {
logger().Infof(" %s", c)
}
if c, ok := res.(resolver.ChainedResolver); ok {
res = c.GetNext()
} else {
break
}
}
logger().Infof("- DNS listening port: '%s'", s.cfg.Port)
logger().Infof("- HTTP listening on addr/port: %s", s.cfg.HTTPPort)
logger().Info("runtime information:")
// force garbage collector
runtime.GC()
debug.FreeOSMemory()
// gather memory stats
var m runtime.MemStats
runtime.ReadMemStats(&m)
logger().Infof("MEM Alloc = %10v MB", toMB(m.Alloc))
logger().Infof("MEM HeapAlloc = %10v MB", toMB(m.HeapAlloc))
logger().Infof("MEM Sys = %10v MB", toMB(m.Sys))
logger().Infof("MEM NumGC = %10v", m.NumGC)
logger().Infof("RUN NumCPU = %10d", runtime.NumCPU())
logger().Infof("RUN NumGoroutine = %10d", runtime.NumGoroutine())
}
func toMB(b uint64) uint64 {
return b / 1024 / 1024
}
// Start starts the server
func (s *Server) Start() {
logger().Info("Starting server")
go func() {
if err := s.udpServer.ListenAndServe(); err != nil {
logger().Fatalf("start %s listener failed: %v", s.udpServer.Net, err)
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
logger().Fatalf("start %s listener failed: %v", s.tcpServer.Net, err)
}
}()
go func() {
if s.httpListener != nil {
logger().Infof("http server is up and running on addr/port %s", s.cfg.HTTPPort)
err := http.Serve(s.httpListener, s.httpMux)
util.FatalOnError("start http listener failed: ", err)
}
}()
go func() {
if s.httpsListener != nil {
logger().Infof("https server is up and running on addr/port %s", s.cfg.HTTPSPort)
err := http.ServeTLS(s.httpsListener, s.httpMux, s.cfg.CertFile, s.cfg.KeyFile)
util.FatalOnError("start https listener failed: ", err)
}
}()
registerPrintConfigurationTrigger(s)
}
// Stop stops the server
func (s *Server) Stop() {
logger().Info("Stopping server")
if err := s.udpServer.Shutdown(); err != nil {
logger().Fatalf("stop %s listener failed: %v", s.udpServer.Net, err)
}
if err := s.tcpServer.Shutdown(); err != nil {
logger().Fatalf("stop %s listener failed: %v", s.tcpServer.Net, err)
}
}
func createResolverRequest(remoteAddress net.Addr, request *dns.Msg) *resolver.Request {
clientIP, protocol := resolveClientIPAndProtocol(remoteAddress)
return newRequest(clientIP, protocol, request)
}
func newRequest(clientIP net.IP, protocol resolver.RequestProtocol, request *dns.Msg) *resolver.Request {
return &resolver.Request{
ClientIP: clientIP,
Protocol: protocol,
Req: request,
RequestTS: time.Now(),
Log: log.Log().WithFields(logrus.Fields{
"question": util.QuestionToString(request.Question),
"client_ip": clientIP,
}),
}
}
// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(w dns.ResponseWriter, request *dns.Msg) {
logger().Debug("new request")
r := createResolverRequest(w.RemoteAddr(), request)
response, err := s.queryResolver.Resolve(r)
if err != nil {
logger().Errorf("error on processing request: %v", err)
m := new(dns.Msg)
m.SetRcode(request, dns.RcodeServerFailure)
err := w.WriteMsg(m)
util.LogOnError("can't write message: ", err)
} else {
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired
// truncate if necessary
response.Res.Truncate(getMaxResponseSize(w.LocalAddr().Network(), request))
// enable compression
response.Res.Compress = true
err := w.WriteMsg(response.Res)
util.LogOnError("can't write message: ", err)
}
}
// returns EDNS upd size or if not present, 512 for UDP and 64K for TCP
func getMaxResponseSize(network string, request *dns.Msg) int {
edns := request.IsEdns0()
if edns != nil && edns.UDPSize() > 0 {
return int(edns.UDPSize())
}
if network == "tcp" {
return dns.MaxMsgSize
}
return dns.MinMsgSize
}
// OnHealthCheck Handler for docker health check. Just returns OK code without delegating to resolver chain
func (s *Server) OnHealthCheck(w dns.ResponseWriter, request *dns.Msg) {
resp := new(dns.Msg)
resp.SetReply(request)
resp.Rcode = dns.RcodeSuccess
err := w.WriteMsg(resp)
util.LogOnError("can't write message: ", err)
}
func resolveClientIPAndProtocol(addr net.Addr) (ip net.IP, protocol resolver.RequestProtocol) {
if t, ok := addr.(*net.UDPAddr); ok {
return t.IP, resolver.UDP
} else if t, ok := addr.(*net.TCPAddr); ok {
return t.IP, resolver.TCP
}
return nil, resolver.TCP
}