fix: use proxy env vars via Go default HTTP Transport values

Don't build `http.Transport` instances from scratch, but start from
`http.DefaultTransport` and override what is needed.
This commit is contained in:
ThinkChaos 2024-03-23 15:45:08 -04:00
parent 5040ed8216
commit d5b6ee93b5
8 changed files with 211 additions and 28 deletions

71
helpertest/http.go Normal file
View File

@ -0,0 +1,71 @@
package helpertest
import (
"fmt"
"net"
"net/http"
"net/url"
"sync/atomic"
"github.com/onsi/ginkgo/v2"
)
type HTTPProxy struct {
Addr net.Addr
requestTarget atomic.Value // string: HTTP Host of latest request
}
// TestHTTPProxy returns a new HTTPProxy server.
//
// All requests return http.StatusNotImplemented.
func TestHTTPProxy() *HTTPProxy {
proxyListener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
ginkgo.Fail(fmt.Sprintf("could not create HTTP proxy listener: %s", err))
}
proxy := &HTTPProxy{
Addr: proxyListener.Addr(),
}
proxySrv := http.Server{ //nolint:gosec
Addr: "127.0.0.1:0",
Handler: proxy,
}
go func() { _ = proxySrv.Serve(proxyListener) }()
ginkgo.DeferCleanup(proxySrv.Close)
return proxy
}
// URL returns the proxy's URL for use by clients.
func (p *HTTPProxy) URL() *url.URL {
return &url.URL{
Scheme: "http",
Host: p.Addr.String(),
}
}
// Check ReqURL has the right type signature for http.Transport.Proxy
var _ = http.Transport{Proxy: (*HTTPProxy)(nil).ReqURL}
func (p *HTTPProxy) ReqURL(*http.Request) (*url.URL, error) {
return p.URL(), nil
}
// RequestTarget returns the target of the last request.
func (p *HTTPProxy) RequestTarget() string {
val := p.requestTarget.Load()
if val == nil {
ginkgo.Fail(fmt.Sprintf("http proxy %s received no requests", p.Addr))
}
return val.(string)
}
func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
p.requestTarget.Store(req.Host)
w.WriteHeader(http.StatusNotImplemented)
}

View File

@ -54,7 +54,7 @@ var _ = Describe("Downloader", func() {
Describe("NewDownloader", func() {
It("Should use provided parameters", func() {
transport := &http.Transport{}
transport := new(http.Transport)
sut = NewDownloader(
config.Downloader{
@ -96,6 +96,7 @@ var _ = Describe("Downloader", func() {
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusNotFound)
}))
DeferCleanup(server.Close)
sutConfig.Attempts = 3
})
@ -212,5 +213,17 @@ var _ = Describe("Downloader", func() {
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
})
})
When("a proxy is configured", func() {
It("should be used", func(ctx context.Context) {
proxy := TestHTTPProxy()
sut.client.Transport = &http.Transport{Proxy: proxy.ReqURL}
_, err := sut.DownloadFile(ctx, "http://example.com")
Expect(err).Should(HaveOccurred())
Expect(proxy.RequestTarget()).Should(Equal("example.com"))
})
})
})
})

View File

@ -156,18 +156,17 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string
// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
func (b *Bootstrap) NewHTTPTransport() *http.Transport {
if b.resolver == nil {
return &http.Transport{
DialContext: b.dialer.DialContext,
}
}
transport := util.DefaultHTTPTransport()
transport.DialContext = b.dialContext
return &http.Transport{
DialContext: b.dialContext,
}
return transport
}
func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if b.resolver == nil {
return b.dialer.DialContext(ctx, network, addr)
}
ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr})
host, port, err := net.SplitHostPort(addr)

View File

