diff --git a/db/db.go b/db/db.go new file mode 100644 index 00000000..b3e79320 --- /dev/null +++ b/db/db.go @@ -0,0 +1,34 @@ +package db + +import ( + "database/sql" + "os" + + "github.com/deluan/navidrome/conf" + _ "github.com/deluan/navidrome/db/migrations" + "github.com/deluan/navidrome/log" + _ "github.com/mattn/go-sqlite3" + "github.com/pressly/goose" +) + +const driver = "sqlite3" + +func EnsureDB() { + db, err := sql.Open(driver, conf.Server.DbPath) + defer db.Close() + if err != nil { + log.Error("Failed to open DB", err) + os.Exit(1) + } + + err = goose.SetDialect(driver) + if err != nil { + log.Error("Invalid DB driver", "driver", driver, err) + os.Exit(1) + } + err = goose.Run("up", db, "./") + if err != nil { + log.Error("Failed to apply new migrations", err) + os.Exit(1) + } +} diff --git a/db/migrations/20200130083147_create_schema.go b/db/migrations/20200130083147_create_schema.go new file mode 100644 index 00000000..766bb31c --- /dev/null +++ b/db/migrations/20200130083147_create_schema.go @@ -0,0 +1,22 @@ +package migrations + +import ( + "database/sql" + + "github.com/deluan/navidrome/log" + "github.com/pressly/goose" +) + +func init() { + goose.AddMigration(Up20200130083147, Down20200130083147) +} + +func Up20200130083147(tx *sql.Tx) error { + log.Info("Creating DB Schema") + _, err := tx.Exec(schema) + return err +} + +func Down20200130083147(tx *sql.Tx) error { + return nil +} diff --git a/db/migrations/schema.go b/db/migrations/schema.go new file mode 100644 index 00000000..7f2f5155 --- /dev/null +++ b/db/migrations/schema.go @@ -0,0 +1,136 @@ +package migrations + +var schema = ` +create table if not exists media_file +( + id varchar(255) not null + primary key, + title varchar(255) not null, + album varchar(255) default '' not null, + artist varchar(255) default '' not null, + artist_id varchar(255) default '' not null, + album_artist varchar(255) default '' not null, + album_id varchar(255) default '' not null, + has_cover_art bool default FALSE not null, + track_number integer default 0 not null, + disc_number integer default 0 not null, + year integer default 0 not null, + size integer default 0 not null, + path varchar(1024) not null, + suffix varchar(255) default '' not null, + duration integer default 0 not null, + bit_rate integer default 0 not null, + genre varchar(255) default '' not null, + compilation bool default FALSE not null, + created_at datetime, + updated_at datetime +); + +create index if not exists media_file_title + on media_file (title); + +create index if not exists media_file_album_id + on media_file (album_id); + +create index if not exists media_file_album + on media_file (album); + +create index if not exists media_file_artist_id + on media_file (artist_id); + +create index if not exists media_file_artist + on media_file (artist); + +create index if not exists media_file_album_artist + on media_file (album_artist); + +create index if not exists media_file_genre + on media_file (genre); + +create index if not exists media_file_year + on media_file (year); + +create index if not exists media_file_compilation + on media_file (compilation); + +create index if not exists media_file_path + on media_file (path); + +create table if not exists annotation +( + ann_id varchar(255) not null + primary key, + user_id varchar(255) default '' not null, + item_id varchar(255) default '' not null, + item_type varchar(255) default '' not null, + play_count integer, + play_date datetime, + rating integer, + starred bool default FALSE not null, + starred_at datetime, + unique (user_id, item_id, item_type) +); + +create index if not exists annotation_play_count + on annotation (play_count); + +create index if not exists annotation_play_date + on annotation (play_date); + +create index if not exists annotation_starred + on annotation (starred); + +create table if not exists playlist +( + id varchar(255) not null + primary key, + name varchar(255) not null, + comment varchar(255) default '' not null, + duration integer default 0 not null, + owner varchar(255) default '' not null, + public bool default FALSE not null, + tracks text not null, + unique (owner, name) +); + +create index if not exists playlist_name + on playlist (name); + +create table if not exists property +( + id varchar(255) not null + primary key, + value varchar(1024) default '' not null +); + +create table if not exists search +( + id varchar(255) not null + primary key, + "table" varchar(255) not null, + full_text varchar(1024) not null +); + +create index if not exists search_full_text + on search (full_text); + +create index if not exists search_table + on search ("table"); + +create table if not exists user +( + id varchar(255) not null + primary key, + user_name varchar(255) default '' not null + unique, + name varchar(255) default '' not null, + email varchar(255) default '' not null + unique, + password varchar(255) default '' not null, + is_admin bool default FALSE not null, + last_login_at datetime, + last_access_at datetime, + created_at datetime not null, + updated_at datetime not null +); +` diff --git a/engine/common.go b/engine/common.go index c95a04c6..8801dbf6 100644 --- a/engine/common.go +++ b/engine/common.go @@ -52,9 +52,9 @@ func FromArtist(ar *model.Artist, ann *model.Annotation) Entry { e.Title = ar.Name e.AlbumCount = ar.AlbumCount e.IsDir = true - if ann != nil { - e.Starred = ann.StarredAt - } + //if ann != nil { + e.Starred = ar.StarredAt + //} return e } @@ -74,11 +74,11 @@ func FromAlbum(al *model.Album, ann *model.Annotation) Entry { e.ArtistId = al.ArtistID e.Duration = al.Duration e.SongCount = al.SongCount - if ann != nil { - e.Starred = ann.StarredAt - e.PlayCount = int32(ann.PlayCount) - e.UserRating = ann.Rating - } + //if ann != nil { + e.Starred = al.StarredAt + e.PlayCount = int32(al.PlayCount) + e.UserRating = al.Rating + //} return e } @@ -111,11 +111,11 @@ func FromMediaFile(mf *model.MediaFile, ann *model.Annotation) Entry { e.AlbumId = mf.AlbumID e.ArtistId = mf.ArtistID e.Type = "music" // TODO Hardcoded for now - if ann != nil { - e.PlayCount = int32(ann.PlayCount) - e.Starred = ann.StarredAt - e.UserRating = ann.Rating - } + //if ann != nil { + e.PlayCount = int32(mf.PlayCount) + e.Starred = mf.StarredAt + e.UserRating = mf.Rating + //} return e } diff --git a/engine/cover_test.go b/engine/cover_test.go deleted file mode 100644 index 91caad92..00000000 --- a/engine/cover_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package engine_test - -import ( - "bytes" - "context" - "image" - "testing" - - "github.com/deluan/navidrome/engine" - "github.com/deluan/navidrome/model" - "github.com/deluan/navidrome/persistence" - . "github.com/deluan/navidrome/tests" - . "github.com/smartystreets/goconvey/convey" -) - -func TestCover(t *testing.T) { - Init(t, false) - - ds := &persistence.MockDataStore{} - mockMediaFileRepo := ds.MediaFile(nil).(*persistence.MockMediaFile) - mockAlbumRepo := ds.Album(nil).(*persistence.MockAlbum) - - cover := engine.NewCover(ds) - out := new(bytes.Buffer) - - Convey("Subject: GetCoverArt Endpoint", t, func() { - Convey("When id is not found", func() { - mockMediaFileRepo.SetData(`[]`, 1) - err := cover.Get(context.TODO(), "1", 0, out) - - Convey("Then return default cover", func() { - So(err, ShouldBeNil) - So(out.Bytes(), ShouldMatchMD5, "963552b04e87a5a55e993f98a0fbdf82") - }) - }) - Convey("When id is found", func() { - mockMediaFileRepo.SetData(`[{"ID":"2","HasCoverArt":true,"Path":"tests/fixtures/01 Invisible (RED) Edit Version.mp3"}]`, 1) - err := cover.Get(context.TODO(), "2", 0, out) - - Convey("Then it should return the cover from the file", func() { - So(err, ShouldBeNil) - So(out.Bytes(), ShouldMatchMD5, "e859a71cd1b1aaeb1ad437d85b306668") - }) - }) - Convey("When there is an error accessing the database", func() { - mockMediaFileRepo.SetData(`[{"ID":"2","HasCoverArt":true,"Path":"tests/fixtures/01 Invisible (RED) Edit Version.mp3"}]`, 1) - mockMediaFileRepo.SetError(true) - err := cover.Get(context.TODO(), "2", 0, out) - - Convey("Then error should not be nil", func() { - So(err, ShouldNotBeNil) - }) - }) - Convey("When id is found but file is not present", func() { - mockMediaFileRepo.SetData(`[{"ID":"2","HasCoverArt":true,"Path":"tests/fixtures/NOT_FOUND.mp3"}]`, 1) - err := cover.Get(context.TODO(), "2", 0, out) - - Convey("Then it should return DatNotFound error", func() { - So(err, ShouldEqual, model.ErrNotFound) - }) - }) - Convey("When specifying a size", func() { - mockMediaFileRepo.SetData(`[{"ID":"2","HasCoverArt":true,"Path":"tests/fixtures/01 Invisible (RED) Edit Version.mp3"}]`, 1) - err := cover.Get(context.TODO(), "2", 100, out) - - Convey("Then image returned should be 100x100", func() { - So(err, ShouldBeNil) - So(out.Bytes(), ShouldMatchMD5, "04378f523ca3e8ead33bf7140d39799e") - img, _, err := image.Decode(bytes.NewReader(out.Bytes())) - So(err, ShouldBeNil) - So(img.Bounds().Max.X, ShouldEqual, 100) - So(img.Bounds().Max.Y, ShouldEqual, 100) - }) - }) - Convey("When id is for an album", func() { - mockAlbumRepo.SetData(`[{"ID":"1","CoverArtPath":"tests/fixtures/01 Invisible (RED) Edit Version.mp3"}]`, 1) - err := cover.Get(context.TODO(), "al-1", 0, out) - - Convey("Then it should return the cover for the album", func() { - So(err, ShouldBeNil) - So(out.Bytes(), ShouldMatchMD5, "e859a71cd1b1aaeb1ad437d85b306668") - }) - }) - - Reset(func() { - mockMediaFileRepo.SetData("[]", 0) - mockMediaFileRepo.SetError(false) - out = new(bytes.Buffer) - }) - }) -} diff --git a/go.mod b/go.mod index 6c0c9d04..7c06219a 100644 --- a/go.mod +++ b/go.mod @@ -15,21 +15,26 @@ require ( github.com/go-chi/chi v4.0.3+incompatible github.com/go-chi/cors v1.0.0 github.com/go-chi/jwtauth v4.0.3+incompatible + github.com/go-sql-driver/mysql v1.5.0 // indirect + github.com/golang/protobuf v1.3.1 // indirect github.com/google/uuid v1.1.1 github.com/google/wire v0.4.0 github.com/kennygrant/sanitize v0.0.0-20170120101633-6a0bfdde8629 github.com/koding/multiconfig v0.0.0-20170327155832-26b6dfd3a84a github.com/kr/pretty v0.1.0 // indirect - github.com/lib/pq v1.3.0 + github.com/lib/pq v1.3.0 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 github.com/onsi/ginkgo v1.11.0 github.com/onsi/gomega v1.8.1 + github.com/pkg/errors v0.9.1 // indirect + github.com/pressly/goose v2.6.0+incompatible github.com/sirupsen/logrus v1.4.2 github.com/smartystreets/assertions v1.0.1 // indirect github.com/smartystreets/goconvey v1.6.4 github.com/stretchr/testify v1.4.0 // indirect + golang.org/x/net v0.0.0-20190603091049-60506f45cf65 // indirect golang.org/x/sys v0.0.0-20200107162124-548cf772de50 // indirect - google.golang.org/appengine v1.6.5 // indirect + golang.org/x/text v0.3.2 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index 57b27f09..63abf357 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/go-chi/jwtauth v4.0.3+incompatible/go.mod h1:Q5EIArY/QnD6BdS+IyDw7B2m github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -96,9 +98,14 @@ github.com/onsi/gomega v1.8.1 h1:C5Dqfs/LeauYDX0jJXIe2SWmwCbGzx9yF8C8xy3Lh34= github.com/onsi/gomega v1.8.1/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose v2.6.0+incompatible h1:3f8zIQ8rfgP9tyI0Hmcs2YNAqUCL1c+diLe3iU8Qd/k= +github.com/pressly/goose v2.6.0+incompatible/go.mod h1:m+QHWCqxR3k8D9l7qfzuC/djtlfzxr34mozWDYEu1z8= github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8iUrN18JYed2TvG9yN5ULG2jATM= github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 h1:p6IxqQMjab30l4lb9mmkIkkcE1yv6o0SKbPhW5pxqHI= @@ -150,8 +157,6 @@ golang.org/x/tools v0.0.0-20190422233926-fe54fb35175b h1:NVD8gBK33xpdqCaZVVtd6OF golang.org/x/tools v0.0.0-20190422233926-fe54fb35175b/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= -google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/log/log_test.go b/log/log_test.go index 4087c0bc..69f951a1 100644 --- a/log/log_test.go +++ b/log/log_test.go @@ -30,6 +30,12 @@ func TestLog(t *testing.T) { So(hook.LastEntry().Data, ShouldBeEmpty) }) + SkipConvey("Empty context", func() { + Error(context.Background(), "Simple Message") + So(hook.LastEntry().Message, ShouldEqual, "Simple Message") + So(hook.LastEntry().Data, ShouldBeEmpty) + }) + Convey("Message with two kv pairs", func() { Error("Simple Message", "key1", "value1", "key2", "value2") So(hook.LastEntry().Message, ShouldEqual, "Simple Message") diff --git a/main.go b/main.go index 9ccfbc24..88014b69 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "github.com/deluan/navidrome/conf" + "github.com/deluan/navidrome/db" ) func main() { @@ -10,6 +11,7 @@ func main() { } conf.Load() + db.EnsureDB() a := CreateServer(conf.Server.MusicFolder) a.MountRouter("/rest", CreateSubsonicAPIRouter()) diff --git a/model/album.go b/model/album.go index de1e0cae..c081669a 100644 --- a/model/album.go +++ b/model/album.go @@ -3,26 +3,33 @@ package model import "time" type Album struct { - ID string - Name string - ArtistID string - CoverArtPath string - CoverArtId string - Artist string - AlbumArtist string - Year int - Compilation bool - SongCount int - Duration int - Genre string - CreatedAt time.Time - UpdatedAt time.Time + ID string `json:"id" orm:"column(id)"` + Name string `json:"name"` + ArtistID string `json:"artistId"` + CoverArtPath string `json:"-"` + CoverArtId string `json:"-"` + Artist string `json:"artist"` + AlbumArtist string `json:"albumArtist"` + Year int `json:"year"` + Compilation bool `json:"compilation"` + SongCount int `json:"songCount"` + Duration int `json:"duration"` + Genre string `json:"genre"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + + // Annotations + PlayCount int `orm:"-"` + PlayDate time.Time `orm:"-"` + Rating int `orm:"-"` + Starred bool `orm:"-"` + StarredAt time.Time `orm:"-"` } type Albums []Album type AlbumRepository interface { - CountAll() (int64, error) + CountAll(...QueryOptions) (int64, error) Exists(id string) (bool, error) Put(m *Album) error Get(id string) (*Album, error) diff --git a/model/annotation.go b/model/annotation.go index fef0d37a..2261fc6c 100644 --- a/model/annotation.go +++ b/model/annotation.go @@ -5,7 +5,7 @@ import "time" const ( ArtistItemType = "artist" AlbumItemType = "album" - MediaItemType = "mediaFile" + MediaItemType = "media_file" ) type Annotation struct { diff --git a/model/artist.go b/model/artist.go index 16d32557..3a8c94fa 100644 --- a/model/artist.go +++ b/model/artist.go @@ -1,10 +1,20 @@ package model +import "time" + type Artist struct { - ID string - Name string - AlbumCount int + ID string `json:"id" orm:"column(id)"` + Name string `json:"name"` + AlbumCount int `json:"albumCount" orm:"column(album_count)"` + + // Annotations + PlayCount int `json:"playCount"` + PlayDate time.Time `json:"playDate"` + Rating int `json:"rating"` + Starred bool `json:"starred"` + StarredAt time.Time `json:"starredAt"` } + type Artists []Artist type ArtistIndex struct { @@ -14,12 +24,11 @@ type ArtistIndex struct { type ArtistIndexes []ArtistIndex type ArtistRepository interface { - CountAll() (int64, error) + CountAll(options ...QueryOptions) (int64, error) Exists(id string) (bool, error) Put(m *Artist) error Get(id string) (*Artist, error) GetStarred(userId string, options ...QueryOptions) (Artists, error) - SetStar(star bool, ids ...string) error Search(q string, offset int, size int) (Artists, error) Refresh(ids ...string) error GetIndex() (ArtistIndexes, error) diff --git a/model/datastore.go b/model/datastore.go index 61f1ce86..b8531b4a 100644 --- a/model/datastore.go +++ b/model/datastore.go @@ -20,7 +20,6 @@ type QueryOptions struct { type ResourceRepository interface { rest.Repository - rest.Persistable } type DataStore interface { diff --git a/model/mediafile.go b/model/mediafile.go index d46d038b..9da98c72 100644 --- a/model/mediafile.go +++ b/model/mediafile.go @@ -6,26 +6,33 @@ import ( ) type MediaFile struct { - ID string - Path string - Title string - Album string - Artist string - ArtistID string - AlbumArtist string - AlbumID string - HasCoverArt bool - TrackNumber int - DiscNumber int - Year int - Size int - Suffix string - Duration int - BitRate int - Genre string - Compilation bool - CreatedAt time.Time - UpdatedAt time.Time + ID string `json:"id" orm:"pk;column(id)"` + Path string `json:"path"` + Title string `json:"title"` + Album string `json:"album"` + Artist string `json:"artist"` + ArtistID string `json:"artistId"` + AlbumArtist string `json:"albumArtist"` + AlbumID string `json:"albumId"` + HasCoverArt bool `json:"hasCoverArt"` + TrackNumber int `json:"trackNumber"` + DiscNumber int `json:"discNumber"` + Year int `json:"year"` + Size int `json:"size"` + Suffix string `json:"suffix"` + Duration int `json:"duration"` + BitRate int `json:"bitRate"` + Genre string `json:"genre"` + Compilation bool `json:"compilation"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + + // Annotations + PlayCount int `json:"-" orm:"-"` + PlayDate time.Time `json:"-" orm:"-"` + Rating int `json:"-" orm:"-"` + Starred bool `json:"-" orm:"-"` + StarredAt time.Time `json:"-" orm:"-"` } func (mf *MediaFile) ContentType() string { @@ -35,12 +42,13 @@ func (mf *MediaFile) ContentType() string { type MediaFiles []MediaFile type MediaFileRepository interface { - CountAll() (int64, error) + CountAll(options ...QueryOptions) (int64, error) Exists(id string) (bool, error) Put(m *MediaFile) error Get(id string) (*MediaFile, error) FindByAlbum(albumId string) (MediaFiles, error) FindByPath(path string) (MediaFiles, error) + // TODO Remove userId GetStarred(userId string, options ...QueryOptions) (MediaFiles, error) GetRandom(options ...QueryOptions) (MediaFiles, error) Search(q string, offset int, size int) (MediaFiles, error) diff --git a/model/user.go b/model/user.go index 90af1de4..5857f4a0 100644 --- a/model/user.go +++ b/model/user.go @@ -3,18 +3,21 @@ package model import "time" type User struct { - ID string - UserName string - Name string - Email string - Password string - IsAdmin bool - LastLoginAt *time.Time - LastAccessAt *time.Time - CreatedAt time.Time - UpdatedAt time.Time + ID string `json:"id" orm:"column(id)"` + UserName string `json:"userName"` + Name string `json:"name"` + Email string `json:"email"` + Password string `json:"password"` + IsAdmin bool `json:"isAdmin"` + LastLoginAt *time.Time `json:"lastLoginAt"` + LastAccessAt *time.Time `json:"lastAccessAt"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + // TODO ChangePassword string `json:"password"` } +type Users []User + type UserRepository interface { CountAll(...QueryOptions) (int64, error) Get(id string) (*User, error) diff --git a/persistence/album_repository.go b/persistence/album_repository.go index 15e5691c..206d7eff 100644 --- a/persistence/album_repository.go +++ b/persistence/album_repository.go @@ -1,220 +1,136 @@ package persistence import ( - "fmt" - "strings" - "time" + "context" - "github.com/Masterminds/squirrel" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" - "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/rest" ) -type album struct { - ID string `json:"id" orm:"pk;column(id)"` - Name string `json:"name" orm:"index"` - ArtistID string `json:"artistId" orm:"column(artist_id);index"` - CoverArtPath string `json:"-"` - CoverArtId string `json:"-"` - Artist string `json:"artist" orm:"index"` - AlbumArtist string `json:"albumArtist"` - Year int `json:"year" orm:"index"` - Compilation bool `json:"compilation"` - SongCount int `json:"songCount"` - Duration int `json:"duration"` - Genre string `json:"genre" orm:"index"` - CreatedAt time.Time `json:"createdAt" orm:"null"` - UpdatedAt time.Time `json:"updatedAt" orm:"null"` -} - type albumRepository struct { - searchableRepository + sqlRepository } -func NewAlbumRepository(o orm.Ormer) model.AlbumRepository { +func NewAlbumRepository(ctx context.Context, o orm.Ormer) model.AlbumRepository { r := &albumRepository{} + r.ctx = ctx r.ormer = o - r.tableName = "album" + r.tableName = "media_file" return r } +func (r *albumRepository) CountAll(options ...model.QueryOptions) (int64, error) { + sel := r.selectAlbum(options...) + return r.count(sel, options...) +} + +func (r *albumRepository) Exists(id string) (bool, error) { + return r.exists(Select().Where(Eq{"album_id": id})) +} + func (r *albumRepository) Put(a *model.Album) error { - ta := album(*a) - return r.put(a.ID, a.Name, &ta) + return nil +} + +func (r *albumRepository) selectAlbum(options ...model.QueryOptions) SelectBuilder { + //select album_id as id, album as name, f.artist, f.album_artist, f.artist_id, f.compilation, f.genre, + // max(f.year) as year, sum(f.duration) as duration, max(f.updated_at) as updated_at, + // min(f.created_at) as created_at, count(*) as song_count, a.id as current_id, f.id as cover_art_id, + // f.path as cover_art_path, f.has_cover_art + // group by album_id + return r.newSelectWithAnnotation(model.AlbumItemType, "album_id", options...). + Columns("album_id as id", "album as name", "artist", "album_artist", "artist", "artist_id", + "compilation", "genre", "id as cover_art_id", "path as cover_art_path", "has_cover_art", + "max(year) as year", "sum(duration) as duration", "max(updated_at) as updated_at", + "min(created_at) as created_at", "count(*) as song_count").GroupBy("album_id") } func (r *albumRepository) Get(id string) (*model.Album, error) { - ta := album{ID: id} - err := r.ormer.Read(&ta) - if err == orm.ErrNoRows { - return nil, model.ErrNotFound - } + sq := r.selectAlbum().Where(Eq{"album_id": id}) + var res model.Album + err := r.queryOne(sq, &res) if err != nil { return nil, err } - a := model.Album(ta) - return &a, err + return &res, nil } func (r *albumRepository) FindByArtist(artistId string) (model.Albums, error) { - var albums []album - _, err := r.newQuery().Filter("artist_id", artistId).OrderBy("year", "name").All(&albums) - if err != nil { - return nil, err - } - return r.toAlbums(albums), nil + sq := r.selectAlbum().Where(Eq{"artist_id": artistId}).OrderBy("album") + var res model.Albums + err := r.queryAll(sq, &res) + return res, err } func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, error) { - var all []album - _, err := r.newQuery(options...).All(&all) - if err != nil { - return nil, err - } - return r.toAlbums(all), nil + sq := r.selectAlbum(options...) + var res model.Albums + err := r.queryAll(sq, &res) + return res, err } func (r *albumRepository) GetMap(ids []string) (map[string]model.Album, error) { - var all []album - if len(ids) == 0 { - return nil, nil - } - _, err := r.newQuery().Filter("id__in", ids).All(&all) - if err != nil { - return nil, err - } - res := make(map[string]model.Album) - for _, a := range all { - res[a.ID] = model.Album(a) - } - return res, nil + return nil, nil } // TODO Keep order when paginating func (r *albumRepository) GetRandom(options ...model.QueryOptions) (model.Albums, error) { - sq := r.newRawQuery(options...) + sq := r.selectAlbum(options...) switch r.ormer.Driver().Type() { case orm.DRMySQL: sq = sq.OrderBy("RAND()") default: sq = sq.OrderBy("RANDOM()") } - sql, args, err := sq.ToSql() + sql, args, err := r.toSql(sq) if err != nil { return nil, err } - var results []album + var results model.Albums _, err = r.ormer.Raw(sql, args...).QueryRows(&results) - return r.toAlbums(results), err -} - -func (r *albumRepository) toAlbums(all []album) model.Albums { - result := make(model.Albums, len(all)) - for i, a := range all { - result[i] = model.Album(a) - } - return result + return results, err } func (r *albumRepository) Refresh(ids ...string) error { - type refreshAlbum struct { - album - CurrentId string - HasCoverArt bool - } - var albums []refreshAlbum - o := r.ormer - sql := fmt.Sprintf(` -select album_id as id, album as name, f.artist, f.album_artist, f.artist_id, f.compilation, f.genre, - max(f.year) as year, sum(f.duration) as duration, max(f.updated_at) as updated_at, - min(f.created_at) as created_at, count(*) as song_count, a.id as current_id, f.id as cover_art_id, - f.path as cover_art_path, f.has_cover_art -from media_file f left outer join album a on f.album_id = a.id -where f.album_id in ('%s') -group by album_id order by f.id`, strings.Join(ids, "','")) - _, err := o.Raw(sql).QueryRows(&albums) - if err != nil { - return err - } - - var toInsert []album - var toUpdate []album - for _, al := range albums { - if !al.HasCoverArt { - al.CoverArtId = "" - } - if al.Compilation { - al.AlbumArtist = "Various Artists" - } - if al.AlbumArtist == "" { - al.AlbumArtist = al.Artist - } - if al.CurrentId != "" { - toUpdate = append(toUpdate, al.album) - } else { - toInsert = append(toInsert, al.album) - } - err := r.addToIndex(r.tableName, al.ID, al.Name) - if err != nil { - return err - } - } - if len(toInsert) > 0 { - n, err := o.InsertMulti(10, toInsert) - if err != nil { - return err - } - log.Debug("Inserted new albums", "num", n) - } - if len(toUpdate) > 0 { - for _, al := range toUpdate { - _, err := o.Update(&al, "name", "artist_id", "cover_art_path", "cover_art_id", "artist", "album_artist", - "year", "compilation", "song_count", "duration", "updated_at", "created_at") - if err != nil { - return err - } - } - log.Debug("Updated albums", "num", len(toUpdate)) - } - return err + return nil } func (r *albumRepository) PurgeEmpty() error { - _, err := r.ormer.Raw("delete from album where id not in (select distinct(album_id) from media_file)").Exec() - return err + return nil } func (r *albumRepository) GetStarred(userId string, options ...model.QueryOptions) (model.Albums, error) { - var starred []album - sq := r.newRawQuery(options...).Join("annotation").Where("annotation.item_id = " + r.tableName + ".id") - sq = sq.Where(squirrel.And{ - squirrel.Eq{"annotation.user_id": userId}, - squirrel.Eq{"annotation.starred": true}, - }) - sql, args, err := sq.ToSql() - if err != nil { - return nil, err - } - _, err = r.ormer.Raw(sql, args...).QueryRows(&starred) - if err != nil { - return nil, err - } - return r.toAlbums(starred), nil + sq := r.selectAlbum(options...).Where("starred = true") + var starred model.Albums + err := r.queryAll(sq, &starred) + return starred, err } func (r *albumRepository) Search(q string, offset int, size int) (model.Albums, error) { - if len(q) <= 2 { - return nil, nil - } + return nil, nil +} - var results []album - err := r.doSearch(r.tableName, q, offset, size, &results, "name") - if err != nil { - return nil, err - } - return r.toAlbums(results), nil +func (r *albumRepository) Count(options ...rest.QueryOptions) (int64, error) { + return r.CountAll(r.parseRestOptions(options...)) +} + +func (r *albumRepository) Read(id string) (interface{}, error) { + return r.Get(id) +} + +func (r *albumRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { + return r.GetAll(r.parseRestOptions(options...)) +} + +func (r *albumRepository) EntityName() string { + return "album" +} + +func (r *albumRepository) NewInstance() interface{} { + return &model.Album{} } var _ model.AlbumRepository = (*albumRepository)(nil) -var _ = model.Album(album{}) +var _ model.ResourceRepository = (*albumRepository)(nil) diff --git a/persistence/album_repository_test.go b/persistence/album_repository_test.go index e60cd526..8e088181 100644 --- a/persistence/album_repository_test.go +++ b/persistence/album_repository_test.go @@ -1,6 +1,8 @@ package persistence import ( + "context" + "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" . "github.com/onsi/ginkgo" @@ -11,7 +13,18 @@ var _ = Describe("AlbumRepository", func() { var repo model.AlbumRepository BeforeEach(func() { - repo = NewAlbumRepository(orm.NewOrm()) + ctx := context.WithValue(context.Background(), "user", &model.User{ID: "userid"}) + repo = NewAlbumRepository(ctx, orm.NewOrm()) + }) + + Describe("Get", func() { + It("returns an existent album", func() { + Expect(repo.Get("3")).To(Equal(&albumRadioactivity)) + }) + It("returns ErrNotFound when the album does not exist", func() { + _, err := repo.Get("666") + Expect(err).To(MatchError(model.ErrNotFound)) + }) }) Describe("GetAll", func() { @@ -20,7 +33,7 @@ var _ = Describe("AlbumRepository", func() { }) It("returns all records sorted", func() { - Expect(repo.GetAll(model.QueryOptions{Sort: "Name"})).To(Equal(model.Albums{ + Expect(repo.GetAll(model.QueryOptions{Sort: "name"})).To(Equal(model.Albums{ albumAbbeyRoad, albumRadioactivity, albumSgtPeppers, @@ -28,7 +41,7 @@ var _ = Describe("AlbumRepository", func() { }) It("returns all records sorted desc", func() { - Expect(repo.GetAll(model.QueryOptions{Sort: "Name", Order: "desc"})).To(Equal(model.Albums{ + Expect(repo.GetAll(model.QueryOptions{Sort: "name", Order: "desc"})).To(Equal(model.Albums{ albumSgtPeppers, albumRadioactivity, albumAbbeyRoad, @@ -52,7 +65,7 @@ var _ = Describe("AlbumRepository", func() { Describe("FindByArtist", func() { It("returns all records from a given ArtistID", func() { - Expect(repo.FindByArtist("1")).To(Equal(model.Albums{ + Expect(repo.FindByArtist("3")).To(Equal(model.Albums{ albumAbbeyRoad, albumSgtPeppers, })) diff --git a/persistence/annotation_repository.go b/persistence/annotation_repository.go index cf3a86a8..84fdc66f 100644 --- a/persistence/annotation_repository.go +++ b/persistence/annotation_repository.go @@ -1,8 +1,10 @@ package persistence import ( + "context" "time" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" "github.com/google/uuid" @@ -13,16 +15,16 @@ type annotation struct { UserID string `orm:"column(user_id)"` ItemID string `orm:"column(item_id)"` ItemType string `orm:"column(item_type)"` - PlayCount int `orm:"index;null"` - PlayDate time.Time `orm:"index;null"` - Rating int `orm:"index;null"` + PlayCount int `orm:"column(play_count);index;null"` + PlayDate time.Time `orm:"column(play_date);index;null"` + Rating int `orm:"null"` Starred bool `orm:"index"` - StarredAt time.Time `orm:"null"` + StarredAt time.Time `orm:"column(starred_at);null"` } func (u *annotation) TableUnique() [][]string { return [][]string{ - []string{"UserID", "ItemID", "ItemType"}, + {"UserID", "ItemID", "ItemType"}, } } @@ -30,40 +32,40 @@ type annotationRepository struct { sqlRepository } -func NewAnnotationRepository(o orm.Ormer) model.AnnotationRepository { +func NewAnnotationRepository(ctx context.Context, o orm.Ormer) model.AnnotationRepository { r := &annotationRepository{} + r.ctx = ctx r.ormer = o r.tableName = "annotation" return r } func (r *annotationRepository) Get(userID, itemType string, itemID string) (*model.Annotation, error) { - if userID == "" { - return nil, model.ErrInvalidAuth - } - q := r.newQuery().Filter("user_id", userID).Filter("item_type", itemType).Filter("item_id", itemID) + q := Select("*").From(r.tableName).Where(And{ + Eq{"user_id": userId(r.ctx)}, + Eq{"item_type": itemType}, + Eq{"item_id": itemID}, + }) var ann annotation - err := q.One(&ann) - if err == orm.ErrNoRows { + err := r.queryOne(q, &ann) + if err == model.ErrNotFound { return nil, nil } - if err != nil { - return nil, err - } resp := model.Annotation(ann) return &resp, nil } -func (r *annotationRepository) GetMap(userID, itemType string, itemID []string) (model.AnnotationMap, error) { - if userID == "" { - return nil, model.ErrInvalidAuth - } - if len(itemID) == 0 { +func (r *annotationRepository) GetMap(userID, itemType string, itemIDs []string) (model.AnnotationMap, error) { + if len(itemIDs) == 0 { return nil, nil } - q := r.newQuery().Filter("user_id", userID).Filter("item_type", itemType).Filter("item_id__in", itemID) + q := Select("*").From(r.tableName).Where(And{ + Eq{"user_id": userId(r.ctx)}, + Eq{"item_type": itemType}, + Eq{"item_id": itemIDs}, + }) var res []annotation - _, err := q.All(&res) + err := r.queryAll(q, &res) if err != nil { return nil, err } @@ -76,12 +78,12 @@ func (r *annotationRepository) GetMap(userID, itemType string, itemID []string) } func (r *annotationRepository) GetAll(userID, itemType string, options ...model.QueryOptions) ([]model.Annotation, error) { - if userID == "" { - return nil, model.ErrInvalidAuth - } - q := r.newQuery(options...).Filter("user_id", userID).Filter("item_type", itemType) + q := Select("*").From(r.tableName).Where(And{ + Eq{"user_id": userId(r.ctx)}, + Eq{"item_type": itemType}, + }) var res []annotation - _, err := q.All(&res) + err := r.queryAll(q, &res) if err != nil { return nil, err } @@ -104,16 +106,18 @@ func (r *annotationRepository) new(userID, itemType string, itemID string) *anno } func (r *annotationRepository) IncPlayCount(userID, itemType string, itemID string, ts time.Time) error { - if userID == "" { - return model.ErrInvalidAuth - } - q := r.newQuery().Filter("user_id", userID).Filter("item_type", itemType).Filter("item_id", itemID) - c, err := q.Update(orm.Params{ - "play_count": orm.ColValue(orm.ColAdd, 1), - "play_date": ts, - }) + uid := userId(r.ctx) + q := Update(r.tableName). + Set("play_count", Expr("play_count + 1")). + Set("play_date", ts). + Where(And{ + Eq{"user_id": uid}, + Eq{"item_type": itemType}, + Eq{"item_id": itemID}, + }) + c, err := r.executeSQL(q) if c == 0 || err == orm.ErrNoRows { - ann := r.new(userID, itemType, itemID) + ann := r.new(uid, itemType, itemID) ann.PlayCount = 1 ann.PlayDate = ts _, err = r.ormer.Insert(ann) @@ -122,26 +126,30 @@ func (r *annotationRepository) IncPlayCount(userID, itemType string, itemID stri } func (r *annotationRepository) SetStar(starred bool, userID, itemType string, ids ...string) error { - if userID == "" { - return model.ErrInvalidAuth - } - q := r.newQuery().Filter("user_id", userID).Filter("item_type", itemType).Filter("item_id__in", ids) + uid := userId(r.ctx) var starredAt time.Time if starred { starredAt = time.Now() } - c, err := q.Update(orm.Params{ - "starred": starred, - "starred_at": starredAt, - }) + q := Update(r.tableName). + Set("starred", starred). + Set("starred_at", starredAt). + Where(And{ + Eq{"user_id": uid}, + Eq{"item_type": itemType}, + Eq{"item_id": ids}, + }) + c, err := r.executeSQL(q) if c == 0 || err == orm.ErrNoRows { for _, id := range ids { - ann := r.new(userID, itemType, id) + ann := r.new(uid, itemType, id) ann.Starred = starred ann.StarredAt = starredAt _, err = r.ormer.Insert(ann) if err != nil { - return err + if err.Error() != "LastInsertId is not supported by this driver" { + return err + } } } } @@ -149,24 +157,27 @@ func (r *annotationRepository) SetStar(starred bool, userID, itemType string, id } func (r *annotationRepository) SetRating(rating int, userID, itemType string, itemID string) error { - if userID == "" { - return model.ErrInvalidAuth - } - q := r.newQuery().Filter("user_id", userID).Filter("item_type", itemType).Filter("item_id", itemID) - c, err := q.Update(orm.Params{ - "rating": rating, - }) + uid := userId(r.ctx) + q := Update(r.tableName). + Set("rating", rating). + Where(And{ + Eq{"user_id": uid}, + Eq{"item_type": itemType}, + Eq{"item_id": itemID}, + }) + c, err := r.executeSQL(q) if c == 0 || err == orm.ErrNoRows { - ann := r.new(userID, itemType, itemID) + ann := r.new(uid, itemType, itemID) ann.Rating = rating _, err = r.ormer.Insert(ann) } return err - } -func (r *annotationRepository) Delete(userID, itemType string, itemID ...string) error { - q := r.newQuery().Filter("user_id", userID).Filter("item_type", itemType).Filter("item_id__in", itemID) - _, err := q.Delete() - return err +func (r *annotationRepository) Delete(userID, itemType string, ids ...string) error { + return r.delete(And{ + Eq{"user_id": userId(r.ctx)}, + Eq{"item_type": itemType}, + Eq{"item_id": ids}, + }) } diff --git a/persistence/artist_repository.go b/persistence/artist_repository.go index 9b6986ca..90c01e7f 100644 --- a/persistence/artist_repository.go +++ b/persistence/artist_repository.go @@ -1,39 +1,49 @@ package persistence import ( - "fmt" + "context" "sort" "strings" - "time" - "github.com/Masterminds/squirrel" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/conf" - "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" "github.com/deluan/navidrome/utils" + "github.com/deluan/rest" ) -type artist struct { - ID string `json:"id" orm:"pk;column(id)"` - Name string `json:"name" orm:"index"` - AlbumCount int `json:"albumCount" orm:"column(album_count)"` -} - type artistRepository struct { - searchableRepository + sqlRepository indexGroups utils.IndexGroups } -func NewArtistRepository(o orm.Ormer) model.ArtistRepository { +func NewArtistRepository(ctx context.Context, o orm.Ormer) model.ArtistRepository { r := &artistRepository{} + r.ctx = ctx r.ormer = o r.indexGroups = utils.ParseIndexGroups(conf.Server.IndexGroups) - r.tableName = "artist" + r.tableName = "media_file" return r } -func (r *artistRepository) getIndexKey(a *artist) string { +func (r *artistRepository) selectArtist(options ...model.QueryOptions) SelectBuilder { + // FIXME Handle AlbumArtist/Various Artists... + return r.newSelectWithAnnotation(model.ArtistItemType, "album_id", options...). + Columns("artist_id as id", "artist as name", "count(distinct album_id) as album_count"). + GroupBy("artist_id").Where(Eq{"compilation": false}) +} + +func (r *artistRepository) CountAll(options ...model.QueryOptions) (int64, error) { + sel := r.selectArtist(options...).Where(Eq{"compilation": false}) + return r.count(sel, options...) +} + +func (r *artistRepository) Exists(id string) (bool, error) { + return r.exists(Select().Where(Eq{"artist_id": id})) +} + +func (r *artistRepository) getIndexKey(a *model.Artist) string { name := strings.ToLower(utils.NoArticle(a.Name)) for k, v := range r.indexGroups { key := strings.ToLower(k) @@ -45,28 +55,31 @@ func (r *artistRepository) getIndexKey(a *artist) string { } func (r *artistRepository) Put(a *model.Artist) error { - ta := artist(*a) - return r.put(a.ID, a.Name, &ta) + return nil } func (r *artistRepository) Get(id string) (*model.Artist, error) { - ta := artist{ID: id} - err := r.ormer.Read(&ta) - if err == orm.ErrNoRows { - return nil, model.ErrNotFound - } - if err != nil { - return nil, err - } - a := model.Artist(ta) - return &a, nil + sel := Select("artist_id as id", "artist as name", "count(distinct album_id) as album_count"). + From("media_file").GroupBy("artist_id").Where(Eq{"artist_id": id}) + var res model.Artist + err := r.queryOne(sel, &res) + return &res, err +} + +func (r *artistRepository) GetAll(options ...model.QueryOptions) (model.Artists, error) { + sel := r.selectArtist(options...) + var res model.Artists + err := r.queryAll(sel, &res) + return res, err } // TODO Cache the index (recalculate when there are changes to the DB) func (r *artistRepository) GetIndex() (model.ArtistIndexes, error) { - var all []artist + sq := Select("artist_id as id", "artist as name", "count(distinct album_id) as album_count"). + From("media_file").GroupBy("artist_id").OrderBy("name") + var all model.Artists // TODO Paginate - _, err := r.newQuery().OrderBy("name").All(&all) + err := r.queryAll(sq, &all) if err != nil { return nil, err } @@ -92,127 +105,41 @@ func (r *artistRepository) GetIndex() (model.ArtistIndexes, error) { } func (r *artistRepository) Refresh(ids ...string) error { - type refreshArtist struct { - artist - CurrentId string - AlbumArtist string - Compilation bool - } - var artists []refreshArtist - o := r.ormer - sql := fmt.Sprintf(` -select f.artist_id as id, - f.artist as name, - f.album_artist, - f.compilation, - count(*) as album_count, - a.id as current_id -from album f - left outer join artist a on f.artist_id = a.id -where f.artist_id in ('%s') group by f.artist_id order by f.id`, strings.Join(ids, "','")) - _, err := o.Raw(sql).QueryRows(&artists) - if err != nil { - return err - } - - var toInsert []artist - var toUpdate []artist - for _, ar := range artists { - if ar.Compilation { - ar.AlbumArtist = "Various Artists" - } - if ar.AlbumArtist != "" { - ar.Name = ar.AlbumArtist - } - if ar.CurrentId != "" { - toUpdate = append(toUpdate, ar.artist) - } else { - toInsert = append(toInsert, ar.artist) - } - err := r.addToIndex(r.tableName, ar.ID, ar.Name) - if err != nil { - return err - } - } - if len(toInsert) > 0 { - n, err := o.InsertMulti(10, toInsert) - if err != nil { - return err - } - log.Debug("Inserted new artists", "num", n) - } - if len(toUpdate) > 0 { - for _, al := range toUpdate { - // Don't update Starred - _, err := o.Update(&al, "name", "album_count") - if err != nil { - return err - } - } - log.Debug("Updated artists", "num", len(toUpdate)) - } - return err + return nil } func (r *artistRepository) GetStarred(userId string, options ...model.QueryOptions) (model.Artists, error) { - var starred []artist - sq := r.newRawQuery(options...).Join("annotation").Where("annotation.item_id = " + r.tableName + ".id") - sq = sq.Where(squirrel.And{ - squirrel.Eq{"annotation.user_id": userId}, - squirrel.Eq{"annotation.starred": true}, - }) - sql, args, err := sq.ToSql() - if err != nil { - return nil, err - } - _, err = r.ormer.Raw(sql, args...).QueryRows(&starred) - if err != nil { - return nil, err - } - return r.toArtists(starred), nil -} - -func (r *artistRepository) SetStar(starred bool, ids ...string) error { - if len(ids) == 0 { - return model.ErrNotFound - } - var starredAt time.Time - if starred { - starredAt = time.Now() - } - _, err := r.newQuery().Filter("id__in", ids).Update(orm.Params{ - "starred": starred, - "starred_at": starredAt, - }) - return err + return nil, nil // TODO } func (r *artistRepository) PurgeEmpty() error { - _, err := r.ormer.Raw("delete from artist where id not in (select distinct(artist_id) from album)").Exec() - return err + return nil } func (r *artistRepository) Search(q string, offset int, size int) (model.Artists, error) { - if len(q) <= 2 { - return nil, nil - } - - var results []artist - err := r.doSearch(r.tableName, q, offset, size, &results, "name") - if err != nil { - return nil, err - } - - return r.toArtists(results), nil + return nil, nil // TODO } -func (r *artistRepository) toArtists(all []artist) model.Artists { - result := make(model.Artists, len(all)) - for i, a := range all { - result[i] = model.Artist(a) - } - return result +func (r *artistRepository) Count(options ...rest.QueryOptions) (int64, error) { + return r.CountAll(r.parseRestOptions(options...)) +} + +func (r *artistRepository) Read(id string) (interface{}, error) { + return r.Get(id) +} + +func (r *artistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { + return r.GetAll(r.parseRestOptions(options...)) +} + +func (r *artistRepository) EntityName() string { + return "artist" +} + +func (r *artistRepository) NewInstance() interface{} { + return &model.Artist{} } var _ model.ArtistRepository = (*artistRepository)(nil) -var _ = model.Artist(artist{}) +var _ model.ArtistRepository = (*artistRepository)(nil) +var _ model.ResourceRepository = (*artistRepository)(nil) diff --git a/persistence/artist_repository_test.go b/persistence/artist_repository_test.go index 22036dc6..0da70d95 100644 --- a/persistence/artist_repository_test.go +++ b/persistence/artist_repository_test.go @@ -1,6 +1,8 @@ package persistence import ( + "context" + "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" . "github.com/onsi/ginkgo" @@ -11,22 +13,27 @@ var _ = Describe("ArtistRepository", func() { var repo model.ArtistRepository BeforeEach(func() { - repo = NewArtistRepository(orm.NewOrm()) + repo = NewArtistRepository(context.Background(), orm.NewOrm()) }) - Describe("Put/Get", func() { + Describe("Count", func() { + It("returns the number of artists in the DB", func() { + Expect(repo.CountAll()).To(Equal(int64(2))) + }) + }) + + Describe("Exist", func() { + It("returns true for an artist that is in the DB", func() { + Expect(repo.Exists("3")).To(BeTrue()) + }) + It("returns false for an artist that is in the DB", func() { + Expect(repo.Exists("666")).To(BeFalse()) + }) + }) + + Describe("Get", func() { It("saves and retrieves data", func() { - Expect(repo.Get("1")).To(Equal(&artistSaaraSaara)) - }) - - It("overrides data if ID already exists", func() { - Expect(repo.Put(&model.Artist{ID: "1", Name: "Saara Saara is The Best!", AlbumCount: 3})).To(BeNil()) - Expect(repo.Get("1")).To(Equal(&model.Artist{ID: "1", Name: "Saara Saara is The Best!", AlbumCount: 3})) - }) - - It("returns ErrNotFound when the ID does not exist", func() { - _, err := repo.Get("999") - Expect(err).To(MatchError(model.ErrNotFound)) + Expect(repo.Get("2")).To(Equal(&artistKraftwerk)) }) }) @@ -47,12 +54,6 @@ var _ = Describe("ArtistRepository", func() { artistKraftwerk, }, }, - { - ID: "S", - Artists: model.Artists{ - {ID: "1", Name: "Saara Saara is The Best!", AlbumCount: 3}, - }, - }, })) }) }) diff --git a/persistence/genre_repository.go b/persistence/genre_repository.go index 700f5d70..08f5efb2 100644 --- a/persistence/genre_repository.go +++ b/persistence/genre_repository.go @@ -1,60 +1,33 @@ package persistence import ( - "strconv" + "context" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" ) type genreRepository struct { - ormer orm.Ormer + sqlRepository } -func NewGenreRepository(o orm.Ormer) model.GenreRepository { - return &genreRepository{ormer: o} +func NewGenreRepository(ctx context.Context, o orm.Ormer) model.GenreRepository { + r := &genreRepository{} + r.ctx = ctx + r.ormer = o + r.tableName = "media_file" + return r } func (r genreRepository) GetAll() (model.Genres, error) { - genres := make(map[string]model.Genre) - - // Collect SongCount - var res []orm.Params - _, err := r.ormer.Raw("select genre, count(*) as c from media_file group by genre").Values(&res) + sq := Select("genre as name", "count(distinct album_id) as album_count", "count(distinct id) as song_count"). + From("media_file").GroupBy("genre") + sql, args, err := r.toSql(sq) if err != nil { return nil, err } - for _, r := range res { - name := r["genre"].(string) - count := r["c"].(string) - g, ok := genres[name] - if !ok { - g = model.Genre{Name: name} - } - g.SongCount, _ = strconv.Atoi(count) - genres[name] = g - } - - // Collect AlbumCount - _, err = r.ormer.Raw("select genre, count(*) as c from album group by genre").Values(&res) - if err != nil { - return nil, err - } - for _, r := range res { - name := r["genre"].(string) - count := r["c"].(string) - g, ok := genres[name] - if !ok { - g = model.Genre{Name: name} - } - g.AlbumCount, _ = strconv.Atoi(count) - genres[name] = g - } - - // Build response - result := model.Genres{} - for _, g := range genres { - result = append(result, g) - } - return result, err + var res model.Genres + _, err = r.ormer.Raw(sql, args).QueryRows(&res) + return res, err } diff --git a/persistence/genre_repository_test.go b/persistence/genre_repository_test.go index f21ab955..79ab2b99 100644 --- a/persistence/genre_repository_test.go +++ b/persistence/genre_repository_test.go @@ -1,8 +1,11 @@ -package persistence +package persistence_test import ( + "context" + "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" + "github.com/deluan/navidrome/persistence" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -11,7 +14,7 @@ var _ = Describe("GenreRepository", func() { var repo model.GenreRepository BeforeEach(func() { - repo = NewGenreRepository(orm.NewOrm()) + repo = persistence.NewGenreRepository(context.Background(), orm.NewOrm()) }) It("returns all records", func() { diff --git a/persistence/helpers.go b/persistence/helpers.go new file mode 100644 index 00000000..b443a807 --- /dev/null +++ b/persistence/helpers.go @@ -0,0 +1,62 @@ +package persistence + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +func toSqlArgs(rec interface{}) (map[string]interface{}, error) { + // Convert to JSON... + b, err := json.Marshal(rec) + if err != nil { + return nil, err + } + + // ... then convert to map + var m map[string]interface{} + err = json.Unmarshal(b, &m) + r := make(map[string]interface{}, len(m)) + for f, v := range m { + r[toSnakeCase(f)] = v + } + return r, err +} + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func toSnakeCase(str string) string { + snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} + +func ToStruct(m map[string]interface{}, rec interface{}, fieldNames []string) error { + var r = make(map[string]interface{}, len(m)) + for _, f := range fieldNames { + v, ok := m[f] + if !ok { + return fmt.Errorf("invalid field '%s'", f) + } + r[toCamelCase(f)] = v + } + // Convert to JSON... + b, err := json.Marshal(r) + if err != nil { + return err + } + + // ... then convert to struct + err = json.Unmarshal(b, &rec) + return err +} + +var matchUnderscore = regexp.MustCompile("_([A-Za-z])") + +func toCamelCase(str string) string { + return matchUnderscore.ReplaceAllStringFunc(str, func(s string) string { + return strings.ToUpper(strings.Replace(s, "_", "", -1)) + }) +} diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go index 19798bc9..d772bbda 100644 --- a/persistence/mediafile_repository.go +++ b/persistence/mediafile_repository.go @@ -1,176 +1,161 @@ package persistence import ( - "os" + "context" "strings" - "time" - "github.com/Masterminds/squirrel" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" + "github.com/deluan/rest" + "github.com/kennygrant/sanitize" ) -type mediaFile struct { - ID string `json:"id" orm:"pk;column(id)"` - Path string `json:"path" orm:"index"` - Title string `json:"title" orm:"index"` - Album string `json:"album"` - Artist string `json:"artist"` - ArtistID string `json:"artistId" orm:"column(artist_id)"` - AlbumArtist string `json:"albumArtist"` - AlbumID string `json:"albumId" orm:"column(album_id);index"` - HasCoverArt bool `json:"-"` - TrackNumber int `json:"trackNumber"` - DiscNumber int `json:"discNumber"` - Year int `json:"year"` - Size int `json:"size"` - Suffix string `json:"suffix"` - Duration int `json:"duration"` - BitRate int `json:"bitRate"` - Genre string `json:"genre" orm:"index"` - Compilation bool `json:"compilation"` - CreatedAt time.Time `json:"createdAt" orm:"null"` - UpdatedAt time.Time `json:"updatedAt" orm:"null"` -} - type mediaFileRepository struct { - searchableRepository + sqlRepository } -func NewMediaFileRepository(o orm.Ormer) model.MediaFileRepository { +func NewMediaFileRepository(ctx context.Context, o orm.Ormer) *mediaFileRepository { r := &mediaFileRepository{} + r.ctx = ctx r.ormer = o r.tableName = "media_file" return r } -func (r *mediaFileRepository) Put(m *model.MediaFile) error { - tm := mediaFile(*m) - // Don't update media annotation fields (playcount, starred, etc..) - // TODO Validate if this is still necessary, now that we don't have annotations in the mediafile model - return r.put(m.ID, m.Title, &tm, "path", "title", "album", "artist", "artist_id", "album_artist", - "album_id", "has_cover_art", "track_number", "disc_number", "year", "size", "suffix", "duration", - "bit_rate", "genre", "compilation", "updated_at") +func (r mediaFileRepository) CountAll(options ...model.QueryOptions) (int64, error) { + return r.count(Select(), options...) } -func (r *mediaFileRepository) Get(id string) (*model.MediaFile, error) { - tm := mediaFile{ID: id} - err := r.ormer.Read(&tm) - if err == orm.ErrNoRows { - return nil, model.ErrNotFound - } - if err != nil { - return nil, err - } - a := model.MediaFile(tm) - return &a, nil +func (r mediaFileRepository) Exists(id string) (bool, error) { + return r.exists(Select().Where(Eq{"id": id})) } -func (r *mediaFileRepository) toMediaFiles(all []mediaFile) model.MediaFiles { - result := make(model.MediaFiles, len(all)) - for i, m := range all { - result[i] = model.MediaFile(m) - } - return result -} - -func (r *mediaFileRepository) FindByAlbum(albumId string) (model.MediaFiles, error) { - var mfs []mediaFile - _, err := r.newQuery().Filter("album_id", albumId).OrderBy("disc_number", "track_number").All(&mfs) - if err != nil { - return nil, err - } - return r.toMediaFiles(mfs), nil -} - -func (r *mediaFileRepository) FindByPath(path string) (model.MediaFiles, error) { - var mfs []mediaFile - _, err := r.newQuery().Filter("path__istartswith", path).OrderBy("disc_number", "track_number").All(&mfs) - if err != nil { - return nil, err - } - var filtered []mediaFile - path = strings.ToLower(path) + string(os.PathSeparator) - for _, mf := range mfs { - filename := strings.TrimPrefix(strings.ToLower(mf.Path), path) - if len(strings.Split(filename, string(os.PathSeparator))) > 1 { - continue - } - filtered = append(filtered, mf) - } - return r.toMediaFiles(filtered), nil -} - -func (r *mediaFileRepository) DeleteByPath(path string) error { - var mfs []mediaFile - // TODO Paginate this (and all other situations similar) - _, err := r.newQuery().Filter("path__istartswith", path).OrderBy("disc_number", "track_number").All(&mfs) +func (r mediaFileRepository) Put(m *model.MediaFile) error { + values, _ := toSqlArgs(*m) + update := Update(r.tableName).Where(Eq{"id": m.ID}).SetMap(values) + count, err := r.executeSQL(update) if err != nil { return err } - var filtered []string - path = strings.ToLower(path) + string(os.PathSeparator) - for _, mf := range mfs { - filename := strings.TrimPrefix(strings.ToLower(mf.Path), path) - if len(strings.Split(filename, string(os.PathSeparator))) > 1 { - continue - } - filtered = append(filtered, mf.ID) - } - if len(filtered) == 0 { + if count > 0 { return nil } - _, err = r.newQuery().Filter("id__in", filtered).Delete() + insert := Insert(r.tableName).SetMap(values) + _, err = r.executeSQL(insert) return err } -func (r *mediaFileRepository) GetRandom(options ...model.QueryOptions) (model.MediaFiles, error) { - sq := r.newRawQuery(options...) +func (r mediaFileRepository) selectMediaFile(options ...model.QueryOptions) SelectBuilder { + return r.newSelectWithAnnotation(model.MediaItemType, "media_file.id", options...).Columns("media_file.*") +} + +func (r mediaFileRepository) Get(id string) (*model.MediaFile, error) { + sel := r.selectMediaFile().Where(Eq{"id": id}) + var res model.MediaFile + err := r.queryOne(sel, &res) + return &res, err +} + +func (r mediaFileRepository) GetAll(options ...model.QueryOptions) (model.MediaFiles, error) { + sq := r.selectMediaFile(options...) + var res model.MediaFiles + err := r.queryAll(sq, &res) + return res, err +} + +func (r mediaFileRepository) FindByAlbum(albumId string) (model.MediaFiles, error) { + sel := r.selectMediaFile().Where(Eq{"album_id": albumId}) + var res model.MediaFiles + err := r.queryAll(sel, &res) + return res, err +} + +func (r mediaFileRepository) FindByPath(path string) (model.MediaFiles, error) { + sel := r.selectMediaFile().Where(Like{"path": path + "%"}) + var res model.MediaFiles + err := r.queryAll(sel, &res) + return res, err +} + +func (r mediaFileRepository) GetStarred(userId string, options ...model.QueryOptions) (model.MediaFiles, error) { + sq := r.selectMediaFile(options...).Where("starred = true") + var starred model.MediaFiles + err := r.queryAll(sq, &starred) + return starred, err +} + +// TODO Keep order when paginating +func (r mediaFileRepository) GetRandom(options ...model.QueryOptions) (model.MediaFiles, error) { + sq := r.selectMediaFile(options...) switch r.ormer.Driver().Type() { case orm.DRMySQL: sq = sq.OrderBy("RAND()") default: sq = sq.OrderBy("RANDOM()") } - sql, args, err := sq.ToSql() + sql, args, err := r.toSql(sq) if err != nil { return nil, err } - var results []mediaFile + var results model.MediaFiles _, err = r.ormer.Raw(sql, args...).QueryRows(&results) - return r.toMediaFiles(results), err + return results, err } -func (r *mediaFileRepository) GetStarred(userId string, options ...model.QueryOptions) (model.MediaFiles, error) { - var starred []mediaFile - sq := r.newRawQuery(options...).Join("annotation").Where("annotation.item_id = " + r.tableName + ".id") - sq = sq.Where(squirrel.And{ - squirrel.Eq{"annotation.user_id": userId}, - squirrel.Eq{"annotation.starred": true}, - }) - sql, args, err := sq.ToSql() - if err != nil { - return nil, err - } - _, err = r.ormer.Raw(sql, args...).QueryRows(&starred) - if err != nil { - return nil, err - } - return r.toMediaFiles(starred), nil +func (r mediaFileRepository) Delete(id string) error { + return r.delete(Eq{"id": id}) } -func (r *mediaFileRepository) Search(q string, offset int, size int) (model.MediaFiles, error) { +func (r mediaFileRepository) DeleteByPath(path string) error { + del := Delete(r.tableName).Where(Like{"path": path + "%"}) + _, err := r.executeSQL(del) + return err +} + +func (r mediaFileRepository) Search(q string, offset int, size int) (model.MediaFiles, error) { + q = strings.TrimSpace(sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*")))) if len(q) <= 2 { - return nil, nil + return model.MediaFiles{}, nil } - - var results []mediaFile - err := r.doSearch(r.tableName, q, offset, size, &results, "title") + sq := Select("*").From(r.tableName) + sq = sq.Limit(uint64(size)).Offset(uint64(offset)).OrderBy("title") + sq = sq.Join("search").Where("search.id = " + r.tableName + ".id") + parts := strings.Split(q, " ") + for _, part := range parts { + sq = sq.Where(Or{ + Like{"full_text": part + "%"}, + Like{"full_text": "%" + part + "%"}, + }) + } + sql, args, err := r.toSql(sq) if err != nil { return nil, err } - return r.toMediaFiles(results), nil + var results model.MediaFiles + _, err = r.ormer.Raw(sql, args...).QueryRows(results) + return results, err +} + +func (r mediaFileRepository) Count(options ...rest.QueryOptions) (int64, error) { + return r.CountAll(r.parseRestOptions(options...)) +} + +func (r mediaFileRepository) Read(id string) (interface{}, error) { + return r.Get(id) +} + +func (r mediaFileRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { + return r.GetAll(r.parseRestOptions(options...)) +} + +func (r mediaFileRepository) EntityName() string { + return "mediafile" +} + +func (r mediaFileRepository) NewInstance() interface{} { + return model.MediaFile{} } var _ model.MediaFileRepository = (*mediaFileRepository)(nil) -var _ = model.MediaFile(mediaFile{}) +var _ model.ResourceRepository = (*mediaFileRepository)(nil) diff --git a/persistence/mediafile_repository_test.go b/persistence/mediafile_repository_test.go index a972aaf5..4861995d 100644 --- a/persistence/mediafile_repository_test.go +++ b/persistence/mediafile_repository_test.go @@ -1,29 +1,86 @@ package persistence import ( - "os" - "path/filepath" + "context" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" + "github.com/google/uuid" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("MediaFileRepository", func() { - var repo model.MediaFileRepository +var _ = Describe("MediaRepository", func() { + var mr model.MediaFileRepository BeforeEach(func() { - repo = NewMediaFileRepository(orm.NewOrm()) + ctx := context.WithValue(context.Background(), "user", &model.User{ID: "userid"}) + mr = NewMediaFileRepository(ctx, orm.NewOrm()) }) - Describe("FindByPath", func() { - It("returns all records from a given ArtistID", func() { - path := string(os.PathSeparator) + filepath.Join("beatles", "1") - Expect(repo.FindByPath(path)).To(Equal(model.MediaFiles{ - songComeTogether, - })) - }) + It("gets mediafile from the DB", func() { + Expect(mr.Get("4")).To(Equal(&songAntenna)) }) + It("returns ErrNotFound", func() { + _, err := mr.Get("56") + Expect(err).To(MatchError(model.ErrNotFound)) + }) + + It("counts the number of mediafiles in the DB", func() { + Expect(mr.CountAll()).To(Equal(int64(4))) + }) + + It("checks existence of mediafiles in the DB", func() { + Expect(mr.Exists(songAntenna.ID)).To(BeTrue()) + Expect(mr.Exists("666")).To(BeFalse()) + }) + + It("find mediafiles by album", func() { + Expect(mr.FindByAlbum("3")).To(Equal(model.MediaFiles{ + songRadioactivity, + songAntenna, + })) + }) + + It("returns empty array when no tracks are found", func() { + Expect(mr.FindByAlbum("67")).To(Equal(model.MediaFiles{})) + }) + + It("finds tracks by path", func() { + Expect(mr.FindByPath(P("/beatles/1/sgt"))).To(Equal(model.MediaFiles{ + songDayInALife, + })) + }) + + It("returns starred tracks", func() { + Expect(mr.GetStarred("userid")).To(Equal(model.MediaFiles{ + songComeTogether, + })) + }) + + It("delete tracks by id", func() { + random, _ := uuid.NewRandom() + id := random.String() + Expect(mr.Put(&model.MediaFile{ID: id})).To(BeNil()) + + Expect(mr.Delete(id)).To(BeNil()) + + _, err := mr.Get(id) + Expect(err).To(MatchError(model.ErrNotFound)) + }) + + It("delete tracks by path", func() { + id1 := "1111" + Expect(mr.Put(&model.MediaFile{ID: id1, Path: P("/abc/123/" + id1 + ".mp3")})).To(BeNil()) + id2 := "2222" + Expect(mr.Put(&model.MediaFile{ID: id2, Path: P("/abc/123/" + id2 + ".mp3")})).To(BeNil()) + + Expect(mr.DeleteByPath(P("/abc"))).To(BeNil()) + + _, err := mr.Get(id1) + Expect(err).To(MatchError(model.ErrNotFound)) + _, err = mr.Get(id2) + Expect(err).To(MatchError(model.ErrNotFound)) + }) }) diff --git a/persistence/mediafolders_repository.go b/persistence/mediafolders_repository.go index 50932d51..be47ebec 100644 --- a/persistence/mediafolders_repository.go +++ b/persistence/mediafolders_repository.go @@ -1,17 +1,19 @@ package persistence import ( + "context" + "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/conf" "github.com/deluan/navidrome/model" ) type mediaFolderRepository struct { - model.MediaFolderRepository + ctx context.Context } -func NewMediaFolderRepository(o orm.Ormer) model.MediaFolderRepository { - return &mediaFolderRepository{} +func NewMediaFolderRepository(ctx context.Context, o orm.Ormer) model.MediaFolderRepository { + return &mediaFolderRepository{ctx} } func (*mediaFolderRepository) GetAll() (model.MediaFolders, error) { diff --git a/persistence/persistence.go b/persistence/persistence.go index 1b19a911..009a204b 100644 --- a/persistence/persistence.go +++ b/persistence/persistence.go @@ -10,19 +10,14 @@ import ( "github.com/deluan/navidrome/conf" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" ) -const batchSize = 100 - var ( - once sync.Once - driver = "sqlite3" - mappedModels map[interface{}]interface{} + once sync.Once + driver = "sqlite3" ) -type SQLStore struct { +type NewSQLStore struct { orm orm.Ormer } @@ -39,57 +34,72 @@ func New() model.DataStore { panic(err) } }) - return &SQLStore{} + return &NewSQLStore{} } -func (db *SQLStore) Album(context.Context) model.AlbumRepository { - return NewAlbumRepository(db.getOrmer()) +func (db *NewSQLStore) Album(ctx context.Context) model.AlbumRepository { + return NewAlbumRepository(ctx, db.getOrmer()) } -func (db *SQLStore) Artist(context.Context) model.ArtistRepository { - return NewArtistRepository(db.getOrmer()) +func (db *NewSQLStore) Artist(ctx context.Context) model.ArtistRepository { + return NewArtistRepository(ctx, db.getOrmer()) } -func (db *SQLStore) MediaFile(context.Context) model.MediaFileRepository { - return NewMediaFileRepository(db.getOrmer()) +func (db *NewSQLStore) MediaFile(ctx context.Context) model.MediaFileRepository { + return NewMediaFileRepository(ctx, db.getOrmer()) } -func (db *SQLStore) MediaFolder(context.Context) model.MediaFolderRepository { - return NewMediaFolderRepository(db.getOrmer()) +func (db *NewSQLStore) MediaFolder(ctx context.Context) model.MediaFolderRepository { + return NewMediaFolderRepository(ctx, db.getOrmer()) } -func (db *SQLStore) Genre(context.Context) model.GenreRepository { - return NewGenreRepository(db.getOrmer()) +func (db *NewSQLStore) Genre(ctx context.Context) model.GenreRepository { + return NewGenreRepository(ctx, db.getOrmer()) } -func (db *SQLStore) Playlist(context.Context) model.PlaylistRepository { - return NewPlaylistRepository(db.getOrmer()) +func (db *NewSQLStore) Playlist(ctx context.Context) model.PlaylistRepository { + return NewPlaylistRepository(ctx, db.getOrmer()) } -func (db *SQLStore) Property(context.Context) model.PropertyRepository { - return NewPropertyRepository(db.getOrmer()) +func (db *NewSQLStore) Property(ctx context.Context) model.PropertyRepository { + return NewPropertyRepository(ctx, db.getOrmer()) } -func (db *SQLStore) User(context.Context) model.UserRepository { - return NewUserRepository(db.getOrmer()) +func (db *NewSQLStore) User(ctx context.Context) model.UserRepository { + return NewUserRepository(ctx, db.getOrmer()) } -func (db *SQLStore) Annotation(context.Context) model.AnnotationRepository { - return NewAnnotationRepository(db.getOrmer()) +func (db *NewSQLStore) Annotation(ctx context.Context) model.AnnotationRepository { + return NewAnnotationRepository(ctx, db.getOrmer()) } -func (db *SQLStore) Resource(ctx context.Context, model interface{}) model.ResourceRepository { - return NewResource(db.getOrmer(), model, getMappedModel(model)) +func getTypeName(model interface{}) string { + return reflect.TypeOf(model).Name() } -func (db *SQLStore) WithTx(block func(tx model.DataStore) error) error { +func (db *NewSQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRepository { + switch m.(type) { + case model.User: + return db.User(ctx).(model.ResourceRepository) + case model.Artist: + return db.Artist(ctx).(model.ResourceRepository) + case model.Album: + return db.Album(ctx).(model.ResourceRepository) + case model.MediaFile: + return db.MediaFile(ctx).(model.ResourceRepository) + } + log.Error("Resource no implemented", "model", getTypeName(m)) + return nil +} + +func (db *NewSQLStore) WithTx(block func(tx model.DataStore) error) error { o := orm.NewOrm() err := o.Begin() if err != nil { return err } - newDb := &SQLStore{orm: o} + newDb := &NewSQLStore{orm: o} err = block(newDb) if err != nil { @@ -107,7 +117,7 @@ func (db *SQLStore) WithTx(block func(tx model.DataStore) error) error { return nil } -func (db *SQLStore) getOrmer() orm.Ormer { +func (db *NewSQLStore) getOrmer() orm.Ormer { if db.orm == nil { return orm.NewOrm() } @@ -115,56 +125,17 @@ func (db *SQLStore) getOrmer() orm.Ormer { } func initORM(dbPath string) error { - verbose := conf.Server.LogLevel == "trace" - orm.Debug = verbose + //verbose := conf.Server.LogLevel == "trace" + //orm.Debug = verbose if strings.Contains(dbPath, "postgres") { driver = "postgres" } err := orm.RegisterDataBase("default", driver, dbPath) if err != nil { - panic(err) + return err } - return orm.RunSyncdb("default", false, verbose) -} - -func collectField(collection interface{}, getValue func(item interface{}) string) []string { - s := reflect.ValueOf(collection) - result := make([]string, s.Len()) - - for i := 0; i < s.Len(); i++ { - result[i] = getValue(s.Index(i).Interface()) - } - - return result -} - -func getType(myvar interface{}) string { - if t := reflect.TypeOf(myvar); t.Kind() == reflect.Ptr { - return t.Elem().Name() - } else { - return t.Name() - } -} - -func registerModel(model interface{}, mappedModel interface{}) { - mappedModels[getType(model)] = mappedModel - orm.RegisterModel(mappedModel) -} - -func getMappedModel(model interface{}) interface{} { - return mappedModels[getType(model)] -} - -func init() { - mappedModels = map[interface{}]interface{}{} - - registerModel(model.Artist{}, new(artist)) - registerModel(model.Album{}, new(album)) - registerModel(model.MediaFile{}, new(mediaFile)) - registerModel(model.Property{}, new(property)) - registerModel(model.Playlist{}, new(playlist)) - registerModel(model.User{}, new(user)) - registerModel(model.Annotation{}, new(annotation)) - - orm.RegisterModel(new(search)) + // TODO Remove all RegisterModels (i.e. don't use orm.Insert/Update) + orm.RegisterModel(new(annotation)) + + return nil } diff --git a/persistence/persistence_suite_test.go b/persistence/persistence_suite_test.go index cd3dd10a..8c15d40f 100644 --- a/persistence/persistence_suite_test.go +++ b/persistence/persistence_suite_test.go @@ -7,47 +7,44 @@ import ( "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/conf" + "github.com/deluan/navidrome/db" "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" "github.com/deluan/navidrome/tests" + _ "github.com/mattn/go-sqlite3" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) func TestPersistence(t *testing.T) { tests.Init(t, true) + + //os.Remove("./test-123.db") + //conf.Server.DbPath = "./test-123.db" + conf.Server.DbPath = "file::memory:?cache=shared" + New() + db.EnsureDB() log.SetLevel(log.LevelCritical) RegisterFailHandler(Fail) RunSpecs(t, "Persistence Suite") } -var artistSaaraSaara = model.Artist{ID: "1", Name: "Saara Saara", AlbumCount: 2} -var artistKraftwerk = model.Artist{ID: "2", Name: "Kraftwerk"} -var artistBeatles = model.Artist{ID: "3", Name: "The Beatles"} -var testArtists = model.Artists{ - artistSaaraSaara, - artistKraftwerk, - artistBeatles, -} +var artistKraftwerk = model.Artist{ID: "2", Name: "Kraftwerk", AlbumCount: 1} +var artistBeatles = model.Artist{ID: "3", Name: "The Beatles", AlbumCount: 2} -var albumSgtPeppers = model.Album{ID: "1", Name: "Sgt Peppers", Artist: "The Beatles", ArtistID: "1", Genre: "Rock"} -var albumAbbeyRoad = model.Album{ID: "2", Name: "Abbey Road", Artist: "The Beatles", ArtistID: "1", Genre: "Rock"} -var albumRadioactivity = model.Album{ID: "3", Name: "Radioactivity", Artist: "Kraftwerk", ArtistID: "2", Genre: "Electronic"} +var albumSgtPeppers = model.Album{ID: "1", Name: "Sgt Peppers", Artist: "The Beatles", ArtistID: "3", Genre: "Rock", CoverArtId: "1", CoverArtPath: P("/beatles/1/sgt/a day.mp3"), SongCount: 1} +var albumAbbeyRoad = model.Album{ID: "2", Name: "Abbey Road", Artist: "The Beatles", ArtistID: "3", Genre: "Rock", CoverArtId: "2", CoverArtPath: P("/beatles/1/come together.mp3"), SongCount: 1} +var albumRadioactivity = model.Album{ID: "3", Name: "Radioactivity", Artist: "Kraftwerk", ArtistID: "2", Genre: "Electronic", CoverArtId: "3", CoverArtPath: P("/kraft/radio/radio.mp3"), SongCount: 2, Starred: true} var testAlbums = model.Albums{ albumSgtPeppers, albumAbbeyRoad, albumRadioactivity, } -var annRadioactivity = model.Annotation{AnnotationID: "1", UserID: "userid", ItemType: model.AlbumItemType, ItemID: "3", Starred: true} -var testAnnotations = []model.Annotation{ - annRadioactivity, -} - -var songDayInALife = model.MediaFile{ID: "1", Title: "A Day In A Life", ArtistID: "3", AlbumID: "1", Genre: "Rock", Path: P("/beatles/1/sgt/a day.mp3")} -var songComeTogether = model.MediaFile{ID: "2", Title: "Come Together", ArtistID: "3", AlbumID: "2", Genre: "Rock", Path: P("/beatles/1/come together.mp3")} -var songRadioactivity = model.MediaFile{ID: "3", Title: "Radioactivity", ArtistID: "2", AlbumID: "3", Genre: "Electronic", Path: P("/kraft/radio/radio.mp3")} -var songAntenna = model.MediaFile{ID: "4", Title: "Antenna", ArtistID: "2", AlbumID: "3", Genre: "Electronic", Path: P("/kraft/radio/antenna.mp3")} +var songDayInALife = model.MediaFile{ID: "1", Title: "A Day In A Life", ArtistID: "3", Artist: "The Beatles", AlbumID: "1", Album: "Sgt Peppers", Genre: "Rock", Path: P("/beatles/1/sgt/a day.mp3")} +var songComeTogether = model.MediaFile{ID: "2", Title: "Come Together", ArtistID: "3", Artist: "The Beatles", AlbumID: "2", Album: "Abbey Road", Genre: "Rock", Path: P("/beatles/1/come together.mp3"), Starred: true} +var songRadioactivity = model.MediaFile{ID: "3", Title: "Radioactivity", ArtistID: "2", Artist: "Kraftwerk", AlbumID: "3", Album: "Radioactivity", Genre: "Electronic", Path: P("/kraft/radio/radio.mp3")} +var songAntenna = model.MediaFile{ID: "4", Title: "Antenna", ArtistID: "2", Artist: "Kraftwerk", AlbumID: "3", Genre: "Electronic", Path: P("/kraft/radio/antenna.mp3")} var testSongs = model.MediaFiles{ songDayInALife, songComeTogether, @@ -55,37 +52,43 @@ var testSongs = model.MediaFiles{ songAntenna, } +var annAlbumRadioactivity = model.Annotation{AnnotationID: "1", UserID: "userid", ItemType: model.AlbumItemType, ItemID: "3", Starred: true} +var annSongComeTogether = model.Annotation{AnnotationID: "2", UserID: "userid", ItemType: model.MediaItemType, ItemID: "2", Starred: true} +var testAnnotations = []model.Annotation{ + annAlbumRadioactivity, + annSongComeTogether, +} + +var ( + plsBest = model.Playlist{ + ID: "10", + Name: "Best", + Comment: "No Comments", + Duration: 10, + Owner: "userid", + Public: true, + Tracks: model.MediaFiles{{ID: "1"}, {ID: "3"}}, + } + plsCool = model.Playlist{ID: "11", Name: "Cool", Tracks: model.MediaFiles{{ID: "4"}}} + testPlaylists = model.Playlists{plsBest, plsCool} +) + func P(path string) string { return strings.ReplaceAll(path, "/", string(os.PathSeparator)) } var _ = Describe("Initialize test DB", func() { BeforeSuite(func() { - conf.Server.DbPath = ":memory:" - ds := New() - artistRepo := ds.Artist(nil) - for _, a := range testArtists { - err := artistRepo.Put(&a) - if err != nil { - panic(err) - } - } - albumRepository := ds.Album(nil) - for _, a := range testAlbums { - err := albumRepository.Put(&a) - if err != nil { - panic(err) - } - } - mediaFileRepository := ds.MediaFile(nil) + o := orm.NewOrm() + mr := NewMediaFileRepository(nil, o) + for _, s := range testSongs { - err := mediaFileRepository.Put(&s) + err := mr.Put(&s) if err != nil { panic(err) } } - o := orm.NewOrm() for _, a := range testAnnotations { ann := annotation(a) _, err := o.Insert(&ann) @@ -93,5 +96,13 @@ var _ = Describe("Initialize test DB", func() { panic(err) } } + + pr := NewPlaylistRepository(nil, o) + for _, pls := range testPlaylists { + err := pr.Put(&pls) + if err != nil { + panic(err) + } + } }) }) diff --git a/persistence/playlist_repository.go b/persistence/playlist_repository.go index 6eb23b82..c1c8b48c 100644 --- a/persistence/playlist_repository.go +++ b/persistence/playlist_repository.go @@ -1,57 +1,73 @@ package persistence import ( + "context" "strings" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" "github.com/google/uuid" ) type playlist struct { - ID string `orm:"pk;column(id)"` - Name string `orm:"index"` + ID string `orm:"column(id)"` + Name string Comment string Duration int Owner string Public bool - Tracks string `orm:"type(text)"` + Tracks string } type playlistRepository struct { sqlRepository } -func NewPlaylistRepository(o orm.Ormer) model.PlaylistRepository { +func NewPlaylistRepository(ctx context.Context, o orm.Ormer) model.PlaylistRepository { r := &playlistRepository{} + r.ctx = ctx r.ormer = o r.tableName = "playlist" return r } +func (r *playlistRepository) CountAll() (int64, error) { + return r.count(Select()) +} + +func (r *playlistRepository) Exists(id string) (bool, error) { + return r.exists(Select().Where(Eq{"id": id})) +} + +func (r *playlistRepository) Delete(id string) error { + return r.delete(Eq{"id": id}) +} + func (r *playlistRepository) Put(p *model.Playlist) error { if p.ID == "" { id, _ := uuid.NewRandom() p.ID = id.String() } - tp := r.fromModel(p) - err := r.put(p.ID, &tp) + values, _ := toSqlArgs(r.fromModel(p)) + update := Update(r.tableName).Where(Eq{"id": p.ID}).SetMap(values) + count, err := r.executeSQL(update) if err != nil { return err } + if count > 0 { + return nil + } + insert := Insert(r.tableName).SetMap(values) + _, err = r.executeSQL(insert) return err } func (r *playlistRepository) Get(id string) (*model.Playlist, error) { - tp := &playlist{ID: id} - err := r.ormer.Read(tp) - if err == orm.ErrNoRows { - return nil, model.ErrNotFound - } - if err != nil { - return nil, err - } - pls := r.toModel(tp) + sel := r.newSelect().Columns("*").Where(Eq{"id": id}) + var res playlist + err := r.queryOne(sel, &res) + pls := r.toModel(&res) return &pls, err } @@ -60,35 +76,34 @@ func (r *playlistRepository) GetWithTracks(id string) (*model.Playlist, error) { if err != nil { return nil, err } - qs := r.ormer.QueryTable(&mediaFile{}) + mfRepo := NewMediaFileRepository(r.ctx, r.ormer) pls.Duration = 0 var newTracks model.MediaFiles for _, t := range pls.Tracks { - mf := &mediaFile{} - if err := qs.Filter("id", t.ID).One(mf); err == nil { - pls.Duration += mf.Duration - newTracks = append(newTracks, model.MediaFile(*mf)) + mf, err := mfRepo.Get(t.ID) + if err != nil { + continue } + pls.Duration += mf.Duration + newTracks = append(newTracks, model.MediaFile(*mf)) } pls.Tracks = newTracks return pls, err } func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) { - var all []playlist - _, err := r.newQuery(options...).All(&all) - if err != nil { - return nil, err - } - return r.toModels(all) + sel := r.newSelect(options...).Columns("*") + var res []playlist + err := r.queryAll(sel, &res) + return r.toModels(res), err } -func (r *playlistRepository) toModels(all []playlist) (model.Playlists, error) { +func (r *playlistRepository) toModels(all []playlist) model.Playlists { result := make(model.Playlists, len(all)) for i, p := range all { result[i] = r.toModel(&p) } - return result, nil + return result } func (r *playlistRepository) toModel(p *playlist) model.Playlist { diff --git a/persistence/playlist_repository_test.go b/persistence/playlist_repository_test.go new file mode 100644 index 00000000..da1ac87f --- /dev/null +++ b/persistence/playlist_repository_test.go @@ -0,0 +1,78 @@ +package persistence + +import ( + "context" + + "github.com/astaxie/beego/orm" + "github.com/deluan/navidrome/model" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("PlaylistRepository", func() { + var repo model.PlaylistRepository + + BeforeEach(func() { + repo = NewPlaylistRepository(context.Background(), orm.NewOrm()) + }) + + Describe("Count", func() { + It("returns the number of playlists in the DB", func() { + Expect(repo.CountAll()).To(Equal(int64(2))) + }) + }) + + Describe("Exist", func() { + It("returns true for an existing playlist", func() { + Expect(repo.Exists("11")).To(BeTrue()) + }) + It("returns false for a non-existing playlist", func() { + Expect(repo.Exists("666")).To(BeFalse()) + }) + }) + + Describe("Get", func() { + It("returns an existing playlist", func() { + Expect(repo.Get("10")).To(Equal(&plsBest)) + }) + It("returns ErrNotFound for a non-existing playlist", func() { + _, err := repo.Get("666") + Expect(err).To(MatchError(model.ErrNotFound)) + }) + }) + + Describe("Put/Get/Delete", func() { + newPls := model.Playlist{ID: "22", Name: "Great!", Tracks: model.MediaFiles{{ID: "4"}}} + It("saves the playlist to the DB", func() { + Expect(repo.Put(&newPls)).To(BeNil()) + }) + It("returns the newly created playlist", func() { + Expect(repo.Get("22")).To(Equal(&newPls)) + }) + It("returns deletes the playlist", func() { + Expect(repo.Delete("22")).To(BeNil()) + }) + It("returns error if tries to retrieve the deleted playlist", func() { + _, err := repo.Get("22") + Expect(err).To(MatchError(model.ErrNotFound)) + }) + }) + + Describe("GetWithTracks", func() { + It("returns an existing playlist", func() { + pls, err := repo.GetWithTracks("10") + Expect(err).To(BeNil()) + Expect(pls.Name).To(Equal(plsBest.Name)) + Expect(pls.Tracks).To(Equal(model.MediaFiles{ + songDayInALife, + songRadioactivity, + })) + }) + }) + + Describe("GetAll", func() { + It("returns all playlists from DB", func() { + Expect(repo.GetAll()).To(Equal(model.Playlists{plsBest, plsCool})) + }) + }) +}) diff --git a/persistence/property_repository.go b/persistence/property_repository.go index b54a3c7f..efbbda4e 100644 --- a/persistence/property_repository.go +++ b/persistence/property_repository.go @@ -1,6 +1,9 @@ package persistence import ( + "context" + + "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" ) @@ -14,35 +17,41 @@ type propertyRepository struct { sqlRepository } -func NewPropertyRepository(o orm.Ormer) model.PropertyRepository { +func NewPropertyRepository(ctx context.Context, o orm.Ormer) model.PropertyRepository { r := &propertyRepository{} + r.ctx = ctx r.ormer = o r.tableName = "property" return r } -func (r *propertyRepository) Put(id string, value string) error { - p := &property{ID: id, Value: value} - num, err := r.ormer.Update(p) +func (r propertyRepository) Put(id string, value string) error { + update := squirrel.Update(r.tableName).Set("value", value).Where(squirrel.Eq{"id": id}) + count, err := r.executeSQL(update) if err != nil { return nil } - if num == 0 { - _, err = r.ormer.Insert(p) + if count > 0 { + return nil } + insert := squirrel.Insert(r.tableName).Columns("id", "value").Values(id, value) + _, err = r.executeSQL(insert) return err } -func (r *propertyRepository) Get(id string) (string, error) { - p := &property{ID: id} - err := r.ormer.Read(p) - if err == orm.ErrNoRows { - return "", model.ErrNotFound +func (r propertyRepository) Get(id string) (string, error) { + sel := squirrel.Select("value").From(r.tableName).Where(squirrel.Eq{"id": id}) + resp := struct { + Value string + }{} + err := r.queryOne(sel, &resp) + if err != nil { + return "", err } - return p.Value, err + return resp.Value, nil } -func (r *propertyRepository) DefaultGet(id string, defaultValue string) (string, error) { +func (r propertyRepository) DefaultGet(id string, defaultValue string) (string, error) { value, err := r.Get(id) if err == model.ErrNotFound { return defaultValue, nil @@ -52,5 +61,3 @@ func (r *propertyRepository) DefaultGet(id string, defaultValue string) (string, } return value, nil } - -var _ model.PropertyRepository = (*propertyRepository)(nil) diff --git a/persistence/property_repository_test.go b/persistence/property_repository_test.go index 9c75bab0..d85f7749 100644 --- a/persistence/property_repository_test.go +++ b/persistence/property_repository_test.go @@ -1,31 +1,34 @@ package persistence import ( + "context" + "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("PropertyRepository", func() { - var repo model.PropertyRepository +var _ = Describe("Property Repository", func() { + var pr model.PropertyRepository BeforeEach(func() { - repo = NewPropertyRepository(orm.NewOrm()) - repo.(*propertyRepository).DeleteAll() + pr = NewPropertyRepository(context.Background(), orm.NewOrm()) }) - It("saves and retrieves data", func() { - Expect(repo.Put("1", "test")).To(BeNil()) - Expect(repo.Get("1")).To(Equal("test")) + It("saves and restore a new property", func() { + id := "1" + value := "a_value" + Expect(pr.Put(id, value)).To(BeNil()) + Expect(pr.Get(id)).To(Equal("a_value")) }) - It("returns default if data is not found", func() { - Expect(repo.DefaultGet("2", "default")).To(Equal("default")) + It("updates a property", func() { + Expect(pr.Put("1", "another_value")).To(BeNil()) + Expect(pr.Get("1")).To(Equal("another_value")) }) - It("returns value if found", func() { - Expect(repo.Put("3", "test")).To(BeNil()) - Expect(repo.DefaultGet("3", "default")).To(Equal("test")) + It("returns a default value if property does not exist", func() { + Expect(pr.DefaultGet("2", "default")).To(Equal("default")) }) }) diff --git a/persistence/resource_repository.go b/persistence/resource_repository.go deleted file mode 100644 index 972610dd..00000000 --- a/persistence/resource_repository.go +++ /dev/null @@ -1,204 +0,0 @@ -package persistence - -import ( - "fmt" - "reflect" - "strconv" - "strings" - - "github.com/astaxie/beego/orm" - "github.com/deluan/navidrome/model" - "github.com/deluan/rest" - "github.com/google/uuid" -) - -type resourceRepository struct { - model.ResourceRepository - model interface{} - mappedModel interface{} - ormer orm.Ormer - instanceType reflect.Type - sliceType reflect.Type -} - -func NewResource(o orm.Ormer, model interface{}, mappedModel interface{}) model.ResourceRepository { - r := &resourceRepository{model: model, mappedModel: mappedModel, ormer: o} - - // Get type of mappedModel (which is a *struct) - rv := reflect.ValueOf(mappedModel) - for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { - rv = rv.Elem() - } - r.instanceType = rv.Type() - r.sliceType = reflect.SliceOf(r.instanceType) - return r -} - -func (r *resourceRepository) EntityName() string { - return r.instanceType.Name() -} - -func (r *resourceRepository) newQuery(options ...rest.QueryOptions) orm.QuerySeter { - qs := r.ormer.QueryTable(r.mappedModel) - if len(options) > 0 { - qs = r.addOptions(qs, options) - qs = r.addFilters(qs, r.buildFilters(qs, options)) - } - return qs -} - -func (r *resourceRepository) NewInstance() interface{} { - return reflect.New(r.instanceType).Interface() -} - -func (r *resourceRepository) NewSlice() interface{} { - slice := reflect.MakeSlice(r.sliceType, 0, 0) - x := reflect.New(slice.Type()) - x.Elem().Set(slice) - return x.Interface() -} - -func (r *resourceRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { - qs := r.newQuery(options...) - dataSet := r.NewSlice() - _, err := qs.All(dataSet) - if err == orm.ErrNoRows { - return dataSet, rest.ErrNotFound - } - return dataSet, err -} - -func (r *resourceRepository) Count(options ...rest.QueryOptions) (int64, error) { - qs := r.newQuery(options...) - count, err := qs.Count() - if err == orm.ErrNoRows { - err = rest.ErrNotFound - } - return count, err -} - -func (r *resourceRepository) Read(id string) (interface{}, error) { - qs := r.newQuery().Filter("id", id) - data := r.NewInstance() - err := qs.One(data) - if err == orm.ErrNoRows { - return data, rest.ErrNotFound - } - return data, err -} - -func setUUID(p interface{}) { - f := reflect.ValueOf(p).Elem().FieldByName("ID") - if f.Kind() == reflect.String { - id, _ := uuid.NewRandom() - f.SetString(id.String()) - } -} - -func (r *resourceRepository) Save(p interface{}) (string, error) { - setUUID(p) - id, err := r.ormer.Insert(p) - if err != nil { - if err.Error() != "LastInsertId is not supported by this driver" { - return "", err - } - } - return strconv.FormatInt(id, 10), nil -} - -func (r *resourceRepository) Update(p interface{}, cols ...string) error { - count, err := r.ormer.Update(p, cols...) - if err != nil { - return err - } - if count == 0 { - return rest.ErrNotFound - } - return err -} - -func (r *resourceRepository) addOptions(qs orm.QuerySeter, options []rest.QueryOptions) orm.QuerySeter { - if len(options) == 0 { - return qs - } - opt := options[0] - sort := strings.Split(opt.Sort, ",") - reverse := strings.ToLower(opt.Order) == "desc" - for i, s := range sort { - s = strings.TrimSpace(s) - if reverse { - if s[0] == '-' { - s = strings.TrimPrefix(s, "-") - } else { - s = "-" + s - } - } - sort[i] = strings.Replace(s, ".", "__", -1) - } - if opt.Sort != "" { - qs = qs.OrderBy(sort...) - } - if opt.Max > 0 { - qs = qs.Limit(opt.Max) - } - if opt.Offset > 0 { - qs = qs.Offset(opt.Offset) - } - return qs -} - -func (r *resourceRepository) addFilters(qs orm.QuerySeter, conditions ...*orm.Condition) orm.QuerySeter { - var cond *orm.Condition - for _, c := range conditions { - if c != nil { - if cond == nil { - cond = c - } else { - cond = cond.AndCond(c) - } - } - } - if cond != nil { - return qs.SetCond(cond) - } - return qs -} - -func unmarshalValue(val interface{}) string { - switch v := val.(type) { - case float64: - return strconv.FormatFloat(v, 'f', -1, 64) - case string: - return v - default: - return fmt.Sprintf("%v", val) - } -} - -func (r *resourceRepository) buildFilters(qs orm.QuerySeter, options []rest.QueryOptions) *orm.Condition { - if len(options) == 0 { - return nil - } - cond := orm.NewCondition() - clauses := cond - for f, v := range options[0].Filters { - fn := strings.Replace(f, ".", "__", -1) - s := unmarshalValue(v) - - if strings.HasSuffix(fn, "Id") || strings.HasSuffix(fn, "__id") { - clauses = IdFilter(clauses, fn, s) - } else { - clauses = StartsWithFilter(clauses, fn, s) - } - } - return clauses -} - -func IdFilter(cond *orm.Condition, field, value string) *orm.Condition { - field = strings.TrimSuffix(field, "Id") + "__id" - return cond.And(field, value) -} - -func StartsWithFilter(cond *orm.Condition, field, value string) *orm.Condition { - return cond.And(field+"__istartswith", value) -} diff --git a/persistence/searchable_repository.go b/persistence/searchable_repository.go deleted file mode 100644 index 14700dda..00000000 --- a/persistence/searchable_repository.go +++ /dev/null @@ -1,111 +0,0 @@ -package persistence - -import ( - "strings" - - "github.com/Masterminds/squirrel" - "github.com/astaxie/beego/orm" - "github.com/deluan/navidrome/log" - "github.com/kennygrant/sanitize" -) - -type search struct { - ID string `orm:"pk;column(id)"` - Table string `orm:"index"` - FullText string `orm:"index"` -} - -type searchableRepository struct { - sqlRepository -} - -func (r *searchableRepository) DeleteAll() error { - _, err := r.newQuery().Filter("id__isnull", false).Delete() - if err != nil { - return err - } - return r.removeAllFromIndex(r.ormer, r.tableName) -} - -func (r *searchableRepository) put(id string, textToIndex string, a interface{}, fields ...string) error { - c, err := r.newQuery().Filter("id", id).Count() - if err != nil { - return err - } - if c == 0 { - err = r.insert(a) - if err != nil && err.Error() == "LastInsertId is not supported by this driver" { - err = nil - } - } else { - _, err = r.ormer.Update(a, fields...) - } - if err != nil { - return err - } - return r.addToIndex(r.tableName, id, textToIndex) -} - -func (r *searchableRepository) addToIndex(table, id, text string) error { - item := search{ID: id, Table: table} - err := r.ormer.Read(&item) - if err != nil && err != orm.ErrNoRows { - return err - } - sanitizedText := strings.TrimSpace(sanitize.Accents(strings.ToLower(text))) - item = search{ID: id, Table: table, FullText: sanitizedText} - if err == orm.ErrNoRows { - err = r.insert(&item) - } else { - _, err = r.ormer.Update(&item) - } - return err -} - -func (r *searchableRepository) removeFromIndex(table string, ids []string) error { - var offset int - for { - var subset = paginateSlice(ids, offset, batchSize) - if len(subset) == 0 { - break - } - log.Trace("Deleting searchable items", "table", table, "num", len(subset), "from", offset) - offset += len(subset) - _, err := r.ormer.QueryTable(&search{}).Filter("table", table).Filter("id__in", subset).Delete() - if err != nil { - return err - } - } - return nil -} - -func (r *searchableRepository) removeAllFromIndex(o orm.Ormer, table string) error { - _, err := o.QueryTable(&search{}).Filter("table", table).Delete() - return err -} - -func (r *searchableRepository) doSearch(table string, q string, offset, size int, results interface{}, orderBys ...string) error { - q = strings.TrimSpace(sanitize.Accents(strings.ToLower(strings.TrimSuffix(q, "*")))) - if len(q) <= 2 { - return nil - } - sq := squirrel.Select("*").From(table) - sq = sq.Limit(uint64(size)).Offset(uint64(offset)) - if len(orderBys) > 0 { - sq = sq.OrderBy(orderBys...) - } - sq = sq.Join("search").Where("search.id = " + table + ".id") - parts := strings.Split(q, " ") - for _, part := range parts { - sq = sq.Where(squirrel.Or{ - squirrel.Like{"full_text": part + "%"}, - squirrel.Like{"full_text": "%" + part + "%"}, - }) - } - sql, args, err := sq.ToSql() - if err != nil { - return err - } - _, err = r.ormer.Raw(sql, args...).QueryRows(results) - return err -} diff --git a/persistence/sql_repository.go b/persistence/sql_repository.go index f6172a2a..2f5ff597 100644 --- a/persistence/sql_repository.go +++ b/persistence/sql_repository.go @@ -1,40 +1,51 @@ package persistence import ( - "github.com/Masterminds/squirrel" + "context" + "fmt" + "strings" + + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" + "github.com/deluan/navidrome/log" "github.com/deluan/navidrome/model" + "github.com/deluan/rest" ) type sqlRepository struct { - tableName string - ormer orm.Ormer + ctx context.Context + tableName string + fieldNames []string + ormer orm.Ormer } -func (r *sqlRepository) newQuery(options ...model.QueryOptions) orm.QuerySeter { - q := r.ormer.QueryTable(r.tableName) - if len(options) > 0 { - opts := options[0] - q = q.Offset(opts.Offset) - if opts.Max > 0 { - q = q.Limit(opts.Max) - } - if opts.Sort != "" { - if opts.Order == "desc" { - q = q.OrderBy("-" + opts.Sort) - } else { - q = q.OrderBy(opts.Sort) - } - } - for field, value := range opts.Filters { - q = q.Filter(field, value) - } +const invalidUserId = "-1" + +func userId(ctx context.Context) string { + user := ctx.Value("user") + if user == nil { + return invalidUserId } - return q + usr := user.(*model.User) + return usr.ID } -func (r *sqlRepository) newRawQuery(options ...model.QueryOptions) squirrel.SelectBuilder { - sq := squirrel.Select("*").From(r.tableName) +func (r *sqlRepository) newSelectWithAnnotation(itemType, idField string, options ...model.QueryOptions) SelectBuilder { + return r.newSelect(options...). + LeftJoin("annotation on ("+ + "annotation.item_id = "+idField+ + " AND annotation.item_type = '"+itemType+"'"+ + " AND annotation.user_id = '"+userId(r.ctx)+"')"). + Columns("starred", "starred_at", "play_count", "play_date", "rating") +} + +func (r *sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder { + sq := Select().From(r.tableName) + sq = r.applyOptions(sq, options...) + return sq +} + +func (r *sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder { if len(options) > 0 { if options[0].Max > 0 { sq = sq.Limit(uint64(options[0].Max)) @@ -49,83 +60,108 @@ func (r *sqlRepository) newRawQuery(options ...model.QueryOptions) squirrel.Sele sq = sq.OrderBy(options[0].Sort) } } + if len(options[0].Filters) > 0 { + for f, v := range options[0].Filters { + sq = sq.Where(Eq{f: v}) + } + } } return sq } -func (r *sqlRepository) CountAll() (int64, error) { - return r.newQuery().Count() -} - -func (r *sqlRepository) Exists(id string) (bool, error) { - c, err := r.newQuery().Filter("id", id).Count() - return c == 1, err -} - -// "Hack" to bypass Postgres driver limitation -func (r *sqlRepository) insert(record interface{}) error { - _, err := r.ormer.Insert(record) - if err != nil && err.Error() != "LastInsertId is not supported by this driver" { - return err +func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) { + query, args, err := r.toSql(sq) + if err != nil { + return 0, err } - return nil + res, err := r.ormer.Raw(query, args...).Exec() + if err != nil { + if err.Error() != "LastInsertId is not supported by this driver" { + return 0, err + } + } + return res.RowsAffected() } -func (r *sqlRepository) put(id string, a interface{}) error { - c, err := r.newQuery().Filter("id", id).Count() +func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error { + query, args, err := r.toSql(sq) if err != nil { return err } - if c == 0 { - err = r.insert(a) - if err != nil && err.Error() == "LastInsertId is not supported by this driver" { - err = nil - } + err = r.ormer.Raw(query, args...).QueryRow(response) + if err == orm.ErrNoRows { + return model.ErrNotFound + } + return err +} + +func (r sqlRepository) queryAll(sq Sqlizer, response interface{}) error { + query, args, err := r.toSql(sq) + if err != nil { return err } - _, err = r.ormer.Update(a) + _, err = r.ormer.Raw(query, args...).QueryRows(response) + if err == orm.ErrNoRows { + return model.ErrNotFound + } return err } -func paginateSlice(slice []string, skip int, size int) []string { - if skip > len(slice) { - skip = len(slice) +func (r sqlRepository) exists(existsQuery SelectBuilder) (bool, error) { + existsQuery = existsQuery.Columns("count(*) as count").From(r.tableName) + query, args, err := r.toSql(existsQuery) + if err != nil { + return false, err } - - end := skip + size - if end > len(slice) { - end = len(slice) - } - - return slice[skip:end] + var res struct{ Count int64 } + err = r.ormer.Raw(query, args...).QueryRow(&res) + return res.Count > 0, err } -func difference(slice1 []string, slice2 []string) []string { - var diffStr []string - m := map[string]int{} - - for _, s1Val := range slice1 { - m[s1Val] = 1 +func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOptions) (int64, error) { + countQuery = countQuery.Columns("count(*) as count").From(r.tableName) + countQuery = r.applyOptions(countQuery, options...) + query, args, err := r.toSql(countQuery) + if err != nil { + return 0, err } - for _, s2Val := range slice2 { - m[s2Val] = m[s2Val] + 1 + var res struct{ Count int64 } + err = r.ormer.Raw(query, args...).QueryRow(&res) + if err == orm.ErrNoRows { + return 0, model.ErrNotFound } + return res.Count, nil +} - for mKey, mVal := range m { - if mVal == 1 { - diffStr = append(diffStr, mKey) +func (r sqlRepository) delete(cond Sqlizer) error { + del := Delete(r.tableName).Where(cond) + _, err := r.executeSQL(del) + if err == orm.ErrNoRows { + return model.ErrNotFound + } + return err +} + +func (r sqlRepository) toSql(sq Sqlizer) (string, []interface{}, error) { + sql, args, err := sq.ToSql() + if err == nil { + log.Trace(r.ctx, "SQL: `"+sql+"`", "args", strings.TrimPrefix(fmt.Sprintf("%#v", args), "[]interface {}")) + } + return sql, args, err +} + +func (r sqlRepository) parseRestOptions(options ...rest.QueryOptions) model.QueryOptions { + qo := model.QueryOptions{} + if len(options) > 0 { + qo.Sort = toSnakeCase(options[0].Sort) + qo.Order = options[0].Order + qo.Max = options[0].Max + qo.Offset = options[0].Offset + if len(options[0].Filters) > 0 { + for f, v := range options[0].Filters { + qo.Filters = Like{f: fmt.Sprintf("%s%%", v)} + } } } - - return diffStr -} - -func (r *sqlRepository) Delete(id string) error { - _, err := r.newQuery().Filter("id", id).Delete() - return err -} - -func (r *sqlRepository) DeleteAll() error { - _, err := r.newQuery().Filter("id__isnull", false).Delete() - return err + return qo } diff --git a/persistence/user_repository.go b/persistence/user_repository.go index 9187dccc..39fbbf5b 100644 --- a/persistence/user_repository.go +++ b/persistence/user_repository.go @@ -1,92 +1,137 @@ package persistence import ( + "context" + "strings" "time" + . "github.com/Masterminds/squirrel" "github.com/astaxie/beego/orm" "github.com/deluan/navidrome/model" "github.com/deluan/rest" + "github.com/google/uuid" ) -type user struct { - ID string `json:"id" orm:"pk;column(id)"` - UserName string `json:"userName" orm:"index;unique"` - Name string `json:"name"` - Email string `json:"email" orm:"unique"` - Password string `json:"password"` - IsAdmin bool `json:"isAdmin"` - LastLoginAt *time.Time `json:"lastLoginAt" orm:"null"` - LastAccessAt *time.Time `json:"lastAccessAt" orm:"null"` - CreatedAt time.Time `json:"createdAt" orm:"auto_now_add;type(datetime)"` - UpdatedAt time.Time `json:"updatedAt" orm:"auto_now;type(datetime)"` -} - type userRepository struct { - ormer orm.Ormer - userResource model.ResourceRepository + sqlRepository } -func NewUserRepository(o orm.Ormer) model.UserRepository { - r := &userRepository{ormer: o} - r.userResource = NewResource(o, model.User{}, new(user)) +func NewUserRepository(ctx context.Context, o orm.Ormer) model.UserRepository { + r := &userRepository{} + r.ctx = ctx + r.ormer = o + r.tableName = "user" return r } func (r *userRepository) CountAll(qo ...model.QueryOptions) (int64, error) { - if len(qo) > 0 { - return r.userResource.Count(rest.QueryOptions(qo[0])) - } - return r.userResource.Count() + return r.count(Select(), qo...) } func (r *userRepository) Get(id string) (*model.User, error) { - u, err := r.userResource.Read(id) - if err != nil { - return nil, err - } - res := model.User(u.(user)) - return &res, nil + sel := r.newSelect().Columns("*").Where(Eq{"id": id}) + var res model.User + err := r.queryOne(sel, &res) + return &res, err +} + +func (r *userRepository) GetAll(options ...model.QueryOptions) (model.Users, error) { + sel := r.newSelect(options...).Columns("*") + var res model.Users + err := r.queryAll(sel, &res) + return res, err } func (r *userRepository) Put(u *model.User) error { - tu := user(*u) - c, err := r.CountAll() + if u.ID == "" { + id, _ := uuid.NewRandom() + u.ID = id.String() + } + u.UserName = strings.ToLower(u.UserName) + values, _ := toSqlArgs(*u) + update := Update(r.tableName).Where(Eq{"id": u.ID}).SetMap(values) + count, err := r.executeSQL(update) if err != nil { return err } - if c == 0 { - _, err = r.userResource.Save(&tu) - return err + if count > 0 { + return nil } - return r.userResource.Update(&tu, "user_name", "is_admin", "password") + insert := Insert(r.tableName).SetMap(values) + _, err = r.executeSQL(insert) + return err } func (r *userRepository) FindByUsername(username string) (*model.User, error) { - tu := user{} - err := r.ormer.QueryTable(user{}).Filter("user_name__iexact", username).One(&tu) - if err == orm.ErrNoRows { - return nil, model.ErrNotFound - } - if err != nil { - return nil, err - } - u := model.User(tu) - return &u, err + sel := r.newSelect().Columns("*").Where(Eq{"user_name": username}) + var usr model.User + err := r.queryOne(sel, &usr) + return &usr, err } func (r *userRepository) UpdateLastLoginAt(id string) error { - now := time.Now() - tu := user{ID: id, LastLoginAt: &now} - _, err := r.ormer.Update(&tu, "last_login_at") + upd := Update(r.tableName).Where(Eq{"id": id}).Set("last_login_at", time.Now()) + _, err := r.executeSQL(upd) return err } func (r *userRepository) UpdateLastAccessAt(id string) error { now := time.Now() - tu := user{ID: id, LastAccessAt: &now} - _, err := r.ormer.Update(&tu, "last_access_at") + upd := Update(r.tableName).Where(Eq{"id": id}).Set("last_access_at", now) + _, err := r.executeSQL(upd) + return err +} + +func (r *userRepository) Count(options ...rest.QueryOptions) (int64, error) { + return r.CountAll(r.parseRestOptions(options...)) +} + +func (r *userRepository) Read(id string) (interface{}, error) { + usr, err := r.Get(id) + if err == model.ErrNotFound { + return nil, rest.ErrNotFound + } + return usr, err +} + +func (r *userRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) { + return r.GetAll(r.parseRestOptions(options...)) +} + +func (r *userRepository) EntityName() string { + return "user" +} + +func (r *userRepository) NewInstance() interface{} { + return &model.User{} +} + +func (r *userRepository) Save(entity interface{}) (string, error) { + usr := entity.(*model.User) + err := r.Put(usr) + if err != nil { + return "", err + } + return usr.ID, err +} + +func (r *userRepository) Update(entity interface{}, cols ...string) error { + usr := entity.(*model.User) + err := r.Put(usr) + if err == model.ErrNotFound { + return rest.ErrNotFound + } + return err +} + +func (r *userRepository) Delete(id string) error { + err := r.Delete(id) + if err == model.ErrNotFound { + return rest.ErrNotFound + } return err } -var _ = model.User(user{}) var _ model.UserRepository = (*userRepository)(nil) +var _ rest.Repository = (*userRepository)(nil) +var _ rest.Persistable = (*userRepository)(nil) diff --git a/scanner/scanner.go b/scanner/scanner.go index 5b1d8212..4758835d 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -34,7 +34,7 @@ func (s *Scanner) Rescan(mediaFolder string, fullRescan bool) error { log.Debug("Scanning folder (full scan)", "folder", mediaFolder) } - err := folderScanner.Scan(nil, lastModifiedSince) + err := folderScanner.Scan(log.NewContext(nil), lastModifiedSince) if err != nil { log.Error("Error importing MediaFolder", "folder", mediaFolder, err) } diff --git a/server/app/auth.go b/server/app/auth.go index edc9ff67..c837502d 100644 --- a/server/app/auth.go +++ b/server/app/auth.go @@ -110,13 +110,15 @@ func CreateAdmin(ds model.DataStore) func(w http.ResponseWriter, r *http.Request func createDefaultUser(ctx context.Context, ds model.DataStore, username, password string) error { id, _ := uuid.NewRandom() log.Warn("Creating initial user", "user", username) + now := time.Now() initialUser := model.User{ - ID: id.String(), - UserName: username, - Name: strings.Title(username), - Email: "", - Password: password, - IsAdmin: true, + ID: id.String(), + UserName: username, + Name: strings.Title(username), + Email: "", + Password: password, + IsAdmin: true, + LastLoginAt: &now, } err := ds.User(ctx).Put(&initialUser) if err != nil {