refactor(server): add `resolve` for common query code

Ensure all queries go through that common code path so we always enable
compression, truncate if required, etc.
This commit is contained in:
ThinkChaos 2023-12-19 19:40:34 -05:00
parent e9a1e8974d
commit f0ad412d8d
2 changed files with 55 additions and 38 deletions

View File

@ -417,7 +417,9 @@ func createQueryResolver(
func (s *Server) registerDNSHandlers(ctx context.Context) {
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
s.OnRequest(ctx, w, request)
ip, proto := resolveClientIPAndProtocol(w.RemoteAddr())
s.OnRequest(ctx, w, ip, proto, request)
}
for _, server := range s.dnsServers {
@ -550,25 +552,6 @@ func (s *Server) Stop(ctx context.Context) error {
return nil
}
func createResolverRequest(rw dns.ResponseWriter, request *dns.Msg) *model.Request {
var hostName string
var remoteAddr net.Addr
if rw != nil {
remoteAddr = rw.RemoteAddr()
}
clientIP, protocol := resolveClientIPAndProtocol(remoteAddr)
con, ok := rw.(dns.ConnectionStater)
if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
}
return newRequest(clientIP, protocol, extractClientIDFromHost(hostName), request)
}
func extractClientIDFromHost(hostName string) string {
const clientIDPrefix = "id-"
if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") {
@ -578,12 +561,13 @@ func extractClientIDFromHost(hostName string) string {
return ""
}
func newRequest(clientIP net.IP, protocol model.RequestProtocol,
requestClientID string, request *dns.Msg,
func newRequest(
clientIP net.IP, clientID string,
protocol model.RequestProtocol, request *dns.Msg,
) *model.Request {
return &model.Request{
ClientIP: clientIP,
RequestClientID: requestClientID,
RequestClientID: clientID,
Protocol: protocol,
Req: request,
Log: log.Log().WithFields(logrus.Fields{
@ -595,13 +579,23 @@ func newRequest(clientIP net.IP, protocol model.RequestProtocol,
}
// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
func (s *Server) OnRequest(
ctx context.Context, w dns.ResponseWriter,
clientIP net.IP, protocol model.RequestProtocol,
request *dns.Msg,
) {
logger().Debug("new request")
r := createResolverRequest(w, request)
var hostName string
response, err := s.queryResolver.Resolve(ctx, r)
con, ok := w.(dns.ConnectionStater)
if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
}
req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request)
response, err := s.resolve(ctx, req)
if err != nil {
logger().Error("error on processing request:", err)
@ -610,19 +604,42 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
err := w.WriteMsg(m)
util.LogOnError("can't write message: ", err)
} else {
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired
// truncate if necessary
response.Res.Truncate(getMaxResponseSize(r))
// enable compression
response.Res.Compress = true
err := w.WriteMsg(response.Res)
util.LogOnError("can't write message: ", err)
}
}
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
var response *model.Response
switch {
case len(request.Req.Question) == 0:
m := new(dns.Msg)
m.SetRcode(request.Req, dns.RcodeFormatError)
request.Log.Error("query has no questions")
response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
default:
var err error
response, err = s.queryResolver.Resolve(ctx, request)
if err != nil {
return nil, err
}
}
response.Res.MsgHdr.RecursionAvailable = request.Req.MsgHdr.RecursionDesired
// truncate if necessary
response.Res.Truncate(getMaxResponseSize(request))
// enable compression
response.Res.Compress = true
return response, nil
}
// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
func getMaxResponseSize(req *model.Request) int {
edns := req.Req.IsEdns0()

View File

@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
clientID = extractClientIDFromHost(req.Host)
}
r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg)
r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
resResponse, err := s.queryResolver.Resolve(req.Context(), r)
resResponse, err := s.resolve(req.Context(), r)
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)
@ -173,9 +173,9 @@ func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest)
r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
return s.queryResolver.Resolve(ctx, r)
return s.resolve(ctx, r)
}
func createHTTPSRouter(cfg *config.Config) *chi.Mux {