@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync/atomic"
"github.com/0xERR0R/blocky/config"
@ -77,10 +78,15 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
})
Describe("HTTP transport", func() {
It("should use the system resolver", func() {
It("should use Go default values", func() {
transport := sut.NewHTTPTransport()
Expect(transport).ShouldNot(BeNil())
Expect(
reflect.ValueOf(transport.Proxy).Pointer(),
).Should(Equal(
reflect.ValueOf(http.ProxyFromEnvironment).Pointer(),
))
})
})

View File

@ -98,13 +98,13 @@ func createUpstreamClient(cfg upstreamConfig) upstreamClient {
switch cfg.Net {
case config.NetProtocolHttps:
transport := util.DefaultHTTPTransport()
transport.TLSClientConfig = &tlsConfig
return &httpUpstreamClient{
userAgent: cfg.UserAgent,
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
ForceAttemptHTTP2: true,
},
Transport: transport,
},
host: cfg.Host,
}

View File

@ -2,9 +2,9 @@ package resolver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync/atomic"
"time"
@ -195,7 +195,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
})
Describe("Using Dns over HTTP (DOH) upstream", func() {
Describe("Using DNS over HTTPS (DoH) upstream", func() {
var (
respFn func(request *dns.Msg) (response *dns.Msg)
modifyHTTPRespFn func(w http.ResponseWriter)
@ -211,18 +211,34 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})
transport := func() *http.Transport {
upstreamClient := sut.upstreamClient.(*httpUpstreamClient)
return upstreamClient.client.Transport.(*http.Transport)
}
JustBeforeEach(func() {
sutConfig.Upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn)
sut = newUpstreamResolverUnchecked(sutConfig, nil)
// use insecure certificates for test doh upstream
sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
// use insecure certificates for test DoH upstream
transport().TLSClientConfig.InsecureSkipVerify = true
})
When("Configured DOH resolver can resolve query", func() {
When("a proxy is configured", func() {
It("should use it", func() {
proxy := TestHTTPProxy()
transport().Proxy = proxy.ReqURL
_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())
upstreamHostPort := net.JoinHostPort(sutConfig.Upstream.Host, fmt.Sprint(sutConfig.Port))
Expect(proxy.RequestTarget()).Should(Equal(upstreamHostPort))
})
})
When("Configured DoH resolver can resolve query", func() {
It("should return answer from DNS upstream", func() {
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
@ -235,7 +251,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
))
})
})
When("Configured DOH resolver returns wrong http status code", func() {
When("Configured DoH resolver returns wrong http status code", func() {
BeforeEach(func() {
modifyHTTPRespFn = func(w http.ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError)
@ -247,7 +263,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
})
})
When("Configured DOH resolver returns wrong content type", func() {
When("Configured DoH resolver returns wrong content type", func() {
BeforeEach(func() {
modifyHTTPRespFn = func(w http.ResponseWriter) {
w.Header().Set("content-type", "text")
@ -260,7 +276,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
})
})
When("Configured DOH resolver returns wrong content", func() {
When("Configured DoH resolver returns wrong content", func() {
BeforeEach(func() {
modifyHTTPRespFn = func(w http.ResponseWriter) {
_, _ = w.Write([]byte("wrongcontent"))
@ -272,7 +288,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
})
})
When("Configured DOH resolver does not respond", func() {
When("Configured DoH resolver does not respond", func() {
JustBeforeEach(func() {
sutConfig.Upstream = config.Upstream{
Net: config.NetProtocolHttps,

View File

@ -1,10 +1,56 @@
package util
import (
"fmt"
"net"
"net/http"
)
//nolint:gochecknoglobals
var baseTransport *http.Transport
//nolint:gochecknoinits
func init() {
base, ok := http.DefaultTransport.(*http.Transport)
if !ok {
panic(fmt.Errorf(
"unsupported Go version: http.DefaultTransport is not of type *http.Transport: it is a %T",
http.DefaultTransport,
))
}
baseTransport = base
}
// DefaultHTTPTransport returns a new Transport with the same defaults as net/http.
func DefaultHTTPTransport() *http.Transport {
return &http.Transport{
Dial: baseTransport.Dial, //nolint:staticcheck
DialContext: baseTransport.DialContext,
DialTLS: baseTransport.DialTLS, //nolint:staticcheck
DialTLSContext: baseTransport.DialTLSContext,
DisableCompression: baseTransport.DisableCompression,
DisableKeepAlives: baseTransport.DisableKeepAlives,
ExpectContinueTimeout: baseTransport.ExpectContinueTimeout,
ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2,
GetProxyConnectHeader: baseTransport.GetProxyConnectHeader,
IdleConnTimeout: baseTransport.IdleConnTimeout,
MaxConnsPerHost: baseTransport.MaxConnsPerHost,
MaxIdleConns: baseTransport.MaxIdleConns,
MaxIdleConnsPerHost: baseTransport.MaxConnsPerHost,
MaxResponseHeaderBytes: baseTransport.MaxResponseHeaderBytes,
OnProxyConnectResponse: baseTransport.OnProxyConnectResponse,
Proxy: baseTransport.Proxy,
ProxyConnectHeader: baseTransport.ProxyConnectHeader,
ReadBufferSize: baseTransport.ReadBufferSize,
ResponseHeaderTimeout: baseTransport.ResponseHeaderTimeout,
TLSClientConfig: baseTransport.TLSClientConfig,
TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout,
TLSNextProto: baseTransport.TLSNextProto,
WriteBufferSize: baseTransport.WriteBufferSize,
}
}
func HTTPClientIP(r *http.Request) net.IP {
addr := r.Header.Get("X-FORWARDED-FOR")
if addr == "" {

View File

@ -1,14 +1,39 @@
package util
import (
"context"
"net"
"net/http"
"net/url"
"reflect"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("HTTP Util", func() {
Describe("DefaultHTTPTransport", func() {
It("returns a new transport", func() {
a := DefaultHTTPTransport()
Expect(a).Should(BeIdenticalTo(a))
b := DefaultHTTPTransport()
Expect(a).ShouldNot(BeIdenticalTo(b))
})
It("returns a copy of http.DefaultTransport", func() {
Expect(cmp.Diff(
DefaultHTTPTransport(), http.DefaultTransport,
cmpopts.IgnoreUnexported(http.Transport{}),
// Non nil func field comparers
cmp.Comparer(cmpAsPtrs[func(context.Context, string, string) (net.Conn, error)]),
cmp.Comparer(cmpAsPtrs[func(*http.Request) (*url.URL, error)]),
)).Should(BeEmpty())
})
})
Describe("HTTPClientIP", func() {
It("extracts the IP from RemoteAddr", func() {
r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
@ -43,3 +68,10 @@ var _ = Describe("HTTP Util", func() {
})
})
})
// Go and cmp don't define func comparisons, besides with nil.
// In practice we can just compare them as pointers.
// See https://github.com/google/go-cmp/issues/162
func cmpAsPtrs[T any](x, y T) bool {
return reflect.ValueOf(x).Pointer() == reflect.ValueOf(y).Pointer()
}