diff --git a/server/subsonic/api.go b/server/subsonic/api.go index ae28e4ec..467d1d7e 100644 --- a/server/subsonic/api.go +++ b/server/subsonic/api.go @@ -19,6 +19,7 @@ import ( "github.com/navidrome/navidrome/server/events" "github.com/navidrome/navidrome/server/subsonic/responses" "github.com/navidrome/navidrome/utils" + "github.com/navidrome/navidrome/utils/req" ) const Version = "1.16.1" @@ -211,15 +212,6 @@ func hr(r chi.Router, path string, f handlerRaw) { handle := func(w http.ResponseWriter, r *http.Request) { res, err := f(w, r) if err != nil { - // If it is not a Subsonic error, convert it to an ErrorGeneric - var subErr subError - if !errors.As(err, &subErr) { - if errors.Is(err, model.ErrNotFound) { - err = newError(responses.ErrorDataNotFound, "data not found") - } else { - err = newError(responses.ErrorGeneric, fmt.Sprintf("Internal Server Error: %s", err)) - } - } sendError(w, r, err) return } @@ -264,15 +256,28 @@ func addHandler(r chi.Router, path string, handle func(w http.ResponseWriter, r r.HandleFunc("/"+path+".view", handle) } -func sendError(w http.ResponseWriter, r *http.Request, err error) { - response := newResponse() - code := responses.ErrorGeneric - var subErr subError - if errors.As(err, &subErr) { - code = subErr.code +func mapToSubsonicError(err error) subError { + switch { + case errors.Is(err, errSubsonic): // do nothing + case errors.Is(err, req.ErrMissingParam): + err = newError(responses.ErrorMissingParameter, err.Error()) + case errors.Is(err, req.ErrInvalidParam): + err = newError(responses.ErrorGeneric, err.Error()) + case errors.Is(err, model.ErrNotFound): + err = newError(responses.ErrorDataNotFound, "data not found") + default: + err = newError(responses.ErrorGeneric, fmt.Sprintf("Internal Server Error: %s", err)) } + var subErr subError + errors.As(err, &subErr) + return subErr +} + +func sendError(w http.ResponseWriter, r *http.Request, err error) { + subErr := mapToSubsonicError(err) + response := newResponse() response.Status = "failed" - response.Error = &responses.Error{Code: int32(code), Message: err.Error()} + response.Error = &responses.Error{Code: int32(subErr.code), Message: subErr.Error()} sendResponse(w, r, response) } diff --git a/server/subsonic/helpers.go b/server/subsonic/helpers.go index 87ace90b..d006bd7a 100644 --- a/server/subsonic/helpers.go +++ b/server/subsonic/helpers.go @@ -2,6 +2,7 @@ package subsonic import ( "context" + "errors" "fmt" "mime" "net/http" @@ -62,6 +63,13 @@ func newError(code int, message ...interface{}) error { } } +// errSubsonic and Unwrap are used to allow `errors.Is(err, errSubsonic)` to work +var errSubsonic = errors.New("subsonic API error") + +func (e subError) Unwrap() error { + return fmt.Errorf("%w: %d", errSubsonic, e.code) +} + func (e subError) Error() string { var msg string if len(e.messages) == 0 { diff --git a/server/subsonic/media_annotation.go b/server/subsonic/media_annotation.go index f6576b15..36c67373 100644 --- a/server/subsonic/media_annotation.go +++ b/server/subsonic/media_annotation.go @@ -2,7 +2,6 @@ package subsonic import ( "context" - "errors" "fmt" "net/http" "time" @@ -13,26 +12,22 @@ import ( "github.com/navidrome/navidrome/model/request" "github.com/navidrome/navidrome/server/events" "github.com/navidrome/navidrome/server/subsonic/responses" - "github.com/navidrome/navidrome/utils" + "github.com/navidrome/navidrome/utils/req" ) func (api *Router) SetRating(r *http.Request) (*responses.Subsonic, error) { - id, err := requiredParamString(r, "id") + p := req.Params(r) + id, err := p.String("id") if err != nil { return nil, err } - rating, err := requiredParamInt(r, "rating") + rating, err := p.Int("rating") if err != nil { return nil, err } log.Debug(r, "Setting rating", "rating", rating, "id", id) err = api.setRating(r.Context(), id, rating) - - if errors.Is(err, model.ErrNotFound) { - log.Error(r, err) - return nil, newError(responses.ErrorDataNotFound, "ID not found") - } if err != nil { log.Error(r, err) return nil, err @@ -70,9 +65,10 @@ func (api *Router) setRating(ctx context.Context, id string, rating int) error { } func (api *Router) Star(r *http.Request) (*responses.Subsonic, error) { - ids := utils.ParamStrings(r, "id") - albumIds := utils.ParamStrings(r, "albumId") - artistIds := utils.ParamStrings(r, "artistId") + p := req.Params(r) + ids, _ := p.Strings("id") + albumIds, _ := p.Strings("albumId") + artistIds, _ := p.Strings("artistId") if len(ids)+len(albumIds)+len(artistIds) == 0 { return nil, newError(responses.ErrorMissingParameter, "Required id parameter is missing") } @@ -88,9 +84,10 @@ func (api *Router) Star(r *http.Request) (*responses.Subsonic, error) { } func (api *Router) Unstar(r *http.Request) (*responses.Subsonic, error) { - ids := utils.ParamStrings(r, "id") - albumIds := utils.ParamStrings(r, "albumId") - artistIds := utils.ParamStrings(r, "artistId") + p := req.Params(r) + ids, _ := p.Strings("id") + albumIds, _ := p.Strings("albumId") + artistIds, _ := p.Strings("artistId") if len(ids)+len(albumIds)+len(artistIds) == 0 { return nil, newError(responses.ErrorMissingParameter, "Required id parameter is missing") } @@ -150,11 +147,6 @@ func (api *Router) setStar(ctx context.Context, star bool, ids ...string) error api.broker.SendMessage(ctx, event) return nil }) - - if errors.Is(err, model.ErrNotFound) { - log.Error(ctx, err) - return newError(responses.ErrorDataNotFound, "ID not found") - } if err != nil { log.Error(ctx, err) return err @@ -163,15 +155,16 @@ func (api *Router) setStar(ctx context.Context, star bool, ids ...string) error } func (api *Router) Scrobble(r *http.Request) (*responses.Subsonic, error) { - ids, err := requiredParamStrings(r, "id") + p := req.Params(r) + ids, err := p.Strings("id") if err != nil { return nil, err } - times := utils.ParamTimes(r, "time") + times := p.Times("time") if len(times) > 0 && len(times) != len(ids) { return nil, newError(responses.ErrorGeneric, "Wrong number of timestamps: %d, should be %d", len(times), len(ids)) } - submission := utils.ParamBool(r, "submission", true) + submission := p.BoolOr("submission", true) ctx := r.Context() if submission { diff --git a/utils/req/req.go b/utils/req/req.go new file mode 100644 index 00000000..a314556c --- /dev/null +++ b/utils/req/req.go @@ -0,0 +1,140 @@ +package req + +import ( + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/navidrome/navidrome/log" + "github.com/navidrome/navidrome/utils" +) + +type Values struct { + *http.Request +} + +func Params(r *http.Request) *Values { + return &Values{r} +} + +var ( + ErrMissingParam = errors.New("missing parameter") + ErrInvalidParam = errors.New("invalid parameter") +) + +func newError(err error, param string) error { + return fmt.Errorf("%w: '%s'", err, param) +} +func (r *Values) String(param string) (string, error) { + v := r.URL.Query().Get(param) + if v == "" { + return "", newError(ErrMissingParam, param) + } + return v, nil +} + +func (r *Values) StringOr(param, def string) string { + v, _ := r.String(param) + if v == "" { + return def + } + return v +} + +func (r *Values) Strings(param string) ([]string, error) { + values := r.URL.Query()[param] + if len(values) == 0 { + return nil, newError(ErrMissingParam, param) + } + return values, nil +} + +func (r *Values) TimeOr(param string, def time.Time) time.Time { + v, _ := r.String(param) + if v == "" || v == "-1" { + return def + } + value, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return def + } + t := utils.ToTime(value) + if t.Before(time.Date(1970, time.January, 2, 0, 0, 0, 0, time.UTC)) { + return def + } + return t +} + +func (r *Values) Times(param string) []time.Time { + pStr, _ := r.Strings(param) + times := make([]time.Time, len(pStr)) + for i, t := range pStr { + ti, err := strconv.ParseInt(t, 10, 64) + if err != nil { + log.Warn(r.Context(), "Ignoring invalid time param", "time", t, err) + times[i] = time.Now() + continue + } + times[i] = utils.ToTime(ti) + } + return times +} + +func (r *Values) Int64(param string) (int64, error) { + v, err := r.String(param) + if err != nil { + return 0, err + } + value, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, fmt.Errorf("%w '%s': expected integer, got '%s'", ErrInvalidParam, param, v) + } + return value, nil +} + +func (r *Values) Int(param string) (int, error) { + v, err := r.Int64(param) + if err != nil { + return 0, err + } + return int(v), nil +} + +func (r *Values) IntOr(param string, def int) int { + v, err := r.Int64(param) + if err != nil { + return def + } + return int(v) +} + +func (r *Values) Int64Or(param string, def int64) int64 { + v, err := r.Int64(param) + if err != nil { + return def + } + return v +} + +func (r *Values) Ints(param string) []int { + pStr, _ := r.Strings(param) + ints := make([]int, 0, len(pStr)) + for _, s := range pStr { + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + ints = append(ints, int(i)) + } + } + return ints +} + +func (r *Values) BoolOr(param string, def bool) bool { + v, _ := r.String(param) + if v == "" { + return def + } + return strings.Contains("/true/on/1/", "/"+strings.ToLower(v)+"/") +} diff --git a/utils/req/req_test.go b/utils/req/req_test.go new file mode 100644 index 00000000..0433ae6d --- /dev/null +++ b/utils/req/req_test.go @@ -0,0 +1,208 @@ +package req_test + +import ( + "fmt" + "net/http/httptest" + "testing" + "time" + + "github.com/navidrome/navidrome/utils" + "github.com/navidrome/navidrome/utils/req" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Request Helpers Suite") +} + +var _ = Describe("Request Helpers", func() { + var r *req.Values + + Describe("ParamString", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?a=123", nil)) + }) + + It("returns param as string", func() { + Expect(r.String("a")).To(Equal("123")) + }) + + It("returns empty string if param does not exist", func() { + v, err := r.String("NON_EXISTENT_PARAM") + Expect(err).To(MatchError(req.ErrMissingParam)) + Expect(err.Error()).To(ContainSubstring("NON_EXISTENT_PARAM")) + Expect(v).To(BeEmpty()) + }) + }) + + Describe("ParamStringDefault", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?a=123", nil)) + }) + + It("returns param as string", func() { + Expect(r.StringOr("a", "default_value")).To(Equal("123")) + }) + + It("returns default string if param does not exist", func() { + Expect(r.StringOr("xx", "default_value")).To(Equal("default_value")) + }) + }) + + Describe("ParamStrings", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?a=123&a=456", nil)) + }) + + It("returns all param occurrences as []string", func() { + Expect(r.Strings("a")).To(Equal([]string{"123", "456"})) + }) + + It("returns empty array if param does not exist", func() { + v, err := r.Strings("xx") + Expect(err).To(MatchError(req.ErrMissingParam)) + Expect(v).To(BeEmpty()) + }) + }) + + Describe("ParamTime", func() { + d := time.Date(2002, 8, 9, 12, 11, 13, 1000000, time.Local) + t := utils.ToMillis(d) + now := time.Now() + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", fmt.Sprintf("/ping?t=%d&inv=abc", t), nil)) + }) + + It("returns parsed time", func() { + Expect(r.TimeOr("t", now)).To(Equal(d)) + }) + + It("returns default time if param does not exist", func() { + Expect(r.TimeOr("xx", now)).To(Equal(now)) + }) + + It("returns default time if param is an invalid timestamp", func() { + Expect(r.TimeOr("inv", now)).To(Equal(now)) + }) + }) + + Describe("ParamTimes", func() { + d1 := time.Date(2002, 8, 9, 12, 11, 13, 1000000, time.Local) + d2 := time.Date(2002, 8, 9, 12, 13, 56, 0000000, time.Local) + t1 := utils.ToMillis(d1) + t2 := utils.ToMillis(d2) + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", fmt.Sprintf("/ping?t=%d&t=%d", t1, t2), nil)) + }) + + It("returns all param occurrences as []time.Time", func() { + Expect(r.Times("t")).To(Equal([]time.Time{d1, d2})) + }) + + It("returns empty string if param does not exist", func() { + Expect(r.Times("xx")).To(BeEmpty()) + }) + + It("returns current time as default if param is invalid", func() { + now := time.Now() + r = req.Params(httptest.NewRequest("GET", "/ping?t=null", nil)) + times := r.Times("t") + Expect(times).To(HaveLen(1)) + Expect(times[0]).To(BeTemporally(">=", now)) + }) + }) + + Describe("ParamInt", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?i=123&inv=123.45", nil)) + }) + Context("int", func() { + It("returns parsed int", func() { + Expect(r.IntOr("i", 999)).To(Equal(123)) + }) + + It("returns default value if param does not exist", func() { + Expect(r.IntOr("xx", 999)).To(Equal(999)) + }) + + It("returns default value if param is an invalid int", func() { + Expect(r.IntOr("inv", 999)).To(Equal(999)) + }) + }) + Context("int64", func() { + It("returns parsed int64", func() { + Expect(r.IntOr("i", 999)).To(Equal(123)) + }) + + It("returns default value if param does not exist", func() { + Expect(r.IntOr("xx", 999)).To(Equal(999)) + }) + + It("returns default value if param is an invalid int", func() { + Expect(r.IntOr("inv", 999)).To(Equal(999)) + }) + }) + }) + + Describe("ParamInts", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?i=123&i=456", nil)) + }) + + It("returns array of occurrences found", func() { + Expect(r.Ints("i")).To(Equal([]int{123, 456})) + }) + + It("returns empty array if param does not exist", func() { + Expect(r.Ints("xx")).To(BeEmpty()) + }) + }) + + Describe("ParamBool", func() { + Context("value is true", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?b=true&c=on&d=1&e=True", nil)) + }) + + It("parses 'true'", func() { + Expect(r.BoolOr("b", false)).To(BeTrue()) + }) + + It("parses 'on'", func() { + Expect(r.BoolOr("c", false)).To(BeTrue()) + }) + + It("parses '1'", func() { + Expect(r.BoolOr("d", false)).To(BeTrue()) + }) + + It("parses 'True'", func() { + Expect(r.BoolOr("e", false)).To(BeTrue()) + }) + }) + + Context("value is false", func() { + BeforeEach(func() { + r = req.Params(httptest.NewRequest("GET", "/ping?b=false&c=off&d=0", nil)) + }) + + It("parses 'false'", func() { + Expect(r.BoolOr("b", true)).To(BeFalse()) + }) + + It("parses 'off'", func() { + Expect(r.BoolOr("c", true)).To(BeFalse()) + }) + + It("parses '0'", func() { + Expect(r.BoolOr("d", true)).To(BeFalse()) + }) + + It("returns default value if param does not exist", func() { + Expect(r.BoolOr("xx", true)).To(BeTrue()) + }) + }) + }) +})