From 65eeb5ec1a7fe50dca8c51e666964d7ff89827bb Mon Sep 17 00:00:00 2001 From: Deluan Date: Sun, 26 Mar 2023 13:29:57 -0400 Subject: [PATCH] Add tests for serverAddressMiddleware --- server/middlewares.go | 4 +- server/middlewares_test.go | 96 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/server/middlewares.go b/server/middlewares.go index 52ff3609..8948b706 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -164,8 +164,8 @@ func serverAddressMiddleware(h http.Handler) http.Handler { var ( xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") - xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Scheme") - xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Proto") + xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") + xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme") ) func serverAddress(r *http.Request) (scheme, host string) { diff --git a/server/middlewares_test.go b/server/middlewares_test.go index a1ea59dc..946ac613 100644 --- a/server/middlewares_test.go +++ b/server/middlewares_test.go @@ -48,4 +48,100 @@ var _ = Describe("middlewares", func() { Expect(nextCalled).To(BeTrue()) }) }) + + Describe("serverAddressMiddleware", func() { + var ( + nextHandler http.Handler + middleware http.Handler + recorder *httptest.ResponseRecorder + req *http.Request + ) + + BeforeEach(func() { + nextHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + middleware = serverAddressMiddleware(nextHandler) + recorder = httptest.NewRecorder() + }) + + Context("with no X-Forwarded headers", func() { + BeforeEach(func() { + req, _ = http.NewRequest("GET", "http://example.com", nil) + }) + + It("should not modify the request", func() { + middleware.ServeHTTP(recorder, req) + Expect(req.Host).To(Equal("example.com")) + Expect(req.URL.Scheme).To(Equal("http")) + }) + }) + + Context("with X-Forwarded-Host header", func() { + BeforeEach(func() { + req, _ = http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("X-Forwarded-Host", "forwarded.example.com") + }) + + It("should modify the request with the X-Forwarded-Host header value", func() { + middleware.ServeHTTP(recorder, req) + Expect(req.Host).To(Equal("forwarded.example.com")) + Expect(req.URL.Scheme).To(Equal("http")) + }) + }) + + Context("with X-Forwarded-Proto header", func() { + BeforeEach(func() { + req, _ = http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("X-Forwarded-Proto", "https") + }) + + It("should modify the request with the X-Forwarded-Proto header value", func() { + middleware.ServeHTTP(recorder, req) + Expect(req.Host).To(Equal("example.com")) + Expect(req.URL.Scheme).To(Equal("https")) + }) + }) + + Context("with X-Forwarded-Scheme header", func() { + BeforeEach(func() { + req, _ = http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("X-Forwarded-Scheme", "https") + }) + + It("should modify the request with the X-Forwarded-Scheme header value", func() { + middleware.ServeHTTP(recorder, req) + Expect(req.Host).To(Equal("example.com")) + Expect(req.URL.Scheme).To(Equal("https")) + }) + }) + + Context("with multiple X-Forwarded headers", func() { + BeforeEach(func() { + req, _ = http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("X-Forwarded-Host", "forwarded.example.com") + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Scheme", "http") + }) + + It("should modify the request with the first non-empty X-Forwarded header value", func() { + middleware.ServeHTTP(recorder, req) + Expect(req.Host).To(Equal("forwarded.example.com")) + Expect(req.URL.Scheme).To(Equal("https")) + }) + }) + + Context("with multiple values in X-Forwarded-Host header", func() { + BeforeEach(func() { + req, _ = http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("X-Forwarded-Host", "forwarded1.example.com, forwarded2.example.com") + }) + + It("should modify the request with the first value in X-Forwarded-Host header", func() { + middleware.ServeHTTP(recorder, req) + Expect(req.Host).To(Equal("forwarded1.example.com")) + Expect(req.URL.Scheme).To(Equal("http")) + }) + }) + }) })