navidrome/server/subsonic/middlewares_test.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

361 lines
10 KiB
Go
Raw Normal View History

package subsonic
2020-01-09 21:56:44 +01:00
import (
"context"
"errors"
2020-01-09 21:56:44 +01:00
"net/http"
"net/http/httptest"
2020-01-10 03:58:03 +01:00
"strings"
2021-05-11 23:21:18 +02:00
"time"
2020-01-09 21:56:44 +01:00
2021-05-11 23:21:18 +02:00
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/consts"
2020-10-27 16:01:40 +01:00
"github.com/navidrome/navidrome/core"
2020-08-14 16:10:17 +02:00
"github.com/navidrome/navidrome/core/auth"
2020-01-24 01:44:08 +01:00
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
2020-05-13 22:49:55 +02:00
"github.com/navidrome/navidrome/model/request"
2020-10-27 16:01:40 +01:00
"github.com/navidrome/navidrome/tests"
2022-07-26 22:47:16 +02:00
. "github.com/onsi/ginkgo/v2"
2020-01-09 21:56:44 +01:00
. "github.com/onsi/gomega"
)
func newGetRequest(queryParams ...string) *http.Request {
r := httptest.NewRequest("GET", "/ping?"+strings.Join(queryParams, "&"), nil)
ctx := r.Context()
return r.WithContext(log.NewContext(ctx))
}
func newPostRequest(queryParam string, formFields ...string) *http.Request {
r, err := http.NewRequest("POST", "/ping?"+queryParam, strings.NewReader(strings.Join(formFields, "&")))
if err != nil {
panic(err)
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
2020-01-09 21:56:44 +01:00
ctx := r.Context()
return r.WithContext(log.NewContext(ctx))
}
var _ = Describe("Middlewares", func() {
var next *mockHandler
var w *httptest.ResponseRecorder
2021-05-02 00:03:45 +02:00
var ds model.DataStore
2020-01-09 21:56:44 +01:00
BeforeEach(func() {
next = &mockHandler{}
w = httptest.NewRecorder()
2021-05-02 00:03:45 +02:00
ds = &tests.MockDataStore{}
2020-01-09 21:56:44 +01:00
})
Describe("ParsePostForm", func() {
It("converts any filed in a x-www-form-urlencoded POST into query params", func() {
r := newPostRequest("a=abc", "u=user", "v=1.15", "c=test")
cp := postFormToQueryParams(next)
cp.ServeHTTP(w, r)
Expect(next.req.URL.Query().Get("a")).To(Equal("abc"))
Expect(next.req.URL.Query().Get("u")).To(Equal("user"))
Expect(next.req.URL.Query().Get("v")).To(Equal("1.15"))
Expect(next.req.URL.Query().Get("c")).To(Equal("test"))
})
It("adds repeated params", func() {
r := newPostRequest("a=abc", "id=1", "id=2")
cp := postFormToQueryParams(next)
cp.ServeHTTP(w, r)
Expect(next.req.URL.Query().Get("a")).To(Equal("abc"))
Expect(next.req.URL.Query()["id"]).To(ConsistOf("1", "2"))
})
It("overrides query params with same key", func() {
r := newPostRequest("a=query", "a=body")
cp := postFormToQueryParams(next)
cp.ServeHTTP(w, r)
Expect(next.req.URL.Query().Get("a")).To(Equal("body"))
})
})
2020-01-09 21:56:44 +01:00
Describe("CheckParams", func() {
It("passes when all required params are available (subsonicauth case)", func() {
r := newGetRequest("u=user", "v=1.15", "c=test")
2020-01-09 21:56:44 +01:00
cp := checkRequiredParameters(next)
cp.ServeHTTP(w, r)
2020-05-13 22:49:55 +02:00
username, _ := request.UsernameFrom(next.req.Context())
Expect(username).To(Equal("user"))
version, _ := request.VersionFrom(next.req.Context())
Expect(version).To(Equal("1.15"))
client, _ := request.ClientFrom(next.req.Context())
Expect(client).To(Equal("test"))
2020-01-09 21:56:44 +01:00
Expect(next.called).To(BeTrue())
})
It("passes when all required params are available (reverse-proxy case)", func() {
conf.Server.ReverseProxyWhitelist = "127.0.0.234/32"
conf.Server.ReverseProxyUserHeader = "Remote-User"
r := newGetRequest("v=1.15", "c=test")
r.Header.Add("Remote-User", "user")
r = r.WithContext(request.WithReverseProxyIp(r.Context(), "127.0.0.234"))
cp := checkRequiredParameters(next)
cp.ServeHTTP(w, r)
username, _ := request.UsernameFrom(next.req.Context())
Expect(username).To(Equal("user"))
version, _ := request.VersionFrom(next.req.Context())
Expect(version).To(Equal("1.15"))
client, _ := request.ClientFrom(next.req.Context())
Expect(client).To(Equal("test"))
Expect(next.called).To(BeTrue())
})
2020-01-09 21:56:44 +01:00
It("fails when user is missing", func() {
r := newGetRequest("v=1.15", "c=test")
2020-01-09 21:56:44 +01:00
cp := checkRequiredParameters(next)
cp.ServeHTTP(w, r)
Expect(w.Body.String()).To(ContainSubstring(`code="10"`))
Expect(next.called).To(BeFalse())
})
It("fails when version is missing", func() {
r := newGetRequest("u=user", "c=test")
2020-01-09 21:56:44 +01:00
cp := checkRequiredParameters(next)
cp.ServeHTTP(w, r)
Expect(w.Body.String()).To(ContainSubstring(`code="10"`))
Expect(next.called).To(BeFalse())
})
It("fails when client is missing", func() {
r := newGetRequest("u=user", "v=1.15")
2020-01-09 21:56:44 +01:00
cp := checkRequiredParameters(next)
cp.ServeHTTP(w, r)
Expect(w.Body.String()).To(ContainSubstring(`code="10"`))
Expect(next.called).To(BeFalse())
})
})
Describe("Authenticate", func() {
BeforeEach(func() {
2021-05-02 00:03:45 +02:00
ur := ds.User(context.TODO())
_ = ur.Put(&model.User{
UserName: "admin",
NewPassword: "wordpass",
})
2020-01-09 21:56:44 +01:00
})
2020-08-14 16:10:17 +02:00
It("passes authentication with correct credentials", func() {
r := newGetRequest("u=admin", "p=wordpass")
cp := authenticate(ds)(next)
cp.ServeHTTP(w, r)
2020-01-09 21:56:44 +01:00
Expect(next.called).To(BeTrue())
2020-05-13 22:49:55 +02:00
user, _ := request.UserFrom(next.req.Context())
2020-08-14 16:10:17 +02:00
Expect(user.UserName).To(Equal("admin"))
2020-01-09 21:56:44 +01:00
})
It("fails authentication with wrong password", func() {
r := newGetRequest("u=invalid", "", "", "")
2020-08-14 16:10:17 +02:00
cp := authenticate(ds)(next)
cp.ServeHTTP(w, r)
2020-01-09 21:56:44 +01:00
Expect(w.Body.String()).To(ContainSubstring(`code="40"`))
Expect(next.called).To(BeFalse())
2020-01-09 21:56:44 +01:00
})
})
Describe("GetPlayer", func() {
var mockedPlayers *mockPlayers
var r *http.Request
BeforeEach(func() {
mockedPlayers = &mockPlayers{}
r = newGetRequest()
2020-05-13 22:49:55 +02:00
ctx := request.WithUsername(r.Context(), "someone")
ctx = request.WithClient(ctx, "client")
r = r.WithContext(ctx)
})
It("returns a new player in the cookies when none is specified", func() {
gp := getPlayer(mockedPlayers)(next)
gp.ServeHTTP(w, r)
cookieStr := w.Header().Get("Set-Cookie")
Expect(cookieStr).To(ContainSubstring(playerIDCookieName("someone")))
})
It("does not add the cookie if there was an error", func() {
2020-05-13 22:49:55 +02:00
ctx := request.WithClient(r.Context(), "error")
r = r.WithContext(ctx)
gp := getPlayer(mockedPlayers)(next)
gp.ServeHTTP(w, r)
cookieStr := w.Header().Get("Set-Cookie")
Expect(cookieStr).To(BeEmpty())
})
Context("PlayerId specified in Cookies", func() {
BeforeEach(func() {
cookie := &http.Cookie{
Name: playerIDCookieName("someone"),
Value: "123",
MaxAge: consts.CookieExpiry,
}
r.AddCookie(cookie)
gp := getPlayer(mockedPlayers)(next)
gp.ServeHTTP(w, r)
})
It("stores the player in the context", func() {
Expect(next.called).To(BeTrue())
2020-05-13 22:49:55 +02:00
player, _ := request.PlayerFrom(next.req.Context())
Expect(player.ID).To(Equal("123"))
2020-05-13 22:49:55 +02:00
_, ok := request.TranscodingFrom(next.req.Context())
Expect(ok).To(BeFalse())
})
It("returns the playerId in the cookie", func() {
cookieStr := w.Header().Get("Set-Cookie")
Expect(cookieStr).To(ContainSubstring(playerIDCookieName("someone") + "=123"))
})
})
Context("Player has transcoding configured", func() {
BeforeEach(func() {
cookie := &http.Cookie{
Name: playerIDCookieName("someone"),
Value: "123",
MaxAge: consts.CookieExpiry,
}
r.AddCookie(cookie)
mockedPlayers.transcoding = &model.Transcoding{ID: "12"}
gp := getPlayer(mockedPlayers)(next)
gp.ServeHTTP(w, r)
})
It("stores the player in the context", func() {
2020-05-13 22:49:55 +02:00
player, _ := request.PlayerFrom(next.req.Context())
Expect(player.ID).To(Equal("123"))
2020-05-13 22:49:55 +02:00
transcoding, _ := request.TranscodingFrom(next.req.Context())
Expect(transcoding.ID).To(Equal("12"))
})
})
})
2020-08-14 16:10:17 +02:00
Describe("validateCredentials", func() {
var usr *model.User
2020-08-14 16:10:17 +02:00
BeforeEach(func() {
2021-05-02 00:03:45 +02:00
ur := ds.User(context.TODO())
_ = ur.Put(&model.User{
UserName: "admin",
NewPassword: "wordpass",
})
var err error
usr, err = ur.FindByUsernameWithPassword("admin")
if err != nil {
panic(err)
}
2020-08-14 16:10:17 +02:00
})
2020-08-14 16:10:17 +02:00
Context("Plaintext password", func() {
It("authenticates with plaintext password ", func() {
err := validateCredentials(usr, "wordpass", "", "", "")
2020-08-14 16:10:17 +02:00
Expect(err).NotTo(HaveOccurred())
})
It("fails authentication with wrong password", func() {
err := validateCredentials(usr, "INVALID", "", "", "")
2020-08-14 16:10:17 +02:00
Expect(err).To(MatchError(model.ErrInvalidAuth))
})
})
Context("Encoded password", func() {
It("authenticates with simple encoded password ", func() {
err := validateCredentials(usr, "enc:776f726470617373", "", "", "")
2020-08-14 16:10:17 +02:00
Expect(err).NotTo(HaveOccurred())
})
})
Context("Token based authentication", func() {
It("authenticates with token based authentication", func() {
err := validateCredentials(usr, "", "23b342970e25c7928831c3317edd0b67", "retnlmjetrymazgkt", "")
2020-08-14 16:10:17 +02:00
Expect(err).NotTo(HaveOccurred())
})
It("fails if salt is missing", func() {
err := validateCredentials(usr, "", "23b342970e25c7928831c3317edd0b67", "", "")
2020-08-14 16:10:17 +02:00
Expect(err).To(MatchError(model.ErrInvalidAuth))
})
})
Context("JWT based authentication", func() {
var usr *model.User
2020-08-14 16:10:17 +02:00
var validToken string
2020-08-14 16:10:17 +02:00
BeforeEach(func() {
2021-05-11 23:21:18 +02:00
conf.Server.SessionTimeout = time.Minute
2021-05-12 00:55:58 +02:00
auth.Init(ds)
2021-05-11 23:21:18 +02:00
usr = &model.User{UserName: "admin"}
2020-08-14 16:10:17 +02:00
var err error
validToken, err = auth.CreateToken(usr)
2020-08-14 16:10:17 +02:00
if err != nil {
panic(err)
}
})
2020-08-14 16:10:17 +02:00
It("authenticates with JWT token based authentication", func() {
err := validateCredentials(usr, "", "", "", validToken)
2020-08-14 16:10:17 +02:00
Expect(err).NotTo(HaveOccurred())
})
It("fails if JWT token is invalid", func() {
err := validateCredentials(usr, "", "", "", "invalid.token")
2020-08-14 16:10:17 +02:00
Expect(err).To(MatchError(model.ErrInvalidAuth))
})
It("fails if JWT token sub is different than username", func() {
u := &model.User{UserName: "hacker"}
validToken, _ = auth.CreateToken(u)
err := validateCredentials(usr, "", "", "", validToken)
2020-08-14 16:10:17 +02:00
Expect(err).To(MatchError(model.ErrInvalidAuth))
})
})
})
2020-01-09 21:56:44 +01:00
})
type mockHandler struct {
req *http.Request
called bool
}
func (mh *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mh.req = r
mh.called = true
}
type mockPlayers struct {
2020-10-27 16:01:40 +01:00
core.Players
transcoding *model.Transcoding
}
func (mp *mockPlayers) Get(ctx context.Context, playerId string) (*model.Player, error) {
return &model.Player{ID: playerId}, nil
}
func (mp *mockPlayers) Register(ctx context.Context, id, client, typ, ip string) (*model.Player, *model.Transcoding, error) {
if client == "error" {
return nil, nil, errors.New(client)
}
return &model.Player{ID: id}, mp.transcoding, nil
}