diff --git a/persistence/genre_repository.go b/persistence/genre_repository.go index 15e2fba1..fe6c6f52 100644 --- a/persistence/genre_repository.go +++ b/persistence/genre_repository.go @@ -3,6 +3,8 @@ package persistence import ( "context" + "github.com/google/uuid" + . "github.com/Masterminds/squirrel" "github.com/beego/beego/v2/client/orm" "github.com/deluan/rest" @@ -35,10 +37,21 @@ func (r *genreRepository) GetAll(opt ...model.QueryOptions) (model.Genres, error return res, err } +// Put is an Upsert operation, based on the name of the genre: If the name already exists, returns its ID, or else +// insert the new genre in the DB and returns its new created ID. func (r *genreRepository) Put(m *model.Genre) error { - id, err := r.put(m.ID, m) - m.ID = id - return err + if m.ID == "" { + m.ID = uuid.NewString() + } + sql := Insert("genre").Columns("id", "name").Values(m.ID, m.Name). + Suffix("on conflict (name) do update set name=excluded.name returning id") + resp := model.Genre{} + err := r.queryOne(sql, &resp) + if err != nil { + return err + } + m.ID = resp.ID + return nil } func (r *genreRepository) Count(options ...rest.QueryOptions) (int64, error) { diff --git a/persistence/genre_repository_test.go b/persistence/genre_repository_test.go index d02e18c0..212e1a1b 100644 --- a/persistence/genre_repository_test.go +++ b/persistence/genre_repository_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/beego/beego/v2/client/orm" + "github.com/google/uuid" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/persistence" @@ -18,12 +19,39 @@ var _ = Describe("GenreRepository", func() { repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), orm.NewOrm()) }) - It("returns all records", func() { - genres, err := repo.GetAll() - Expect(err).To(BeNil()) - Expect(genres).To(ConsistOf( - model.Genre{ID: "gn-1", Name: "Electronic", AlbumCount: 1, SongCount: 2}, - model.Genre{ID: "gn-2", Name: "Rock", AlbumCount: 3, SongCount: 3}, - )) + Describe("GetAll()", func() { + It("returns all records", func() { + genres, err := repo.GetAll() + Expect(err).To(BeNil()) + Expect(genres).To(ConsistOf( + model.Genre{ID: "gn-1", Name: "Electronic", AlbumCount: 1, SongCount: 2}, + model.Genre{ID: "gn-2", Name: "Rock", AlbumCount: 3, SongCount: 3}, + )) + }) + }) + Describe("Put()", Ordered, func() { + It("does not insert existing genre names", func() { + g := model.Genre{Name: "Rock"} + err := repo.Put(&g) + Expect(err).To(BeNil()) + Expect(g.ID).To(Equal("gn-2")) + + genres, _ := repo.GetAll() + Expect(genres).To(HaveLen(2)) + }) + + It("insert non-existent genre names", func() { + g := model.Genre{Name: "Reggae"} + err := repo.Put(&g) + Expect(err).To(BeNil()) + + // ID is a uuid + _, err = uuid.Parse(g.ID) + Expect(err).To(BeNil()) + + genres, _ := repo.GetAll() + Expect(genres).To(HaveLen(3)) + Expect(genres).To(ContainElement(model.Genre{ID: g.ID, Name: "Reggae", AlbumCount: 0, SongCount: 0})) + }) }) })