mirror of https://github.com/0xERR0R/blocky.git
fix(rewrite): support the case where upstream doesn't echo the question
Apparently Tailscale's magic DNS does this.
This commit is contained in:
parent
c6304e9d7f
commit
dece894bd6
|
@ -95,13 +95,25 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request)
|
|||
return r.next.Resolve(ctx, request)
|
||||
}
|
||||
|
||||
// Revert the rewrite in r.inner's response
|
||||
if rewritten != nil {
|
||||
for i := range originalNames {
|
||||
response.Res.Question[i].Name = originalNames[i]
|
||||
if rewritten == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
if i < len(response.Res.Answer) {
|
||||
response.Res.Answer[i].Header().Name = originalNames[i]
|
||||
// Revert the rewrite in r.inner's response
|
||||
|
||||
n := max(len(response.Res.Question), len(response.Res.Answer))
|
||||
for i := 0; i < n; i++ {
|
||||
if i < len(response.Res.Question) {
|
||||
original, ok := originalNames[response.Res.Question[i].Name]
|
||||
if ok {
|
||||
response.Res.Question[i].Name = original
|
||||
}
|
||||
}
|
||||
|
||||
if i < len(response.Res.Answer) {
|
||||
original, ok := originalNames[response.Res.Answer[i].Header().Name]
|
||||
if ok {
|
||||
response.Res.Answer[i].Header().Name = original
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -109,22 +121,27 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request)
|
|||
return response, nil
|
||||
}
|
||||
|
||||
func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg) (rewritten *dns.Msg, originalNames []string) { //nolint: lll
|
||||
originalNames = make([]string, len(request.Question))
|
||||
func (r *RewriterResolver) rewriteRequest(
|
||||
logger *logrus.Entry, request *dns.Msg,
|
||||
) (rewritten *dns.Msg, originalNames map[string]string) {
|
||||
originalNames = make(map[string]string, len(request.Question))
|
||||
|
||||
for i := range request.Question {
|
||||
nameOriginal := request.Question[i].Name
|
||||
originalNames[i] = nameOriginal
|
||||
|
||||
domainOriginal := util.ExtractDomainOnly(nameOriginal)
|
||||
domainRewritten, rewriteKey := r.rewriteDomain(domainOriginal)
|
||||
|
||||
if domainRewritten != domainOriginal {
|
||||
rewrittenFQDN := dns.Fqdn(domainRewritten)
|
||||
|
||||
originalNames[rewrittenFQDN] = nameOriginal
|
||||
|
||||
if rewritten == nil {
|
||||
rewritten = request.Copy()
|
||||
}
|
||||
|
||||
rewritten.Question[i].Name = dns.Fqdn(domainRewritten)
|
||||
rewritten.Question[i].Name = rewrittenFQDN
|
||||
|
||||
logger.WithFields(logrus.Fields{
|
||||
"rewrite": util.Obfuscate(rewriteKey) + ":" + util.Obfuscate(r.cfg.Rewrite[rewriteKey]),
|
||||
|
|
|
@ -67,17 +67,14 @@ var _ = Describe("RewriterResolver", func() {
|
|||
})
|
||||
|
||||
When("has rewrite", func() {
|
||||
var request *model.Request
|
||||
var expectNilAnswer bool
|
||||
var (
|
||||
request *model.Request
|
||||
expectNilAnswer bool
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
expectNilAnswer = false
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
request = newRequest(fqdnOriginal, dns.Type(dns.TypeA))
|
||||
|
||||
mInner.On("Resolve", mock.Anything)
|
||||
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
|
||||
Expect(req).Should(Equal(request.Req))
|
||||
|
||||
|
@ -95,11 +92,19 @@ var _ = Describe("RewriterResolver", func() {
|
|||
|
||||
return res
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
request = newRequest(fqdnOriginal, dns.Type(dns.TypeA))
|
||||
|
||||
mInner.On("Resolve", mock.Anything)
|
||||
|
||||
resp, err := sut.Resolve(context.Background(), request)
|
||||
Expect(err).Should(Succeed())
|
||||
if resp != mNextResponse {
|
||||
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||
if len(resp.Res.Question) != 0 {
|
||||
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
|
||||
}
|
||||
if expectNilAnswer {
|
||||
Expect(resp.Res.Answer).Should(BeEmpty())
|
||||
} else {
|
||||
|
@ -128,6 +133,19 @@ var _ = Describe("RewriterResolver", func() {
|
|||
fqdnRewritten = fqdnOriginal
|
||||
})
|
||||
|
||||
It("should support a reply without the question", func() {
|
||||
fqdnOriginal = sampleOriginal
|
||||
fqdnRewritten = sampleRewritten
|
||||
|
||||
origResponseFn := mInner.ResponseFn
|
||||
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
|
||||
res := origResponseFn(req)
|
||||
res.Question = nil
|
||||
|
||||
return res
|
||||
}
|
||||
})
|
||||
|
||||
It("should call next resolver", func() {
|
||||
fqdnOriginal = sampleOriginal
|
||||
fqdnRewritten = sampleRewritten
|
||||
|
|
Loading…
Reference in New Issue