mirror of https://github.com/0xERR0R/blocky.git
fix: use proxy env vars via Go default HTTP Transport values
Don't build `http.Transport` instances from scratch, but start from `http.DefaultTransport` and override what is needed.
This commit is contained in:
parent
5040ed8216
commit
d5b6ee93b5
|
@ -0,0 +1,71 @@
|
|||
package helpertest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/onsi/ginkgo/v2"
|
||||
)
|
||||
|
||||
type HTTPProxy struct {
|
||||
Addr net.Addr
|
||||
requestTarget atomic.Value // string: HTTP Host of latest request
|
||||
}
|
||||
|
||||
// TestHTTPProxy returns a new HTTPProxy server.
|
||||
//
|
||||
// All requests return http.StatusNotImplemented.
|
||||
func TestHTTPProxy() *HTTPProxy {
|
||||
proxyListener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
if err != nil {
|
||||
ginkgo.Fail(fmt.Sprintf("could not create HTTP proxy listener: %s", err))
|
||||
}
|
||||
|
||||
proxy := &HTTPProxy{
|
||||
Addr: proxyListener.Addr(),
|
||||
}
|
||||
|
||||
proxySrv := http.Server{ //nolint:gosec
|
||||
Addr: "127.0.0.1:0",
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
go func() { _ = proxySrv.Serve(proxyListener) }()
|
||||
ginkgo.DeferCleanup(proxySrv.Close)
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
// URL returns the proxy's URL for use by clients.
|
||||
func (p *HTTPProxy) URL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: "http",
|
||||
Host: p.Addr.String(),
|
||||
}
|
||||
}
|
||||
|
||||
// Check ReqURL has the right type signature for http.Transport.Proxy
|
||||
var _ = http.Transport{Proxy: (*HTTPProxy)(nil).ReqURL}
|
||||
|
||||
func (p *HTTPProxy) ReqURL(*http.Request) (*url.URL, error) {
|
||||
return p.URL(), nil
|
||||
}
|
||||
|
||||
// RequestTarget returns the target of the last request.
|
||||
func (p *HTTPProxy) RequestTarget() string {
|
||||
val := p.requestTarget.Load()
|
||||
if val == nil {
|
||||
ginkgo.Fail(fmt.Sprintf("http proxy %s received no requests", p.Addr))
|
||||
}
|
||||
|
||||
return val.(string)
|
||||
}
|
||||
|
||||
func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
p.requestTarget.Store(req.Host)
|
||||
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
}
|
|
@ -54,7 +54,7 @@ var _ = Describe("Downloader", func() {
|
|||
|
||||
Describe("NewDownloader", func() {
|
||||
It("Should use provided parameters", func() {
|
||||
transport := &http.Transport{}
|
||||
transport := new(http.Transport)
|
||||
|
||||
sut = NewDownloader(
|
||||
config.Downloader{
|
||||
|
@ -96,6 +96,7 @@ var _ = Describe("Downloader", func() {
|
|||
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
DeferCleanup(server.Close)
|
||||
|
||||
sutConfig.Attempts = 3
|
||||
})
|
||||
|
@ -212,5 +213,17 @@ var _ = Describe("Downloader", func() {
|
|||
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
|
||||
})
|
||||
})
|
||||
When("a proxy is configured", func() {
|
||||
It("should be used", func(ctx context.Context) {
|
||||
proxy := TestHTTPProxy()
|
||||
|
||||
sut.client.Transport = &http.Transport{Proxy: proxy.ReqURL}
|
||||
|
||||
_, err := sut.DownloadFile(ctx, "http://example.com")
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
||||
Expect(proxy.RequestTarget()).Should(Equal("example.com"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -156,18 +156,17 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string
|
|||
|
||||
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
|
||||
func (b *Bootstrap) NewHTTPTransport() *http.Transport {
|
||||
if b.resolver == nil {
|
||||
return &http.Transport{
|
||||
DialContext: b.dialer.DialContext,
|
||||
}
|
||||
}
|
||||
transport := util.DefaultHTTPTransport()
|
||||
transport.DialContext = b.dialContext
|
||||
|
||||
return &http.Transport{
|
||||
DialContext: b.dialContext,
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if b.resolver == nil {
|
||||
return b.dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr})
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/0xERR0R/blocky/config"
|
||||
|
@ -77,10 +78,15 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
|
|||
})
|
||||
|
||||
Describe("HTTP transport", func() {
|
||||
It("should use the system resolver", func() {
|
||||
It("should use Go default values", func() {
|
||||
transport := sut.NewHTTPTransport()
|
||||
|
||||
Expect(transport).ShouldNot(BeNil())
|
||||
|
||||
Expect(
|
||||
reflect.ValueOf(transport.Proxy).Pointer(),
|
||||
).Should(Equal(
|
||||
reflect.ValueOf(http.ProxyFromEnvironment).Pointer(),
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -98,13 +98,13 @@ func createUpstreamClient(cfg upstreamConfig) upstreamClient {
|
|||
|
||||
switch cfg.Net {
|
||||
case config.NetProtocolHttps:
|
||||
transport := util.DefaultHTTPTransport()
|
||||
transport.TLSClientConfig = &tlsConfig
|
||||
|
||||
return &httpUpstreamClient{
|
||||
userAgent: cfg.UserAgent,
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tlsConfig,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
Transport: transport,
|
||||
},
|
||||
host: cfg.Host,
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@ package resolver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -195,7 +195,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
})
|
||||
})
|
||||
|
||||
Describe("Using Dns over HTTP (DOH) upstream", func() {
|
||||
Describe("Using DNS over HTTPS (DoH) upstream", func() {
|
||||
var (
|
||||
respFn func(request *dns.Msg) (response *dns.Msg)
|
||||
modifyHTTPRespFn func(w http.ResponseWriter)
|
||||
|
@ -211,18 +211,34 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
}
|
||||
})
|
||||
|
||||
transport := func() *http.Transport {
|
||||
upstreamClient := sut.upstreamClient.(*httpUpstreamClient)
|
||||
|
||||
return upstreamClient.client.Transport.(*http.Transport)
|
||||
}
|
||||
|
||||
JustBeforeEach(func() {
|
||||
sutConfig.Upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn)
|
||||
sut = newUpstreamResolverUnchecked(sutConfig, nil)
|
||||
|
||||
// use insecure certificates for test doh upstream
|
||||
sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
// use insecure certificates for test DoH upstream
|
||||
transport().TLSClientConfig.InsecureSkipVerify = true
|
||||
})
|
||||
When("Configured DOH resolver can resolve query", func() {
|
||||
|
||||
When("a proxy is configured", func() {
|
||||
It("should use it", func() {
|
||||
proxy := TestHTTPProxy()
|
||||
|
||||
transport().Proxy = proxy.ReqURL
|
||||
|
||||
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
|
||||
Expect(err).Should(HaveOccurred())
|
||||
|
||||
upstreamHostPort := net.JoinHostPort(sutConfig.Upstream.Host, fmt.Sprint(sutConfig.Port))
|
||||
Expect(proxy.RequestTarget()).Should(Equal(upstreamHostPort))
|
||||
})
|
||||
})
|
||||
When("Configured DoH resolver can resolve query", func() {
|
||||
It("should return answer from DNS upstream", func() {
|
||||
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
|
||||
Should(
|
||||
|
@ -235,7 +251,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
))
|
||||
})
|
||||
})
|
||||
When("Configured DOH resolver returns wrong http status code", func() {
|
||||
When("Configured DoH resolver returns wrong http status code", func() {
|
||||
BeforeEach(func() {
|
||||
modifyHTTPRespFn = func(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
@ -247,7 +263,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
|
||||
})
|
||||
})
|
||||
When("Configured DOH resolver returns wrong content type", func() {
|
||||
When("Configured DoH resolver returns wrong content type", func() {
|
||||
BeforeEach(func() {
|
||||
modifyHTTPRespFn = func(w http.ResponseWriter) {
|
||||
w.Header().Set("content-type", "text")
|
||||
|
@ -260,7 +276,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
|
||||
})
|
||||
})
|
||||
When("Configured DOH resolver returns wrong content", func() {
|
||||
When("Configured DoH resolver returns wrong content", func() {
|
||||
BeforeEach(func() {
|
||||
modifyHTTPRespFn = func(w http.ResponseWriter) {
|
||||
_, _ = w.Write([]byte("wrongcontent"))
|
||||
|
@ -272,7 +288,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
|
|||
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
|
||||
})
|
||||
})
|
||||
When("Configured DOH resolver does not respond", func() {
|
||||
When("Configured DoH resolver does not respond", func() {
|
||||
JustBeforeEach(func() {
|
||||
sutConfig.Upstream = config.Upstream{
|
||||
Net: config.NetProtocolHttps,
|
||||
|
|
46
util/http.go
46
util/http.go
|
@ -1,10 +1,56 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var baseTransport *http.Transport
|
||||
|
||||
//nolint:gochecknoinits
|
||||
func init() {
|
||||
base, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
panic(fmt.Errorf(
|
||||
"unsupported Go version: http.DefaultTransport is not of type *http.Transport: it is a %T",
|
||||
http.DefaultTransport,
|
||||
))
|
||||
}
|
||||
|
||||
baseTransport = base
|
||||
}
|
||||
|
||||
// DefaultHTTPTransport returns a new Transport with the same defaults as net/http.
|
||||
func DefaultHTTPTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Dial: baseTransport.Dial, //nolint:staticcheck
|
||||
DialContext: baseTransport.DialContext,
|
||||
DialTLS: baseTransport.DialTLS, //nolint:staticcheck
|
||||
DialTLSContext: baseTransport.DialTLSContext,
|
||||
DisableCompression: baseTransport.DisableCompression,
|
||||
DisableKeepAlives: baseTransport.DisableKeepAlives,
|
||||
ExpectContinueTimeout: baseTransport.ExpectContinueTimeout,
|
||||
ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2,
|
||||
GetProxyConnectHeader: baseTransport.GetProxyConnectHeader,
|
||||
IdleConnTimeout: baseTransport.IdleConnTimeout,
|
||||
MaxConnsPerHost: baseTransport.MaxConnsPerHost,
|
||||
MaxIdleConns: baseTransport.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: baseTransport.MaxConnsPerHost,
|
||||
MaxResponseHeaderBytes: baseTransport.MaxResponseHeaderBytes,
|
||||
OnProxyConnectResponse: baseTransport.OnProxyConnectResponse,
|
||||
Proxy: baseTransport.Proxy,
|
||||
ProxyConnectHeader: baseTransport.ProxyConnectHeader,
|
||||
ReadBufferSize: baseTransport.ReadBufferSize,
|
||||
ResponseHeaderTimeout: baseTransport.ResponseHeaderTimeout,
|
||||
TLSClientConfig: baseTransport.TLSClientConfig,
|
||||
TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout,
|
||||
TLSNextProto: baseTransport.TLSNextProto,
|
||||
WriteBufferSize: baseTransport.WriteBufferSize,
|
||||
}
|
||||
}
|
||||
|
||||
func HTTPClientIP(r *http.Request) net.IP {
|
||||
addr := r.Header.Get("X-FORWARDED-FOR")
|
||||
if addr == "" {
|
||||
|
|
|
@ -1,14 +1,39 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("HTTP Util", func() {
|
||||
Describe("DefaultHTTPTransport", func() {
|
||||
It("returns a new transport", func() {
|
||||
a := DefaultHTTPTransport()
|
||||
Expect(a).Should(BeIdenticalTo(a))
|
||||
|
||||
b := DefaultHTTPTransport()
|
||||
Expect(a).ShouldNot(BeIdenticalTo(b))
|
||||
})
|
||||
|
||||
It("returns a copy of http.DefaultTransport", func() {
|
||||
Expect(cmp.Diff(
|
||||
DefaultHTTPTransport(), http.DefaultTransport,
|
||||
cmpopts.IgnoreUnexported(http.Transport{}),
|
||||
// Non nil func field comparers
|
||||
cmp.Comparer(cmpAsPtrs[func(context.Context, string, string) (net.Conn, error)]),
|
||||
cmp.Comparer(cmpAsPtrs[func(*http.Request) (*url.URL, error)]),
|
||||
)).Should(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("HTTPClientIP", func() {
|
||||
It("extracts the IP from RemoteAddr", func() {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
@ -43,3 +68,10 @@ var _ = Describe("HTTP Util", func() {
|
|||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Go and cmp don't define func comparisons, besides with nil.
|
||||
// In practice we can just compare them as pointers.
|
||||
// See https://github.com/google/go-cmp/issues/162
|
||||
func cmpAsPtrs[T any](x, y T) bool {
|
||||
return reflect.ValueOf(x).Pointer() == reflect.ValueOf(y).Pointer()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue