mirror of https://github.com/0xERR0R/blocky.git
refactor(server): add `resolve` for common query code
Ensure all queries go through that common code path so we always enable compression, truncate if required, etc.
This commit is contained in:
parent
e9a1e8974d
commit
f0ad412d8d
|
@ -417,7 +417,9 @@ func createQueryResolver(
|
||||||
|
|
||||||
func (s *Server) registerDNSHandlers(ctx context.Context) {
|
func (s *Server) registerDNSHandlers(ctx context.Context) {
|
||||||
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
|
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
|
||||||
s.OnRequest(ctx, w, request)
|
ip, proto := resolveClientIPAndProtocol(w.RemoteAddr())
|
||||||
|
|
||||||
|
s.OnRequest(ctx, w, ip, proto, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, server := range s.dnsServers {
|
for _, server := range s.dnsServers {
|
||||||
|
@ -550,25 +552,6 @@ func (s *Server) Stop(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createResolverRequest(rw dns.ResponseWriter, request *dns.Msg) *model.Request {
|
|
||||||
var hostName string
|
|
||||||
|
|
||||||
var remoteAddr net.Addr
|
|
||||||
|
|
||||||
if rw != nil {
|
|
||||||
remoteAddr = rw.RemoteAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
clientIP, protocol := resolveClientIPAndProtocol(remoteAddr)
|
|
||||||
con, ok := rw.(dns.ConnectionStater)
|
|
||||||
|
|
||||||
if ok && con.ConnectionState() != nil {
|
|
||||||
hostName = con.ConnectionState().ServerName
|
|
||||||
}
|
|
||||||
|
|
||||||
return newRequest(clientIP, protocol, extractClientIDFromHost(hostName), request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractClientIDFromHost(hostName string) string {
|
func extractClientIDFromHost(hostName string) string {
|
||||||
const clientIDPrefix = "id-"
|
const clientIDPrefix = "id-"
|
||||||
if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") {
|
if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") {
|
||||||
|
@ -578,12 +561,13 @@ func extractClientIDFromHost(hostName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequest(clientIP net.IP, protocol model.RequestProtocol,
|
func newRequest(
|
||||||
requestClientID string, request *dns.Msg,
|
clientIP net.IP, clientID string,
|
||||||
|
protocol model.RequestProtocol, request *dns.Msg,
|
||||||
) *model.Request {
|
) *model.Request {
|
||||||
return &model.Request{
|
return &model.Request{
|
||||||
ClientIP: clientIP,
|
ClientIP: clientIP,
|
||||||
RequestClientID: requestClientID,
|
RequestClientID: clientID,
|
||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
Req: request,
|
Req: request,
|
||||||
Log: log.Log().WithFields(logrus.Fields{
|
Log: log.Log().WithFields(logrus.Fields{
|
||||||
|
@ -595,13 +579,23 @@ func newRequest(clientIP net.IP, protocol model.RequestProtocol,
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRequest will be executed if a new DNS request is received
|
// OnRequest will be executed if a new DNS request is received
|
||||||
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
|
func (s *Server) OnRequest(
|
||||||
|
ctx context.Context, w dns.ResponseWriter,
|
||||||
|
clientIP net.IP, protocol model.RequestProtocol,
|
||||||
|
request *dns.Msg,
|
||||||
|
) {
|
||||||
logger().Debug("new request")
|
logger().Debug("new request")
|
||||||
|
|
||||||
r := createResolverRequest(w, request)
|
var hostName string
|
||||||
|
|
||||||
response, err := s.queryResolver.Resolve(ctx, r)
|
con, ok := w.(dns.ConnectionStater)
|
||||||
|
if ok && con.ConnectionState() != nil {
|
||||||
|
hostName = con.ConnectionState().ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request)
|
||||||
|
|
||||||
|
response, err := s.resolve(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger().Error("error on processing request:", err)
|
logger().Error("error on processing request:", err)
|
||||||
|
|
||||||
|
@ -610,17 +604,40 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
|
||||||
err := w.WriteMsg(m)
|
err := w.WriteMsg(m)
|
||||||
util.LogOnError("can't write message: ", err)
|
util.LogOnError("can't write message: ", err)
|
||||||
} else {
|
} else {
|
||||||
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired
|
err := w.WriteMsg(response.Res)
|
||||||
|
util.LogOnError("can't write message: ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
|
||||||
|
var response *model.Response
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(request.Req.Question) == 0:
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(request.Req, dns.RcodeFormatError)
|
||||||
|
|
||||||
|
request.Log.Error("query has no questions")
|
||||||
|
|
||||||
|
response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
|
||||||
|
default:
|
||||||
|
var err error
|
||||||
|
|
||||||
|
response, err = s.queryResolver.Resolve(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Res.MsgHdr.RecursionAvailable = request.Req.MsgHdr.RecursionDesired
|
||||||
|
|
||||||
// truncate if necessary
|
// truncate if necessary
|
||||||
response.Res.Truncate(getMaxResponseSize(r))
|
response.Res.Truncate(getMaxResponseSize(request))
|
||||||
|
|
||||||
// enable compression
|
// enable compression
|
||||||
response.Res.Compress = true
|
response.Res.Compress = true
|
||||||
|
|
||||||
err := w.WriteMsg(response.Res)
|
return response, nil
|
||||||
util.LogOnError("can't write message: ", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
|
// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
|
||||||
|
|
|
@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
|
||||||
clientID = extractClientIDFromHost(req.Host)
|
clientID = extractClientIDFromHost(req.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg)
|
r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)
|
||||||
|
|
||||||
resResponse, err := s.queryResolver.Resolve(req.Context(), r)
|
resResponse, err := s.resolve(req.Context(), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logAndResponseWithError(err, "unable to process query: ", rw)
|
logAndResponseWithError(err, "unable to process query: ", rw)
|
||||||
|
|
||||||
|
@ -173,9 +173,9 @@ func (s *Server) Query(
|
||||||
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
|
||||||
) (*model.Response, error) {
|
) (*model.Response, error) {
|
||||||
dnsRequest := util.NewMsgWithQuestion(question, qType)
|
dnsRequest := util.NewMsgWithQuestion(question, qType)
|
||||||
r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest)
|
r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)
|
||||||
|
|
||||||
return s.queryResolver.Resolve(ctx, r)
|
return s.resolve(ctx, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
func createHTTPSRouter(cfg *config.Config) *chi.Mux {
|
||||||
|
|
Loading…
Reference in New Issue