mirror of https://github.com/0xERR0R/blocky.git
refactor(server): add `resolve` for common query code
Ensure all queries go through that common code path so we always enable compression, truncate if required, etc.
This commit is contained in:
parent
e9a1e8974d
commit
f0ad412d8d
|
@ -417,7 +417,9 @@ func createQueryResolver(
|
|||
|
||||
func (s *Server) registerDNSHandlers(ctx context.Context) {
|
||||
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
|
||||
s.OnRequest(ctx, w, request)
|
||||
ip, proto := resolveClientIPAndProtocol(w.RemoteAddr())
|
||||
|
||||
s.OnRequest(ctx, w, ip, proto, request)
|
||||
}
|
||||
|
||||
for _, server := range s.dnsServers {
|
||||
|
@ -550,25 +552,6 @@ func (s *Server) Stop(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func createResolverRequest(rw dns.ResponseWriter, request *dns.Msg) *model.Request {
|
||||
var hostName string
|
||||
|
||||
var remoteAddr net.Addr
|
||||
|
||||
if rw != nil {
|
||||
remoteAddr = rw.RemoteAddr()
|
||||
}
|
||||
|
||||
clientIP, protocol := resolveClientIPAndProtocol(remoteAddr)
|
||||
con, ok := rw.(dns.ConnectionStater)
|
||||
|
||||
if ok && con.ConnectionState() != nil {
|
||||
hostName = con.ConnectionState().ServerName
|
||||
}
|
||||
|
||||
return newRequest(clientIP, protocol, extractClientIDFromHost(hostName), request)
|
||||
}
|
||||
|
||||
func extractClientIDFromHost(hostName string) string {
|
||||
const clientIDPrefix = "id-"
|
||||
if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") {
|
||||
|
@ -578,12 +561,13 @@ func extractClientIDFromHost(hostName string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func newRequest(clientIP net.IP, protocol model.RequestProtocol,
|
||||
requestClientID string, request *dns.Msg,
|
||||
func newRequest(
|
||||
clientIP net.IP, clientID string,
|
||||
protocol model.RequestProtocol, request *dns.Msg,
|
||||
) *model.Request {
|
||||
return &model.Request{
|
||||
ClientIP: clientIP,
|
||||
RequestClientID: requestClientID,
|
||||
RequestClientID: clientID,
|
||||
Protocol: protocol,
|
||||
Req: request,
|
||||
Log: log.Log().WithFields(logrus.Fields{
|
||||
|
@ -595,13 +579,23 @@ func newRequest(clientIP net.IP, protocol model.RequestProtocol,
|
|||
}
|
||||
|
||||
// OnRequest will be executed if a new DNS request is received
|
||||
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
|
||||
func (s *Server) OnRequest(
|
||||
ctx context.Context, w dns.ResponseWriter,
|
||||
clientIP net.IP, protocol model.RequestProtocol,
|
||||
request *dns.Msg,
|
||||
) {
|
||||
logger().Debug("new request")
|
||||
|
||||
r := createResolverRequest(w, request)
|
||||
var hostName string
|
||||
|
||||
response, err := s.queryResolver.Resolve(ctx, r)
|
||||
con, ok := w.(dns.ConnectionStater)
|
||||
if ok && con.ConnectionState() != nil {
|
||||
hostName = con.ConnectionState().ServerName
|
||||
}
|
||||
|
||||
req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request)
|
||||
|
||||
response, err := s.resolve(ctx, req)
|
||||
if err != nil {
|
||||
logger().Error("error on processing request:", err)
|
||||
|
||||
|
@ -610,19 +604,42 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
|
|||
err := w.WriteMsg(m)
|
||||
util.LogOnError("can't write message: ", err)
|
||||
} else {
|
||||
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired
|
||||
|
||||
// truncate if necessary
|
||||
response.Res.Truncate(getMaxResponseSize(r))
|
||||
|
||||
// enable compression
|
||||
response.Res.Compress = true
|
||||
|
||||
err := w.WriteMsg(response.Res)
|
||||
util.LogOnError("can't write message: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||
var response *model.Response
|
||||
|
||||
switch {
|
||||
case len(request.Req.Question) == 0:
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(request.Req, dns.RcodeFormatError)
|
||||
|
||||
request.Log.Error("query has no questions")
|
||||
|
||||
response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||
default:
|
||||
var err error
|
||||
|
||||
response, err = s.queryResolver.Resolve(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
response.Res.MsgHdr.RecursionAvailable = request.Req.MsgHdr.RecursionDesired
|
||||
|
||||
// truncate if necessary
|
||||
response.Res.Truncate(getMaxResponseSize(request))
|
||||
|
||||
// enable compression
|
||||
response.Res.Compress = true
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
|
||||
func getMaxResponseSize(req *model.Request) int {
|
||||
edns := req.Req.IsEdns0()
|
||||
|
|
|
@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
|||
clientID = extractClientIDFromHost(req.Host)
|
||||
}
|
||||
|
||||
r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg)
|
||||
r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
|
||||
|
||||
resResponse, err := s.queryResolver.Resolve(req.Context(), r)
|
||||
resResponse, err := s.resolve(req.Context(), r)
|
||||
if err != nil {
|
||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||
|
||||
|
@ -173,9 +173,9 @@ func (s *Server) Query(
|
|||
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
||||
) (*model.Response, error) {
|
||||
dnsRequest := util.NewMsgWithQuestion(question, qType)
|
||||
r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest)
|
||||
r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
|
||||
|
||||
return s.queryResolver.Resolve(ctx, r)
|
||||
return s.resolve(ctx, r)
|
||||
}
|
||||
|
||||
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
||||
|
|
Loading…
Reference in New Issue