navidrome/server/middlewares.go

229 lines
5.5 KiB
Go

package server
import (
"errors"
"fmt"
"io/fs"
"net/http"
"net/url"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
"github.com/navidrome/navidrome/conf"
. "github.com/navidrome/navidrome/utils/gg"
"github.com/go-chi/chi/v5/middleware"
"github.com/navidrome/navidrome/consts"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model/request"
"github.com/unrolled/secure"
)
func requestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
start := time.Now()
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(ww, r)
status := ww.Status()
message := fmt.Sprintf("HTTP: %s %s://%s%s", r.Method, scheme, r.Host, r.RequestURI)
logArgs := []interface{}{
r.Context(),
message,
"remoteAddr", r.RemoteAddr,
"elapsedTime", time.Since(start),
"httpStatus", ww.Status(),
"responseSize", ww.BytesWritten(),
}
if log.CurrentLevel() >= log.LevelDebug {
logArgs = append(logArgs, "userAgent", r.UserAgent())
}
switch {
case status >= 500:
log.Error(logArgs...)
case status >= 400:
log.Warn(logArgs...)
default:
log.Debug(logArgs...)
}
})
}
func loggerInjector(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = log.NewContext(r.Context(), "requestId", middleware.GetReqID(ctx))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func robotsTXT(fs fs.FS) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/robots.txt") {
r.URL.Path = "/robots.txt"
http.FileServer(http.FS(fs)).ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
})
}
}
func corsHandler() func(http.Handler) http.Handler {
return cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
http.MethodHead,
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
},
AllowedHeaders: []string{"*"},
AllowCredentials: false,
ExposedHeaders: []string{"x-content-duration", "x-total-count", "x-nd-authorization"},
})
}
func secureMiddleware() func(http.Handler) http.Handler {
sec := secure.New(secure.Options{
ContentTypeNosniff: true,
FrameDeny: true,
ReferrerPolicy: "same-origin",
PermissionsPolicy: "autoplay=(), camera=(), microphone=(), usb=()",
//ContentSecurityPolicy: "script-src 'self' 'unsafe-inline'",
})
return sec.Handler
}
func compressMiddleware() func(http.Handler) http.Handler {
return middleware.Compress(
5,
"application/xml",
"application/json",
"application/javascript",
"text/html",
"text/plain",
"text/css",
"text/javascript",
)
}
func clientUniqueIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
clientUniqueId := r.Header.Get(consts.UIClientUniqueIDHeader)
if clientUniqueId != "" {
c := &http.Cookie{
Name: consts.UIClientUniqueIDHeader,
Value: clientUniqueId,
MaxAge: consts.CookieExpiry,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
Path: IfZero(conf.Server.BasePath, "/"),
}
http.SetCookie(w, c)
} else {
c, err := r.Cookie(consts.UIClientUniqueIDHeader)
if !errors.Is(err, http.ErrNoCookie) {
clientUniqueId = c.Value
}
}
if clientUniqueId != "" {
ctx = request.WithClientUniqueId(ctx, clientUniqueId)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
func serverAddressMiddleware(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rScheme, rHost := serverAddress(r); rHost != "" {
r.Host = rHost
r.URL.Scheme = rScheme
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
var (
xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host")
xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto")
xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme")
)
func serverAddress(r *http.Request) (scheme, host string) {
origHost := r.Host
protocol := "http"
if r.TLS != nil {
protocol = "https"
}
xfh := r.Header.Get(xForwardedHost)
if xfh != "" {
i := strings.Index(xfh, ",")
if i == -1 {
i = len(xfh)
}
xfh = xfh[:i]
}
scheme = firstOr(
protocol,
r.Header.Get(xForwardedProto),
r.Header.Get(xForwardedScheme),
r.URL.Scheme,
)
host = firstOr(r.Host, xfh)
if host != origHost {
log.Trace(r.Context(), "Request host has changed", "origHost", origHost, "host", host, "scheme", scheme, "url", r.URL)
}
return scheme, host
}
func firstOr(or string, strings ...string) string {
for _, s := range strings {
if s != "" {
return s
}
}
return or
}
// URLParamsMiddleware convert Chi URL params (from Context) to query params, as expected by our REST package
func URLParamsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := chi.RouteContext(r.Context())
parts := make([]string, 0)
for i, key := range ctx.URLParams.Keys {
value := ctx.URLParams.Values[i]
if key == "*" {
continue
}
parts = append(parts, url.QueryEscape(":"+key)+"="+url.QueryEscape(value))
}
q := strings.Join(parts, "&")
if r.URL.RawQuery == "" {
r.URL.RawQuery = q
} else {
r.URL.RawQuery += "&" + q
}
next.ServeHTTP(w, r)
})
}