diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 09c5f361..092117c3 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -46,10 +46,10 @@ func CreateSubsonicAPIRouter() *subsonic.Router { sqlDB := db.Db() dataStore := persistence.New(sqlDB) fileCache := core.GetImageCache() - transcoderTranscoder := ffmpeg.New() - artwork := core.NewArtwork(dataStore, fileCache, transcoderTranscoder) + fFmpeg := ffmpeg.New() + artwork := core.NewArtwork(dataStore, fileCache, fFmpeg) transcodingCache := core.GetTranscodingCache() - mediaStreamer := core.NewMediaStreamer(dataStore, transcoderTranscoder, transcodingCache) + mediaStreamer := core.NewMediaStreamer(dataStore, fFmpeg, transcodingCache) archiver := core.NewArchiver(mediaStreamer, dataStore) players := core.NewPlayers(dataStore) agentsAgents := agents.New(dataStore) @@ -80,8 +80,12 @@ func createScanner() scanner.Scanner { sqlDB := db.Db() dataStore := persistence.New(sqlDB) playlists := core.NewPlaylists(dataStore) + fileCache := core.GetImageCache() + fFmpeg := ffmpeg.New() + artwork := core.NewArtwork(dataStore, fileCache, fFmpeg) + cacheWarmer := core.NewArtworkCacheWarmer(artwork) broker := events.GetBroker() - scannerScanner := scanner.New(dataStore, playlists, broker) + scannerScanner := scanner.New(dataStore, playlists, cacheWarmer, broker) return scannerScanner } diff --git a/core/artwork_cache_warmer.go b/core/artwork_cache_warmer.go new file mode 100644 index 00000000..ac01e3e4 --- /dev/null +++ b/core/artwork_cache_warmer.go @@ -0,0 +1,63 @@ +package core + +import ( + "context" + "fmt" + "io" + + "github.com/navidrome/navidrome/conf" + "github.com/navidrome/navidrome/log" + "github.com/navidrome/navidrome/model" + "github.com/navidrome/navidrome/utils/pl" +) + +type ArtworkCacheWarmer interface { + PreCache(artID model.ArtworkID) +} + +func NewArtworkCacheWarmer(artwork Artwork) ArtworkCacheWarmer { + // If image cache is disabled, return a NOOP implementation + if conf.Server.ImageCacheSize == "0" { + return &noopCacheWarmer{} + } + + a := &artworkCacheWarmer{ + artwork: artwork, + input: make(chan string), + } + go a.run(context.TODO()) + return a +} + +type artworkCacheWarmer struct { + artwork Artwork + input chan string +} + +func (a *artworkCacheWarmer) PreCache(artID model.ArtworkID) { + a.input <- artID.String() +} + +func (a *artworkCacheWarmer) run(ctx context.Context) { + errs := pl.Sink(ctx, 2, a.input, a.doCacheImage) + for err := range errs { + log.Warn(ctx, "Error warming cache", err) + } +} + +func (a *artworkCacheWarmer) doCacheImage(ctx context.Context, id string) error { + r, err := a.artwork.Get(ctx, id, 0) + if err != nil { + return fmt.Errorf("error cacheing id='%s': %w", id, err) + } + defer r.Close() + _, err = io.Copy(io.Discard, r) + if err != nil { + return err + } + return nil +} + +type noopCacheWarmer struct{} + +func (a *noopCacheWarmer) PreCache(id model.ArtworkID) {} diff --git a/core/wire_providers.go b/core/wire_providers.go index 50fef676..cd0586e2 100644 --- a/core/wire_providers.go +++ b/core/wire_providers.go @@ -20,4 +20,5 @@ var Set = wire.NewSet( scrobbler.GetPlayTracker, NewShare, NewPlaylists, + NewArtworkCacheWarmer, ) diff --git a/scanner/metadata/ffmpeg/ffmpeg.go b/scanner/metadata/ffmpeg/ffmpeg.go index 4ad768b0..bbccf164 100644 --- a/scanner/metadata/ffmpeg/ffmpeg.go +++ b/scanner/metadata/ffmpeg/ffmpeg.go @@ -74,7 +74,7 @@ var ( // Stream #0:0: Audio: mp3, 44100 Hz, stereo, fltp, 192 kb/s // Stream #0:0: Audio: flac, 44100 Hz, stereo, s16 - audioStreamRx = regexp.MustCompile(`^\s{2,4}Stream #\d+:\d+.*: Audio: (.*), (.* Hz), ([\w\.]+),*(.*.,)*`) + audioStreamRx = regexp.MustCompile(`^\s{2,4}Stream #\d+:\d+.*: Audio: (.*), (.* Hz), ([\w.]+),*(.*.,)*`) // Stream #0:1: Video: mjpeg, yuvj444p(pc, bt470bg/unknown/unknown), 600x600 [SAR 1:1 DAR 1:1], 90k tbr, 90k tbn, 90k tbc` coverRx = regexp.MustCompile(`^\s{2,4}Stream #\d+:.+: (Video):.*`) diff --git a/scanner/refresher.go b/scanner/refresher.go index 15d1186a..a3f7821a 100644 --- a/scanner/refresher.go +++ b/scanner/refresher.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Masterminds/squirrel" + "github.com/navidrome/navidrome/core" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/utils" @@ -20,20 +21,20 @@ import ( // // The actual mappings happen in MediaFiles.ToAlbum() and Albums.ToAlbumArtist() type refresher struct { - ctx context.Context - ds model.DataStore - album map[string]struct{} - artist map[string]struct{} - dirMap dirMap + ds model.DataStore + album map[string]struct{} + artist map[string]struct{} + dirMap dirMap + cacheWarmer core.ArtworkCacheWarmer } -func newRefresher(ctx context.Context, ds model.DataStore, dirMap dirMap) *refresher { +func newRefresher(ds model.DataStore, cw core.ArtworkCacheWarmer, dirMap dirMap) *refresher { return &refresher{ - ctx: ctx, - ds: ds, - album: map[string]struct{}{}, - artist: map[string]struct{}{}, - dirMap: dirMap, + ds: ds, + album: map[string]struct{}{}, + artist: map[string]struct{}{}, + dirMap: dirMap, + cacheWarmer: cw, } } @@ -46,21 +47,23 @@ func (r *refresher) accumulate(mf model.MediaFile) { } } -func (r *refresher) flush() error { - err := r.flushMap(r.album, "album", r.refreshAlbums) +func (r *refresher) flush(ctx context.Context) error { + err := r.flushMap(ctx, r.album, "album", r.refreshAlbums) if err != nil { return err } - err = r.flushMap(r.artist, "artist", r.refreshArtists) + err = r.flushMap(ctx, r.artist, "artist", r.refreshArtists) if err != nil { return err } + r.album = map[string]struct{}{} + r.artist = map[string]struct{}{} return nil } -type refreshCallbackFunc = func(ids ...string) error +type refreshCallbackFunc = func(ctx context.Context, ids ...string) error -func (r *refresher) flushMap(m map[string]struct{}, entity string, refresh refreshCallbackFunc) error { +func (r *refresher) flushMap(ctx context.Context, m map[string]struct{}, entity string, refresh refreshCallbackFunc) error { if len(m) == 0 { return nil } @@ -71,17 +74,17 @@ func (r *refresher) flushMap(m map[string]struct{}, entity string, refresh refre } chunks := utils.BreakUpStringSlice(ids, 100) for _, chunk := range chunks { - err := refresh(chunk...) + err := refresh(ctx, chunk...) if err != nil { - log.Error(r.ctx, fmt.Sprintf("Error writing %ss to the DB", entity), err) + log.Error(ctx, fmt.Sprintf("Error writing %ss to the DB", entity), err) return err } } return nil } -func (r *refresher) refreshAlbums(ids ...string) error { - mfs, err := r.ds.MediaFile(r.ctx).GetAll(model.QueryOptions{Filters: squirrel.Eq{"album_id": ids}}) +func (r *refresher) refreshAlbums(ctx context.Context, ids ...string) error { + mfs, err := r.ds.MediaFile(ctx).GetAll(model.QueryOptions{Filters: squirrel.Eq{"album_id": ids}}) if err != nil { return err } @@ -89,7 +92,7 @@ func (r *refresher) refreshAlbums(ids ...string) error { return nil } - repo := r.ds.Album(r.ctx) + repo := r.ds.Album(ctx) grouped := slice.Group(mfs, func(m model.MediaFile) string { return m.AlbumID }) for _, group := range grouped { songs := model.MediaFiles(group) @@ -103,6 +106,7 @@ func (r *refresher) refreshAlbums(ids ...string) error { if err != nil { return err } + r.cacheWarmer.PreCache(a.CoverArtID()) } return nil } @@ -122,8 +126,8 @@ func (r *refresher) getImageFiles(dirs []string) (string, time.Time) { return strings.Join(imageFiles, string(filepath.ListSeparator)), updatedAt } -func (r *refresher) refreshArtists(ids ...string) error { - albums, err := r.ds.Album(r.ctx).GetAll(model.QueryOptions{Filters: squirrel.Eq{"album_artist_id": ids}}) +func (r *refresher) refreshArtists(ctx context.Context, ids ...string) error { + albums, err := r.ds.Album(ctx).GetAll(model.QueryOptions{Filters: squirrel.Eq{"album_artist_id": ids}}) if err != nil { return err } @@ -131,7 +135,7 @@ func (r *refresher) refreshArtists(ids ...string) error { return nil } - repo := r.ds.Artist(r.ctx) + repo := r.ds.Artist(ctx) grouped := slice.Group(albums, func(al model.Album) string { return al.AlbumArtistID }) for _, group := range grouped { a := model.Albums(group).ToAlbumArtist() diff --git a/scanner/scanner.go b/scanner/scanner.go index 58841f90..65dc0c55 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -40,12 +40,13 @@ type FolderScanner interface { var isScanning sync.Mutex type scanner struct { - folders map[string]FolderScanner - status map[string]*scanStatus - lock *sync.RWMutex - ds model.DataStore - pls core.Playlists - broker events.Broker + folders map[string]FolderScanner + status map[string]*scanStatus + lock *sync.RWMutex + ds model.DataStore + pls core.Playlists + broker events.Broker + cacheWarmer core.ArtworkCacheWarmer } type scanStatus struct { @@ -55,14 +56,15 @@ type scanStatus struct { lastUpdate time.Time } -func New(ds model.DataStore, playlists core.Playlists, broker events.Broker) Scanner { +func New(ds model.DataStore, playlists core.Playlists, cacheWarmer core.ArtworkCacheWarmer, broker events.Broker) Scanner { s := &scanner{ - ds: ds, - pls: playlists, - broker: broker, - folders: map[string]FolderScanner{}, - status: map[string]*scanStatus{}, - lock: &sync.RWMutex{}, + ds: ds, + pls: playlists, + broker: broker, + folders: map[string]FolderScanner{}, + status: map[string]*scanStatus{}, + lock: &sync.RWMutex{}, + cacheWarmer: cacheWarmer, } s.loadFolders() return s @@ -242,5 +244,5 @@ func (s *scanner) loadFolders() { } func (s *scanner) newScanner(f model.MediaFolder) FolderScanner { - return NewTagScanner(f.Path, s.ds, s.pls) + return NewTagScanner(f.Path, s.ds, s.pls, s.cacheWarmer) } diff --git a/scanner/tag_scanner.go b/scanner/tag_scanner.go index 37168b2c..134a17e5 100644 --- a/scanner/tag_scanner.go +++ b/scanner/tag_scanner.go @@ -22,19 +22,23 @@ import ( ) type TagScanner struct { - rootFolder string - ds model.DataStore - plsSync *playlistImporter - cnt *counters - mapper *mediaFileMapper + rootFolder string + ds model.DataStore + plsSync *playlistImporter + cnt *counters + mapper *mediaFileMapper + cacheWarmer core.ArtworkCacheWarmer } -func NewTagScanner(rootFolder string, ds model.DataStore, playlists core.Playlists) *TagScanner { - return &TagScanner{ - rootFolder: rootFolder, - plsSync: newPlaylistImporter(ds, playlists, rootFolder), - ds: ds, +func NewTagScanner(rootFolder string, ds model.DataStore, playlists core.Playlists, cacheWarmer core.ArtworkCacheWarmer) FolderScanner { + s := &TagScanner{ + rootFolder: rootFolder, + plsSync: newPlaylistImporter(ds, playlists, rootFolder), + ds: ds, + cacheWarmer: cacheWarmer, } + + return s } type dirMap map[string]dirStats @@ -96,6 +100,7 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time, prog s.cnt = &counters{} genres := newCachedGenreRepository(ctx, s.ds.Genre(ctx)) s.mapper = newMediaFileMapper(s.rootFolder, genres) + refresher := newRefresher(s.ds, s.cacheWarmer, allFSDirs) foldersFound, walkerError := s.getRootFolderWalker(ctx) for { @@ -109,7 +114,7 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time, prog if s.folderHasChanged(folderStats, allDBDirs, lastModifiedSince) { changedDirs = append(changedDirs, folderStats.Path) log.Debug("Processing changed folder", "dir", folderStats.Path) - err := s.processChangedDir(ctx, allFSDirs, folderStats.Path, fullScan) + err := s.processChangedDir(ctx, refresher, fullScan, folderStats.Path) if err != nil { log.Error("Error updating folder in the DB", "dir", folderStats.Path, err) } @@ -128,7 +133,7 @@ func (s *TagScanner) Scan(ctx context.Context, lastModifiedSince time.Time, prog } for _, dir := range deletedDirs { - err := s.processDeletedDir(ctx, allFSDirs, dir) + err := s.processDeletedDir(ctx, refresher, dir) if err != nil { log.Error("Error removing deleted folder from DB", "dir", dir, err) } @@ -221,9 +226,8 @@ func (s *TagScanner) getDeletedDirs(ctx context.Context, fsDirs dirMap, dbDirs m return deleted } -func (s *TagScanner) processDeletedDir(ctx context.Context, allFSDirs dirMap, dir string) error { +func (s *TagScanner) processDeletedDir(ctx context.Context, refresher *refresher, dir string) error { start := time.Now() - buffer := newRefresher(ctx, s.ds, allFSDirs) mfs, err := s.ds.MediaFile(ctx).FindAllByPath(dir) if err != nil { @@ -237,17 +241,16 @@ func (s *TagScanner) processDeletedDir(ctx context.Context, allFSDirs dirMap, di s.cnt.deleted += c for _, t := range mfs { - buffer.accumulate(t) + refresher.accumulate(t) } - err = buffer.flush() + err = refresher.flush(ctx) log.Info(ctx, "Finished processing deleted folder", "dir", dir, "purged", len(mfs), "elapsed", time.Since(start)) return err } -func (s *TagScanner) processChangedDir(ctx context.Context, allFSDirs dirMap, dir string, fullScan bool) error { +func (s *TagScanner) processChangedDir(ctx context.Context, refresher *refresher, fullScan bool, dir string) error { start := time.Now() - buffer := newRefresher(ctx, s.ds, allFSDirs) // Load folder's current tracks from DB into a map currentTracks := map[string]model.MediaFile{} @@ -296,7 +299,7 @@ func (s *TagScanner) processChangedDir(ctx context.Context, allFSDirs dirMap, di } // Force a refresh of the album and artist, to cater for cover art files - buffer.accumulate(c) + refresher.accumulate(c) // Only leaves in orphanTracks the ones not found in the folder. After this loop any remaining orphanTracks // are considered gone from the music folder and will be deleted from DB @@ -307,33 +310,38 @@ func (s *TagScanner) processChangedDir(ctx context.Context, allFSDirs dirMap, di numPurgedTracks := 0 if len(filesToUpdate) > 0 { - numUpdatedTracks, err = s.addOrUpdateTracksInDB(ctx, dir, currentTracks, filesToUpdate, buffer) + numUpdatedTracks, err = s.addOrUpdateTracksInDB(ctx, refresher, dir, currentTracks, filesToUpdate) if err != nil { return err } } if len(orphanTracks) > 0 { - numPurgedTracks, err = s.deleteOrphanSongs(ctx, dir, orphanTracks, buffer) + numPurgedTracks, err = s.deleteOrphanSongs(ctx, refresher, dir, orphanTracks) if err != nil { return err } } - err = buffer.flush() + err = refresher.flush(ctx) log.Info(ctx, "Finished processing changed folder", "dir", dir, "updated", numUpdatedTracks, "deleted", numPurgedTracks, "elapsed", time.Since(start)) return err } -func (s *TagScanner) deleteOrphanSongs(ctx context.Context, dir string, tracksToDelete map[string]model.MediaFile, buffer *refresher) (int, error) { +func (s *TagScanner) deleteOrphanSongs( + ctx context.Context, + refresher *refresher, + dir string, + tracksToDelete map[string]model.MediaFile, +) (int, error) { numPurgedTracks := 0 log.Debug(ctx, "Deleting orphan tracks from DB", "dir", dir, "numTracks", len(tracksToDelete)) // Remaining tracks from DB that are not in the folder are deleted for _, ct := range tracksToDelete { numPurgedTracks++ - buffer.accumulate(ct) + refresher.accumulate(ct) if err := s.ds.MediaFile(ctx).Delete(ct.ID); err != nil { return 0, err } @@ -342,7 +350,13 @@ func (s *TagScanner) deleteOrphanSongs(ctx context.Context, dir string, tracksTo return numPurgedTracks, nil } -func (s *TagScanner) addOrUpdateTracksInDB(ctx context.Context, dir string, currentTracks map[string]model.MediaFile, filesToUpdate []string, buffer *refresher) (int, error) { +func (s *TagScanner) addOrUpdateTracksInDB( + ctx context.Context, + refresher *refresher, + dir string, + currentTracks map[string]model.MediaFile, + filesToUpdate []string, +) (int, error) { numUpdatedTracks := 0 log.Trace(ctx, "Updating mediaFiles in DB", "dir", dir, "numFiles", len(filesToUpdate)) @@ -367,7 +381,7 @@ func (s *TagScanner) addOrUpdateTracksInDB(ctx context.Context, dir string, curr if err != nil { return 0, err } - buffer.accumulate(n) + refresher.accumulate(n) numUpdatedTracks++ } } diff --git a/utils/pl/pipelines.go b/utils/pl/pipelines.go new file mode 100644 index 00000000..ed85c6b9 --- /dev/null +++ b/utils/pl/pipelines.go @@ -0,0 +1,176 @@ +// Package pl implements some Data Pipeline helper functions. +// Reference: https://medium.com/amboss/applying-modern-go-concurrency-patterns-to-data-pipelines-b3b5327908d4#3a80 +// +// See also: +// +// https://www.oreilly.com/library/view/concurrency-in-go/9781491941294/ch04.html#fano_fani +// https://www.youtube.com/watch?v=f6kdp27TYZs +// https://www.youtube.com/watch?v=QDDwwePbDtw +package pl + +import ( + "context" + "errors" + "sync" + + "github.com/navidrome/navidrome/log" + "golang.org/x/sync/semaphore" +) + +func Stage[In any, Out any]( + ctx context.Context, + maxWorkers int, + inputChannel <-chan In, + fn func(context.Context, In) (Out, error), +) (chan Out, chan error) { + outputChannel := make(chan Out) + errorChannel := make(chan error) + + limit := int64(maxWorkers) + sem1 := semaphore.NewWeighted(limit) + + go func() { + defer close(outputChannel) + defer close(errorChannel) + + for s := range ReadOrDone(ctx, inputChannel) { + if err := sem1.Acquire(ctx, 1); err != nil { + if !errors.Is(err, context.Canceled) { + log.Error(ctx, "Failed to acquire semaphore", err) + } + break + } + + go func(s In) { + defer sem1.Release(1) + + result, err := fn(ctx, s) + if err != nil { + if !errors.Is(err, context.Canceled) { + errorChannel <- err + } + } else { + outputChannel <- result + } + }(s) + } + + // By using context.Background() here we are assuming the fn will stop when the context + // is canceled. This is required so we can wait for the workers to finish and avoid closing + // the outputChannel before they are done. + if err := sem1.Acquire(context.Background(), limit); err != nil { + log.Error(ctx, "Failed waiting for workers", err) + } + }() + + return outputChannel, errorChannel +} + +func Sink[In any]( + ctx context.Context, + maxWorkers int, + inputChannel <-chan In, + fn func(context.Context, In) error, +) chan error { + results, errC := Stage(ctx, maxWorkers, inputChannel, func(ctx context.Context, in In) (bool, error) { + err := fn(ctx, in) + return false, err // Only err is important, results will be discarded + }) + + // Discard results + go func() { + for range ReadOrDone(ctx, results) { + } + }() + + return errC +} + +func Merge[T any](ctx context.Context, cs ...<-chan T) <-chan T { + var wg sync.WaitGroup + out := make(chan T) + + output := func(c <-chan T) { + defer wg.Done() + for v := range ReadOrDone(ctx, c) { + select { + case out <- v: + case <-ctx.Done(): + return + } + } + } + + wg.Add(len(cs)) + for _, c := range cs { + go output(c) + } + + go func() { + wg.Wait() + close(out) + }() + + return out +} + +func SendOrDone[T any](ctx context.Context, out chan<- T, v T) { + select { + case out <- v: + case <-ctx.Done(): + return + } +} + +func ReadOrDone[T any](ctx context.Context, in <-chan T) <-chan T { + valStream := make(chan T) + go func() { + defer close(valStream) + for { + select { + case <-ctx.Done(): + return + case v, ok := <-in: + if !ok { + return + } + select { + case valStream <- v: + case <-ctx.Done(): + } + } + } + }() + return valStream +} + +func Tee[T any](ctx context.Context, in <-chan T) (<-chan T, <-chan T) { + out1 := make(chan T) + out2 := make(chan T) + go func() { + defer close(out1) + defer close(out2) + for val := range ReadOrDone(ctx, in) { + var out1, out2 = out1, out2 + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + case out1 <- val: + out1 = nil + case out2 <- val: + out2 = nil + } + } + } + }() + return out1, out2 +} + +func FromSlice[T any](ctx context.Context, in []T) <-chan T { + output := make(chan T, len(in)) + for _, c := range in { + output <- c + } + close(output) + return output +} diff --git a/utils/pl/pipelines_test.go b/utils/pl/pipelines_test.go new file mode 100644 index 00000000..f5da6e49 --- /dev/null +++ b/utils/pl/pipelines_test.go @@ -0,0 +1,168 @@ +package pl_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/navidrome/navidrome/utils/pl" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestPipeline(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Pipeline Tests Suite") +} + +var _ = Describe("Pipeline", func() { + Describe("Stage", func() { + Context("happy path", func() { + It("calls the 'transform' function and returns values and errors", func() { + inC := make(chan int, 4) + for i := 0; i < 4; i++ { + inC <- i + } + close(inC) + + outC, errC := pl.Stage(context.Background(), 1, inC, func(ctx context.Context, i int) (int, error) { + if i%2 == 0 { + return 0, errors.New("even number") + } + return i * 2, nil + }) + + Expect(<-errC).To(MatchError("even number")) + Expect(<-outC).To(Equal(2)) + Expect(<-errC).To(MatchError("even number")) + Expect(<-outC).To(Equal(6)) + + Eventually(outC).Should(BeClosed()) + Eventually(errC).Should(BeClosed()) + }) + }) + Context("Multiple workers", func() { + const maxWorkers = 2 + const numJobs = 100 + It("starts multiple workers, respecting the limit", func() { + inC := make(chan int, numJobs) + for i := 0; i < numJobs; i++ { + inC <- i + } + close(inC) + + current := atomic.Int32{} + count := atomic.Int32{} + max := atomic.Int32{} + outC, _ := pl.Stage(context.Background(), maxWorkers, inC, func(ctx context.Context, in int) (int, error) { + defer current.Add(-1) + c := current.Add(1) + count.Add(1) + if c > max.Load() { + max.Store(c) + } + time.Sleep(10 * time.Millisecond) // Slow process + return 0, nil + }) + // Discard output and wait for completion + for range outC { + } + + Expect(count.Load()).To(Equal(int32(numJobs))) + Expect(current.Load()).To(Equal(int32(0))) + Expect(max.Load()).To(Equal(int32(maxWorkers))) + }) + }) + When("the context is canceled", func() { + It("closes its output", func() { + ctx, cancel := context.WithCancel(context.Background()) + inC := make(chan int) + outC, errC := pl.Stage(ctx, 1, inC, func(ctx context.Context, i int) (int, error) { + return i, nil + }) + cancel() + Eventually(outC).Should(BeClosed()) + Eventually(errC).Should(BeClosed()) + }) + }) + + }) + Describe("Merge", func() { + var in1, in2 chan int + BeforeEach(func() { + in1 = make(chan int, 4) + in2 = make(chan int, 4) + for i := 0; i < 4; i++ { + in1 <- i + in2 <- i + 4 + } + close(in1) + close(in2) + }) + When("ranging through the output channel", func() { + It("copies values from all input channels to its output channel", func() { + var values []int + for v := range pl.Merge(context.Background(), in1, in2) { + values = append(values, v) + } + + Expect(values).To(ConsistOf(0, 1, 2, 3, 4, 5, 6, 7)) + }) + }) + When("there's a blocked channel and the context is closed", func() { + It("closes its output", func() { + ctx, cancel := context.WithCancel(context.Background()) + in3 := make(chan int) + out := pl.Merge(ctx, in1, in2, in3) + cancel() + Eventually(out).Should(BeClosed()) + }) + }) + }) + Describe("ReadOrDone", func() { + When("values are sent", func() { + It("copies them to its output channel", func() { + in := make(chan int) + out := pl.ReadOrDone(context.Background(), in) + for i := 0; i < 4; i++ { + in <- i + j := <-out + Expect(i).To(Equal(j)) + } + close(in) + Eventually(out).Should(BeClosed()) + }) + }) + When("the context is canceled", func() { + It("closes its output", func() { + ctx, cancel := context.WithCancel(context.Background()) + in := make(chan int) + out := pl.ReadOrDone(ctx, in) + cancel() + Eventually(out).Should(BeClosed()) + }) + }) + }) + Describe("SendOrDone", func() { + When("out is unblocked", func() { + It("puts the value in the channel", func() { + out := make(chan int) + value := 1234 + go pl.SendOrDone(context.Background(), out, value) + Eventually(out).Should(Receive(&value)) + }) + }) + When("out is blocked", func() { + It("can be canceled by the context", func() { + ctx, cancel := context.WithCancel(context.Background()) + out := make(chan int) + go pl.SendOrDone(ctx, out, 1234) + cancel() + + Consistently(out).ShouldNot(Receive()) + }) + }) + }) +}) diff --git a/utils/pool/pool.go b/utils/pool/pool.go deleted file mode 100644 index 09ceaeff..00000000 --- a/utils/pool/pool.go +++ /dev/null @@ -1,90 +0,0 @@ -package pool - -import ( - "time" - - "github.com/navidrome/navidrome/log" -) - -type Executor func(workload interface{}) - -type Pool struct { - name string - workers []worker - exec Executor - queue chan work // receives jobs to send to workers - done chan bool // when receives bool stops workers - working bool -} - -// TODO This hardcoded value will go away when the queue is persisted in disk -const bufferSize = 10000 - -func NewPool(name string, workerCount int, exec Executor) (*Pool, error) { - p := &Pool{ - name: name, - exec: exec, - queue: make(chan work, bufferSize), - done: make(chan bool), - working: false, - } - - for i := 0; i < workerCount; i++ { - worker := worker{ - p: p, - id: i, - } - worker.Start() - p.workers = append(p.workers, worker) - } - - go func() { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if len(p.queue) > 0 { - log.Debug("Queue status", "poolName", p.name, "items", len(p.queue)) - } else { - if p.working { - log.Info("Queue is empty, all items processed", "poolName", p.name) - } - p.working = false - } - case <-p.done: - close(p.queue) - return - } - } - }() - - return p, nil -} - -func (p *Pool) Submit(workload interface{}) { - p.working = true - p.queue <- work{workload} -} - -func (p *Pool) Stop() { - p.done <- true -} - -type work struct { - workload interface{} -} - -type worker struct { - id int - p *Pool -} - -// start worker -func (w *worker) Start() { - go func() { - for job := range w.p.queue { - w.p.exec(job.workload) // do work - } - }() -} diff --git a/utils/pool/pool_test.go b/utils/pool/pool_test.go deleted file mode 100644 index c6c59e7e..00000000 --- a/utils/pool/pool_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package pool - -import ( - "sync" - "testing" - - "github.com/navidrome/navidrome/log" - "github.com/navidrome/navidrome/tests" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -func TestPool(t *testing.T) { - tests.Init(t, false) - log.SetLevel(log.LevelFatal) - RegisterFailHandler(Fail) - RunSpecs(t, "Pool Suite") -} - -type testItem struct { - ID int -} - -var ( - processed []int - mutex sync.RWMutex -) - -var _ = Describe("Pool", func() { - var pool *Pool - - BeforeEach(func() { - processed = nil - pool, _ = NewPool("test", 2, execute) - }) - - It("processes items", func() { - for i := 0; i < 5; i++ { - pool.Submit(&testItem{ID: i}) - } - Eventually(func() []int { - mutex.RLock() - defer mutex.RUnlock() - return processed - }, "10s").Should(HaveLen(5)) - - Expect(processed).To(ContainElements(0, 1, 2, 3, 4)) - }) -}) - -func execute(workload interface{}) { - mutex.Lock() - defer mutex.Unlock() - item := workload.(*testItem) - processed = append(processed, item.ID) -}