diff --git a/engine/playlists.go b/engine/playlists.go index a0e075ae..efba8f25 100644 --- a/engine/playlists.go +++ b/engine/playlists.go @@ -111,11 +111,7 @@ func (p *playlists) Update(ctx context.Context, playlistId string, name *string, } func (p *playlists) GetAll(ctx context.Context) (model.Playlists, error) { - all, err := p.ds.Playlist(ctx).GetAll(model.QueryOptions{}) - for i := range all { - all[i].Public = true - } - return all, err + return p.ds.Playlist(ctx).GetAll() } type PlaylistInfo struct { diff --git a/persistence/persistence_suite_test.go b/persistence/persistence_suite_test.go index dbaabc4a..0f0381a2 100644 --- a/persistence/persistence_suite_test.go +++ b/persistence/persistence_suite_test.go @@ -65,7 +65,6 @@ var ( var ( plsBest = model.Playlist{ - ID: "10", Name: "Best", Comment: "No Comments", Owner: "userid", @@ -73,8 +72,8 @@ var ( SongCount: 2, Tracks: model.MediaFiles{{ID: "1001"}, {ID: "1003"}}, } - plsCool = model.Playlist{ID: "11", Name: "Cool", Tracks: model.MediaFiles{{ID: "1004"}}} - testPlaylists = model.Playlists{plsBest, plsCool} + plsCool = model.Playlist{Name: "Cool", Owner: "userid", Tracks: model.MediaFiles{{ID: "1004"}}} + testPlaylists = []*model.Playlist{&plsBest, &plsCool} ) func P(path string) string { @@ -117,8 +116,7 @@ var _ = Describe("Initialize test DB", func() { pr := NewPlaylistRepository(ctx, o) for i := range testPlaylists { - pls := testPlaylists[i] - err := pr.Put(&pls) + err := pr.Put(testPlaylists[i]) if err != nil { panic(err) } diff --git a/persistence/playlist_repository.go b/persistence/playlist_repository.go index e5136f92..acc0dfcf 100644 --- a/persistence/playlist_repository.go +++ b/persistence/playlist_repository.go @@ -24,26 +24,47 @@ func NewPlaylistRepository(ctx context.Context, o orm.Ormer) model.PlaylistRepos return r } +func (r *playlistRepository) userFilter() Sqlizer { + user := loggedUser(r.ctx) + if user.IsAdmin { + return And{} + } + return Or{ + Eq{"public": true}, + Eq{"owner": user.UserName}, + } +} + func (r *playlistRepository) CountAll(options ...model.QueryOptions) (int64, error) { - return r.count(Select(), options...) + sql := Select().Where(r.userFilter()) + return r.count(sql, options...) } func (r *playlistRepository) Exists(id string) (bool, error) { - return r.exists(Select().Where(Eq{"id": id})) + return r.exists(Select().Where(And{Eq{"id": id}, r.userFilter()})) } func (r *playlistRepository) Delete(id string) error { - del := Delete("playlist_tracks").Where(Eq{"playlist_id": id}) - _, err := r.executeSQL(del) + err := r.delete(And{Eq{"id": id}, r.userFilter()}) if err != nil { return err } - return r.delete(Eq{"id": id}) + del := Delete("playlist_tracks").Where(Eq{"playlist_id": id}) + _, err = r.executeSQL(del) + return err } func (r *playlistRepository) Put(p *model.Playlist) error { if p.ID == "" { p.CreatedAt = time.Now() + } else { + ok, err := r.Exists(p.ID) + if err != nil { + return err + } + if !ok { + return model.ErrNotAuthorized + } } p.UpdatedAt = time.Now() @@ -55,12 +76,16 @@ func (r *playlistRepository) Put(p *model.Playlist) error { if err != nil { return err } + p.ID = id err = r.updateTracks(id, tracks) - return err + if err != nil { + return err + } + return r.loadTracks(p) } func (r *playlistRepository) Get(id string) (*model.Playlist, error) { - sel := r.newSelect().Columns("*").Where(Eq{"id": id}) + sel := r.newSelect().Columns("*").Where(And{Eq{"id": id}, r.userFilter()}) var pls model.Playlist err := r.queryOne(sel, &pls) if err != nil { @@ -71,7 +96,7 @@ func (r *playlistRepository) Get(id string) (*model.Playlist, error) { } func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) { - sel := r.newSelect(options...).Columns("*") + sel := r.newSelect(options...).Columns("*").Where(r.userFilter()) res := model.Playlists{} err := r.queryAll(sel, &res) return res, err @@ -85,7 +110,7 @@ func (r *playlistRepository) updateTracks(id string, tracks model.MediaFiles) er return r.Tracks(id).Update(ids) } -func (r *playlistRepository) loadTracks(pls *model.Playlist) (err error) { +func (r *playlistRepository) loadTracks(pls *model.Playlist) error { tracksQuery := Select().From("playlist_tracks"). LeftJoin("annotation on ("+ "annotation.item_id = media_file_id"+ @@ -94,11 +119,11 @@ func (r *playlistRepository) loadTracks(pls *model.Playlist) (err error) { Columns("starred", "starred_at", "play_count", "play_date", "rating", "f.*"). Join("media_file f on f.id = media_file_id"). Where(Eq{"playlist_id": pls.ID}).OrderBy("playlist_tracks.id") - err = r.queryAll(tracksQuery, &pls.Tracks) + err := r.queryAll(tracksQuery, &pls.Tracks) if err != nil { log.Error("Error loading playlist tracks", "playlist", pls.Name, "id", pls.ID) } - return + return err } func (r *playlistRepository) Count(options ...rest.QueryOptions) (int64, error) { diff --git a/persistence/playlist_repository_test.go b/persistence/playlist_repository_test.go index da8f8b41..9983358d 100644 --- a/persistence/playlist_repository_test.go +++ b/persistence/playlist_repository_test.go @@ -6,6 +6,7 @@ import ( "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/model/request" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -14,7 +15,9 @@ var _ = Describe("PlaylistRepository", func() { var repo model.PlaylistRepository BeforeEach(func() { - repo = NewPlaylistRepository(log.NewContext(context.TODO()), orm.NewOrm()) + ctx := log.NewContext(context.TODO()) + ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true}) + repo = NewPlaylistRepository(ctx, orm.NewOrm()) }) Describe("Count", func() { @@ -25,7 +28,7 @@ var _ = Describe("PlaylistRepository", func() { Describe("Exists", func() { It("returns true for an existing playlist", func() { - Expect(repo.Exists("11")).To(BeTrue()) + Expect(repo.Exists(plsCool.ID)).To(BeTrue()) }) It("returns false for a non-existing playlist", func() { Expect(repo.Exists("666")).To(BeFalse()) @@ -34,7 +37,7 @@ var _ = Describe("PlaylistRepository", func() { Describe("Get", func() { It("returns an existing playlist", func() { - p, err := repo.Get("10") + p, err := repo.Get(plsBest.ID) Expect(err).To(BeNil()) // Compare all but Tracks and timestamps p2 := *p @@ -52,7 +55,7 @@ var _ = Describe("PlaylistRepository", func() { Expect(err).To(MatchError(model.ErrNotFound)) }) It("returns all tracks", func() { - pls, err := repo.Get("10") + pls, err := repo.Get(plsBest.ID) Expect(err).To(BeNil()) Expect(pls.Name).To(Equal(plsBest.Name)) Expect(pls.Tracks).To(Equal(model.MediaFiles{ @@ -62,32 +65,31 @@ var _ = Describe("PlaylistRepository", func() { }) }) - Describe("Put/Exists/Delete", func() { - var newPls model.Playlist - BeforeEach(func() { - newPls = model.Playlist{ID: "22", Name: "Great!", Tracks: model.MediaFiles{{ID: "1004"}, {ID: "1003"}}} - }) - It("saves the playlist to the DB", func() { - Expect(repo.Put(&newPls)).To(BeNil()) - }) - It("adds repeated songs to a playlist and keeps the order", func() { - newPls.Tracks = append(newPls.Tracks, model.MediaFile{ID: "1004"}) - Expect(repo.Put(&newPls)).To(BeNil()) - saved, _ := repo.Get("22") - Expect(saved.Tracks).To(HaveLen(3)) - Expect(saved.Tracks[0].ID).To(Equal("1004")) - Expect(saved.Tracks[1].ID).To(Equal("1003")) - Expect(saved.Tracks[2].ID).To(Equal("1004")) - }) - It("returns the newly created playlist", func() { - Expect(repo.Exists("22")).To(BeTrue()) - }) - It("returns deletes the playlist", func() { - Expect(repo.Delete("22")).To(BeNil()) - }) - It("returns error if tries to retrieve the deleted playlist", func() { - Expect(repo.Exists("22")).To(BeFalse()) - }) + It("Put/Exists/Delete", func() { + By("saves the playlist to the DB") + newPls := model.Playlist{Name: "Great!", Owner: "userid", + Tracks: model.MediaFiles{{ID: "1004"}, {ID: "1003"}}} + + By("saves the playlist to the DB") + Expect(repo.Put(&newPls)).To(BeNil()) + + By("adds repeated songs to a playlist and keeps the order") + newPls.Tracks = append(newPls.Tracks, model.MediaFile{ID: "1004"}) + Expect(repo.Put(&newPls)).To(BeNil()) + saved, _ := repo.Get(newPls.ID) + Expect(saved.Tracks).To(HaveLen(3)) + Expect(saved.Tracks[0].ID).To(Equal("1004")) + Expect(saved.Tracks[1].ID).To(Equal("1003")) + Expect(saved.Tracks[2].ID).To(Equal("1004")) + + By("returns the newly created playlist") + Expect(repo.Exists(newPls.ID)).To(BeTrue()) + + By("returns deletes the playlist") + Expect(repo.Delete(newPls.ID)).To(BeNil()) + + By("returns error if tries to retrieve the deleted playlist") + Expect(repo.Exists(newPls.ID)).To(BeFalse()) }) Describe("GetAll", func() {