diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 25ca7a0c..2915d05a 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -29,16 +29,16 @@ import ( // Injectors from wire_injectors.go: func CreateServer(musicFolder string) *server.Server { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) broker := events.GetBroker() serverServer := server.New(dataStore, broker) return serverServer } func CreateNativeAPIRouter() *nativeapi.Router { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) share := core.NewShare(dataStore) playlists := core.NewPlaylists(dataStore) router := nativeapi.New(dataStore, share, playlists) @@ -46,8 +46,8 @@ func CreateNativeAPIRouter() *nativeapi.Router { } func CreateSubsonicAPIRouter() *subsonic.Router { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) fileCache := artwork.GetImageCache() fFmpeg := ffmpeg.New() agentsAgents := agents.New(dataStore) @@ -69,8 +69,8 @@ func CreateSubsonicAPIRouter() *subsonic.Router { } func CreatePublicRouter() *public.Router { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) fileCache := artwork.GetImageCache() fFmpeg := ffmpeg.New() agentsAgents := agents.New(dataStore) @@ -85,22 +85,22 @@ func CreatePublicRouter() *public.Router { } func CreateLastFMRouter() *lastfm.Router { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) router := lastfm.NewRouter(dataStore) return router } func CreateListenBrainzRouter() *listenbrainz.Router { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) router := listenbrainz.NewRouter(dataStore) return router } func GetScanner() scanner.Scanner { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) playlists := core.NewPlaylists(dataStore) fileCache := artwork.GetImageCache() fFmpeg := ffmpeg.New() @@ -114,8 +114,8 @@ func GetScanner() scanner.Scanner { } func GetPlaybackServer() playback.PlaybackServer { - sqlDB := db.Db() - dataStore := persistence.New(sqlDB) + dbDB := db.Db() + dataStore := persistence.New(dbDB) playbackServer := playback.GetInstance(dataStore) return playbackServer } diff --git a/consts/consts.go b/consts/consts.go index 28dfc942..e9d0457d 100644 --- a/consts/consts.go +++ b/consts/consts.go @@ -11,7 +11,7 @@ import ( const ( AppName = "navidrome" - DefaultDbPath = "navidrome.db?cache=shared&_busy_timeout=15000&_journal_mode=WAL&_foreign_keys=on" + DefaultDbPath = "navidrome.db?cache=shared&_cache_size=1000000000&_busy_timeout=5000&_journal_mode=WAL&_synchronous=NORMAL&_foreign_keys=on&_txlock=immediate" InitialSetupFlagKey = "InitialSetup" UIAuthorizationHeader = "X-ND-Authorization" diff --git a/core/scrobbler/play_tracker.go b/core/scrobbler/play_tracker.go index 5899b1fa..a8d75f3a 100644 --- a/core/scrobbler/play_tracker.go +++ b/core/scrobbler/play_tracker.go @@ -162,15 +162,15 @@ func (p *playTracker) Submit(ctx context.Context, submissions []Submission) erro func (p *playTracker) incPlay(ctx context.Context, track *model.MediaFile, timestamp time.Time) error { return p.ds.WithTx(func(tx model.DataStore) error { - err := p.ds.MediaFile(ctx).IncPlayCount(track.ID, timestamp) + err := tx.MediaFile(ctx).IncPlayCount(track.ID, timestamp) if err != nil { return err } - err = p.ds.Album(ctx).IncPlayCount(track.AlbumID, timestamp) + err = tx.Album(ctx).IncPlayCount(track.AlbumID, timestamp) if err != nil { return err } - err = p.ds.Artist(ctx).IncPlayCount(track.ArtistID, timestamp) + err = tx.Artist(ctx).IncPlayCount(track.ArtistID, timestamp) return err }) } diff --git a/db/db.go b/db/db.go index cf0ce2cf..eb17bca7 100644 --- a/db/db.go +++ b/db/db.go @@ -4,6 +4,7 @@ import ( "database/sql" "embed" "fmt" + "runtime" "github.com/mattn/go-sqlite3" "github.com/navidrome/navidrome/conf" @@ -24,8 +25,36 @@ var embedMigrations embed.FS const migrationsFolder = "migrations" -func Db() *sql.DB { - return singleton.GetInstance(func() *sql.DB { +type DB interface { + ReadDB() *sql.DB + WriteDB() *sql.DB + Close() +} + +type db struct { + readDB *sql.DB + writeDB *sql.DB +} + +func (d *db) ReadDB() *sql.DB { + return d.readDB +} + +func (d *db) WriteDB() *sql.DB { + return d.writeDB +} + +func (d *db) Close() { + if err := d.readDB.Close(); err != nil { + log.Error("Error closing read DB", err) + } + if err := d.writeDB.Close(); err != nil { + log.Error("Error closing write DB", err) + } +} + +func Db() DB { + return singleton.GetInstance(func() *db { sql.Register(Driver+"_custom", &sqlite3.SQLiteDriver{ ConnectHook: func(conn *sqlite3.SQLiteConn) error { return conn.RegisterFunc("SEEDEDRAND", hasher.HashFunc(), false) @@ -38,21 +67,32 @@ func Db() *sql.DB { conf.Server.DbPath = Path } log.Debug("Opening DataBase", "dbPath", Path, "driver", Driver) - instance, err := sql.Open(Driver+"_custom", Path) + + rdb, err := sql.Open(Driver+"_custom", Path) if err != nil { panic(err) } - return instance + rdb.SetMaxOpenConns(max(4, runtime.NumCPU())) + + wdb, err := sql.Open(Driver+"_custom", Path) + if err != nil { + panic(err) + } + wdb.SetMaxOpenConns(1) + return &db{ + readDB: rdb, + writeDB: wdb, + } }) } -func Close() error { +func Close() { log.Info("Closing Database") - return Db().Close() + Db().Close() } func Init() func() { - db := Db() + db := Db().WriteDB() // Disable foreign_keys to allow re-creating tables in migrations _, err := db.Exec("PRAGMA foreign_keys=off") @@ -82,11 +122,7 @@ func Init() func() { log.Fatal("Failed to apply new migrations", err) } - return func() { - if err := Close(); err != nil { - log.Error("Error closing DB", err) - } - } + return Close } type statusLogger struct{ numPending int } diff --git a/persistence/album_repository_test.go b/persistence/album_repository_test.go index 34949c4f..503675be 100644 --- a/persistence/album_repository_test.go +++ b/persistence/album_repository_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/navidrome/navidrome/conf" "github.com/navidrome/navidrome/consts" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" @@ -20,7 +21,7 @@ var _ = Describe("AlbumRepository", func() { BeforeEach(func() { ctx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe"}) - repo = NewAlbumRepository(ctx, getDBXBuilder()) + repo = NewAlbumRepository(ctx, NewDBXBuilder(db.Db())) }) Describe("Get", func() { diff --git a/persistence/artist_repository_test.go b/persistence/artist_repository_test.go index 5e0304ef..1b5b7176 100644 --- a/persistence/artist_repository_test.go +++ b/persistence/artist_repository_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/fatih/structs" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" @@ -18,7 +19,7 @@ var _ = Describe("ArtistRepository", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid"}) - repo = NewArtistRepository(ctx, getDBXBuilder()) + repo = NewArtistRepository(ctx, NewDBXBuilder(db.Db())) }) Describe("Count", func() { diff --git a/persistence/dbx_builder.go b/persistence/dbx_builder.go new file mode 100644 index 00000000..bdb4dc5c --- /dev/null +++ b/persistence/dbx_builder.go @@ -0,0 +1,50 @@ +package persistence + +import ( + "github.com/navidrome/navidrome/db" + "github.com/pocketbase/dbx" +) + +type dbxBuilder struct { + dbx.Builder + rdb dbx.Builder +} + +func NewDBXBuilder(d db.DB) *dbxBuilder { + b := &dbxBuilder{} + b.Builder = dbx.NewFromDB(d.WriteDB(), db.Driver) + b.rdb = dbx.NewFromDB(d.ReadDB(), db.Driver) + return b +} + +func (d *dbxBuilder) NewQuery(s string) *dbx.Query { + return d.rdb.NewQuery(s) +} + +func (d *dbxBuilder) Select(s ...string) *dbx.SelectQuery { + return d.rdb.Select(s...) +} + +func (d *dbxBuilder) GeneratePlaceholder(i int) string { + return d.rdb.GeneratePlaceholder(i) +} + +func (d *dbxBuilder) Quote(s string) string { + return d.rdb.Quote(s) +} + +func (d *dbxBuilder) QuoteSimpleTableName(s string) string { + return d.rdb.QuoteSimpleTableName(s) +} + +func (d *dbxBuilder) QuoteSimpleColumnName(s string) string { + return d.rdb.QuoteSimpleColumnName(s) +} + +func (d *dbxBuilder) QueryBuilder() dbx.QueryBuilder { + return d.rdb.QueryBuilder() +} + +func (d *dbxBuilder) Transactional(f func(*dbx.Tx) error) (err error) { + return d.Builder.(*dbx.DB).Transactional(f) +} diff --git a/persistence/genre_repository_test.go b/persistence/genre_repository_test.go index 971bdfc1..3d6ae392 100644 --- a/persistence/genre_repository_test.go +++ b/persistence/genre_repository_test.go @@ -10,14 +10,13 @@ import ( "github.com/navidrome/navidrome/persistence" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/pocketbase/dbx" ) var _ = Describe("GenreRepository", func() { var repo model.GenreRepository BeforeEach(func() { - repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), dbx.NewFromDB(db.Db(), db.Driver)) + repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), persistence.NewDBXBuilder(db.Db())) }) Describe("GetAll()", func() { diff --git a/persistence/mediafile_repository_test.go b/persistence/mediafile_repository_test.go index f85275e6..ac017380 100644 --- a/persistence/mediafile_repository_test.go +++ b/persistence/mediafile_repository_test.go @@ -6,6 +6,7 @@ import ( "github.com/Masterminds/squirrel" "github.com/google/uuid" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" @@ -19,7 +20,7 @@ var _ = Describe("MediaRepository", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid"}) - mr = NewMediaFileRepository(ctx, getDBXBuilder()) + mr = NewMediaFileRepository(ctx, NewDBXBuilder(db.Db())) }) It("gets mediafile from the DB", func() { diff --git a/persistence/persistence.go b/persistence/persistence.go index cd446b2f..882f33da 100644 --- a/persistence/persistence.go +++ b/persistence/persistence.go @@ -2,7 +2,6 @@ package persistence import ( "context" - "database/sql" "reflect" "github.com/navidrome/navidrome/db" @@ -15,8 +14,8 @@ type SQLStore struct { db dbx.Builder } -func New(conn *sql.DB) model.DataStore { - return &SQLStore{db: dbx.NewFromDB(conn, db.Driver)} +func New(d db.DB) model.DataStore { + return &SQLStore{db: NewDBXBuilder(d)} } func (s *SQLStore) Album(ctx context.Context) model.AlbumRepository { @@ -106,14 +105,18 @@ func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRe return nil } +type transactional interface { + Transactional(f func(*dbx.Tx) error) (err error) +} + func (s *SQLStore) WithTx(block func(tx model.DataStore) error) error { - conn, ok := s.db.(*dbx.DB) - if !ok { - conn = dbx.NewFromDB(db.Db(), db.Driver) + // If we are already in a transaction, just pass it down + if conn, ok := s.db.(*dbx.Tx); ok { + return block(&SQLStore{db: conn}) } - return conn.Transactional(func(tx *dbx.Tx) error { - newDb := &SQLStore{db: tx} - return block(newDb) + + return s.db.(transactional).Transactional(func(tx *dbx.Tx) error { + return block(&SQLStore{db: tx}) }) } @@ -172,7 +175,7 @@ func (s *SQLStore) GC(ctx context.Context, rootFolder string) error { func (s *SQLStore) getDBXBuilder() dbx.Builder { if s.db == nil { - return dbx.NewFromDB(db.Db(), db.Driver) + return NewDBXBuilder(db.Db()) } return s.db } diff --git a/persistence/persistence_suite_test.go b/persistence/persistence_suite_test.go index c9ba6e0d..7741cec0 100644 --- a/persistence/persistence_suite_test.go +++ b/persistence/persistence_suite_test.go @@ -14,7 +14,6 @@ import ( "github.com/navidrome/navidrome/tests" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/pocketbase/dbx" ) func TestPersistence(t *testing.T) { @@ -29,10 +28,6 @@ func TestPersistence(t *testing.T) { RunSpecs(t, "Persistence Suite") } -func getDBXBuilder() *dbx.DB { - return dbx.NewFromDB(db.Db(), db.Driver) -} - var ( genreElectronic = model.Genre{ID: "gn-1", Name: "Electronic"} genreRock = model.Genre{ID: "gn-2", Name: "Rock"} @@ -95,7 +90,7 @@ func P(path string) string { // Initialize test DB // TODO Load this data setup from file(s) var _ = BeforeSuite(func() { - conn := getDBXBuilder() + conn := NewDBXBuilder(db.Db()) ctx := log.NewContext(context.TODO()) user := model.User{ID: "userid", UserName: "userid", IsAdmin: true} ctx = request.WithUser(ctx, user) diff --git a/persistence/playlist_repository_test.go b/persistence/playlist_repository_test.go index 3721b32e..86372167 100644 --- a/persistence/playlist_repository_test.go +++ b/persistence/playlist_repository_test.go @@ -3,6 +3,7 @@ package persistence import ( "context" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/criteria" @@ -17,7 +18,7 @@ var _ = Describe("PlaylistRepository", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true}) - repo = NewPlaylistRepository(ctx, getDBXBuilder()) + repo = NewPlaylistRepository(ctx, NewDBXBuilder(db.Db())) }) Describe("Count", func() { diff --git a/persistence/playqueue_repository_test.go b/persistence/playqueue_repository_test.go index 1d983694..434ee6a1 100644 --- a/persistence/playqueue_repository_test.go +++ b/persistence/playqueue_repository_test.go @@ -6,6 +6,7 @@ import ( "github.com/Masterminds/squirrel" "github.com/google/uuid" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" @@ -19,7 +20,7 @@ var _ = Describe("PlayQueueRepository", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true}) - repo = NewPlayQueueRepository(ctx, getDBXBuilder()) + repo = NewPlayQueueRepository(ctx, NewDBXBuilder(db.Db())) }) Describe("PlayQueues", func() { diff --git a/persistence/property_repository_test.go b/persistence/property_repository_test.go index 172a7b0d..edc38fc9 100644 --- a/persistence/property_repository_test.go +++ b/persistence/property_repository_test.go @@ -3,6 +3,7 @@ package persistence import ( "context" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" . "github.com/onsi/ginkgo/v2" @@ -13,7 +14,7 @@ var _ = Describe("Property Repository", func() { var pr model.PropertyRepository BeforeEach(func() { - pr = NewPropertyRepository(log.NewContext(context.TODO()), getDBXBuilder()) + pr = NewPropertyRepository(log.NewContext(context.TODO()), NewDBXBuilder(db.Db())) }) It("saves and restore a new property", func() { diff --git a/persistence/radio_repository_test.go b/persistence/radio_repository_test.go index 2d819929..87bdb84f 100644 --- a/persistence/radio_repository_test.go +++ b/persistence/radio_repository_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/deluan/rest" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" @@ -22,7 +23,7 @@ var _ = Describe("RadioRepository", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true}) - repo = NewRadioRepository(ctx, getDBXBuilder()) + repo = NewRadioRepository(ctx, NewDBXBuilder(db.Db())) _ = repo.Put(&radioWithHomePage) }) @@ -119,7 +120,7 @@ var _ = Describe("RadioRepository", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: false}) - repo = NewRadioRepository(ctx, getDBXBuilder()) + repo = NewRadioRepository(ctx, NewDBXBuilder(db.Db())) }) Describe("Count", func() { diff --git a/persistence/sql_bookmarks_test.go b/persistence/sql_bookmarks_test.go index c5447af2..07ad6146 100644 --- a/persistence/sql_bookmarks_test.go +++ b/persistence/sql_bookmarks_test.go @@ -3,6 +3,7 @@ package persistence import ( "context" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model/request" @@ -16,7 +17,7 @@ var _ = Describe("sqlBookmarks", func() { BeforeEach(func() { ctx := log.NewContext(context.TODO()) ctx = request.WithUser(ctx, model.User{ID: "userid"}) - mr = NewMediaFileRepository(ctx, getDBXBuilder()) + mr = NewMediaFileRepository(ctx, NewDBXBuilder(db.Db())) }) Describe("Bookmarks", func() { diff --git a/persistence/user_repository_test.go b/persistence/user_repository_test.go index 762c2127..7085e899 100644 --- a/persistence/user_repository_test.go +++ b/persistence/user_repository_test.go @@ -7,6 +7,7 @@ import ( "github.com/deluan/rest" "github.com/google/uuid" "github.com/navidrome/navidrome/consts" + "github.com/navidrome/navidrome/db" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/tests" @@ -18,7 +19,7 @@ var _ = Describe("UserRepository", func() { var repo model.UserRepository BeforeEach(func() { - repo = NewUserRepository(log.NewContext(context.TODO()), getDBXBuilder()) + repo = NewUserRepository(log.NewContext(context.TODO()), NewDBXBuilder(db.Db())) }) Describe("Put/Get/FindByUsername", func() { diff --git a/server/initial_setup.go b/server/initial_setup.go index b5533c6b..5f314218 100644 --- a/server/initial_setup.go +++ b/server/initial_setup.go @@ -17,22 +17,22 @@ import ( func initialSetup(ds model.DataStore) { ctx := context.TODO() _ = ds.WithTx(func(tx model.DataStore) error { - if err := ds.Library(ctx).StoreMusicFolder(); err != nil { + if err := tx.Library(ctx).StoreMusicFolder(); err != nil { return err } - properties := ds.Property(ctx) + properties := tx.Property(ctx) _, err := properties.Get(consts.InitialSetupFlagKey) if err == nil { return nil } log.Info("Running initial setup") - if err = createJWTSecret(ds); err != nil { + if err = createJWTSecret(tx); err != nil { return err } if conf.Server.DevAutoCreateAdminPassword != "" { - if err = createInitialAdminUser(ds, conf.Server.DevAutoCreateAdminPassword); err != nil { + if err = createInitialAdminUser(tx, conf.Server.DevAutoCreateAdminPassword); err != nil { return err } }