diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 65ddb435..8fc24198 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -31,16 +31,16 @@ import ( func CreateServer(musicFolder string) *server.Server { sqlDB := db.Db() dataStore := persistence.New(sqlDB) - serverServer := server.New(dataStore) + broker := events.GetBroker() + serverServer := server.New(dataStore, broker) return serverServer } func CreateNativeAPIRouter() *nativeapi.Router { sqlDB := db.Db() dataStore := persistence.New(sqlDB) - broker := events.GetBroker() share := core.NewShare(dataStore) - router := nativeapi.New(dataStore, broker, share) + router := nativeapi.New(dataStore, share) return router } diff --git a/reflex.conf b/reflex.conf index 8f5b1af3..c25b1662 100644 --- a/reflex.conf +++ b/reflex.conf @@ -1 +1 @@ --s -r "(\.go$$|\.cpp$$|\.h$$|navidrome.toml|resources|token_received.html)" -R "(^ui|^data|^db/migration)" -- go run -tags netgo . +-s -r "(\.go$$|\.cpp$$|\.h$$|navidrome.toml|resources|token_received.html)" -R "(^ui|^data|^db/migration)" -- go run -race -tags netgo . diff --git a/server/events/sse.go b/server/events/sse.go index 82aa690d..9b920033 100644 --- a/server/events/sse.go +++ b/server/events/sse.go @@ -3,7 +3,6 @@ package events import ( "context" - "errors" "fmt" "io" "net/http" @@ -93,38 +92,35 @@ func (b *broker) prepareMessage(ctx context.Context, event Event) message { return msg } -var errWriteTimeOut = errors.New("write timeout") - // writeEvent writes a message to the given io.Writer, formatted as a Server-Sent Event. // If the writer is an http.Flusher, it flushes the data immediately instead of buffering it. -// The function waits for the message to be written or times out after the specified timeout. -func writeEvent(w io.Writer, event message, timeout time.Duration) error { - // Create a context with a timeout based on the event's sender context. - ctx, cancel := context.WithTimeout(event.senderCtx, timeout) - defer cancel() +func writeEvent(ctx context.Context, w io.Writer, event message, timeout time.Duration) error { + if err := setWriteTimeout(w, timeout); err != nil { + log.Debug(ctx, "Error setting write timeout", err) + } - // Create a channel to signal the completion of writing. - errC := make(chan error, 1) - - // Start a goroutine to write the event and optionally flush the writer. - go func() { - _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data) - - // If the writer is an http.Flusher, flush the data immediately. - if flusher, ok := w.(http.Flusher); ok && flusher != nil { - flusher.Flush() - } - - // Signal that writing is complete. - errC <- err - }() - - // Wait for either the write completion or the context to time out. - select { - case err := <-errC: + _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data) + if err != nil { return err - case <-ctx.Done(): - return errWriteTimeOut + } + + // If the writer is an http.Flusher, flush the data immediately. + if flusher, ok := w.(http.Flusher); ok && flusher != nil { + flusher.Flush() + } + return nil +} + +func setWriteTimeout(rw io.Writer, timeout time.Duration) error { + for { + switch t := rw.(type) { + case interface{ SetWriteDeadline(time.Time) error }: + return t.SetWriteDeadline(time.Now().Add(timeout)) + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return fmt.Errorf("%T - %w", rw, http.ErrNotSupported) + } } } @@ -160,9 +156,9 @@ func (b *broker) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } log.Trace(ctx, "Sending event to client", "event", *event, "client", c.String()) - if err := writeEvent(w, *event, writeTimeOut); errors.Is(err, errWriteTimeOut) { - log.Debug(ctx, "Timeout sending event to client", "event", *event, "client", c.String()) - return + err := writeEvent(ctx, w, *event, writeTimeOut) + if err != nil { + log.Debug(ctx, "Error sending event to client", "event", *event, "client", c.String(), err) } } } diff --git a/server/events/sse_test.go b/server/events/sse_test.go index b2a37f0f..e6a44ca1 100644 --- a/server/events/sse_test.go +++ b/server/events/sse_test.go @@ -1,12 +1,7 @@ package events import ( - "bytes" "context" - "fmt" - "io" - "sync/atomic" - "time" "github.com/navidrome/navidrome/model/request" . "github.com/onsi/ginkgo/v2" @@ -63,126 +58,4 @@ var _ = Describe("Broker", func() { }) }) }) - - Describe("writeEvent", func() { - var ( - timeout time.Duration - buffer *bytes.Buffer - event message - senderCtx context.Context - cancel context.CancelFunc - ) - - BeforeEach(func() { - buffer = &bytes.Buffer{} - senderCtx, cancel = context.WithCancel(context.Background()) - DeferCleanup(cancel) - }) - - Context("with an HTTP flusher", func() { - var flusher *fakeFlusher - - BeforeEach(func() { - flusher = &fakeFlusher{Writer: buffer} - event = message{ - senderCtx: senderCtx, - id: 1, - event: "test", - data: "testdata", - } - }) - - Context("when the write completes before the timeout", func() { - BeforeEach(func() { - timeout = 1 * time.Second - }) - It("should successfully write the event", func() { - err := writeEvent(flusher, event, timeout) - Expect(err).NotTo(HaveOccurred()) - Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data))) - Expect(flusher.flushed.Load()).To(BeTrue()) - }) - }) - - Context("when the write does not complete before the timeout", func() { - BeforeEach(func() { - timeout = 1 * time.Millisecond - flusher.delay = 2 * time.Second - }) - - It("should return an errWriteTimeOut error", func() { - err := writeEvent(flusher, event, timeout) - Expect(err).To(MatchError(errWriteTimeOut)) - Expect(flusher.flushed.Load()).To(BeFalse()) - }) - }) - - Context("without an HTTP flusher", func() { - var writer *fakeWriter - - BeforeEach(func() { - writer = &fakeWriter{Writer: buffer} - event = message{ - senderCtx: senderCtx, - id: 1, - event: "test", - data: "testdata", - } - }) - - Context("when the write completes before the timeout", func() { - BeforeEach(func() { - timeout = 1 * time.Second - }) - - It("should successfully write the event", func() { - err := writeEvent(writer, event, timeout) - Expect(err).NotTo(HaveOccurred()) - Eventually(writer.done.Load).Should(BeTrue()) - Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data))) - }) - }) - - Context("when the write does not complete before the timeout", func() { - BeforeEach(func() { - timeout = 1 * time.Millisecond - writer.delay = 2 * time.Second - }) - - It("should return an errWriteTimeOut error", func() { - err := writeEvent(writer, event, timeout) - Expect(err).To(MatchError(errWriteTimeOut)) - Expect(writer.done.Load()).To(BeFalse()) - }) - }) - }) - }) - }) }) - -type fakeWriter struct { - io.Writer - delay time.Duration - done atomic.Bool -} - -func (f *fakeWriter) Write(p []byte) (n int, err error) { - time.Sleep(f.delay) - f.done.Store(true) - return f.Writer.Write(p) -} - -type fakeFlusher struct { - io.Writer - delay time.Duration - flushed atomic.Bool -} - -func (f *fakeFlusher) Write(p []byte) (n int, err error) { - time.Sleep(f.delay) - return f.Writer.Write(p) -} - -func (f *fakeFlusher) Flush() { - f.flushed.Store(true) -} diff --git a/server/nativeapi/native_api.go b/server/nativeapi/native_api.go index f3a58cf0..76ddbdf5 100644 --- a/server/nativeapi/native_api.go +++ b/server/nativeapi/native_api.go @@ -10,18 +10,16 @@ import ( "github.com/navidrome/navidrome/core" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/server" - "github.com/navidrome/navidrome/server/events" ) type Router struct { http.Handler - ds model.DataStore - broker events.Broker - share core.Share + ds model.DataStore + share core.Share } -func New(ds model.DataStore, broker events.Broker, share core.Share) *Router { - r := &Router{ds: ds, broker: broker, share: share} +func New(ds model.DataStore, share core.Share) *Router { + r := &Router{ds: ds, share: share} r.Handler = r.routes() return r } @@ -55,10 +53,6 @@ func (n *Router) routes() http.Handler { r.Get("/keepalive/*", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"response":"ok", "id":"keepalive"}`)) }) - - if conf.Server.DevActivityPanel { - r.Handle("/events", n.broker) - } }) return r diff --git a/server/server.go b/server/server.go index 2dad5a3f..153e08dd 100644 --- a/server/server.go +++ b/server/server.go @@ -19,21 +19,25 @@ import ( "github.com/navidrome/navidrome/core/auth" "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" + "github.com/navidrome/navidrome/server/events" "github.com/navidrome/navidrome/ui" . "github.com/navidrome/navidrome/utils/gg" ) type Server struct { - router *chi.Mux + router chi.Router ds model.DataStore appRoot string + broker events.Broker } -func New(ds model.DataStore) *Server { - s := &Server{ds: ds} +func New(ds model.DataStore, broker events.Broker) *Server { + s := &Server{ds: ds, broker: broker} auth.Init(s.ds) initialSetup(ds) s.initRoutes() + s.mountAuthenticationRoutes() + s.mountRootRedirector() checkFfmpegInstallation() checkExternalCredentials() return s @@ -131,24 +135,52 @@ func (s *Server) initRoutes() { r := chi.NewRouter() - r.Use(secureMiddleware()) - r.Use(corsHandler()) - r.Use(middleware.RequestID) - if conf.Server.ReverseProxyWhitelist == "" { - r.Use(middleware.RealIP) + middlewares := chi.Middlewares{ + secureMiddleware(), + corsHandler(), + middleware.RequestID, + } + if conf.Server.ReverseProxyWhitelist == "" { + middlewares = append(middlewares, middleware.RealIP) } - r.Use(middleware.Recoverer) - r.Use(compressMiddleware()) - r.Use(middleware.Heartbeat("/ping")) - r.Use(serverAddressMiddleware) - r.Use(clientUniqueIDMiddleware) - r.Use(loggerInjector) - r.Use(requestLogger) - r.Use(robotsTXT(ui.BuildAssets())) - r.Use(authHeaderMapper) - r.Use(jwtVerifier) - r.Route(path.Join(conf.Server.BasePath, "/auth"), func(r chi.Router) { + middlewares = append(middlewares, + middleware.Recoverer, + middleware.Heartbeat("/ping"), + robotsTXT(ui.BuildAssets()), + serverAddressMiddleware, + clientUniqueIDMiddleware, + ) + + // Mount the Native API /events endpoint with all middlewares, except the compress and request logger, + // adding the authentication middlewares + if conf.Server.DevActivityPanel { + r.Group(func(r chi.Router) { + r.Use(middlewares...) + r.Use(loggerInjector) + r.Use(authHeaderMapper) + r.Use(jwtVerifier) + r.Use(Authenticator(s.ds)) + r.Use(JWTRefresher) + r.Handle(path.Join(conf.Server.BasePath, consts.URLPathNativeAPI, "events"), s.broker) + }) + } + + // Configure the router with the default middlewares + r.Group(func(r chi.Router) { + r.Use(middlewares...) + r.Use(compressMiddleware()) + r.Use(loggerInjector) + r.Use(requestLogger) + r.Use(authHeaderMapper) + r.Use(jwtVerifier) + s.router = r + }) +} + +func (s *Server) mountAuthenticationRoutes() chi.Router { + r := s.router + return r.Route(path.Join(conf.Server.BasePath, "/auth"), func(r chi.Router) { if conf.Server.AuthRequestLimit > 0 { log.Info("Login rate limit set", "requestLimit", conf.Server.AuthRequestLimit, "windowLength", conf.Server.AuthWindowLength) @@ -162,7 +194,11 @@ func (s *Server) initRoutes() { } r.Post("/createAdmin", createAdmin(s.ds)) }) +} +// Serve UI app assets +func (s *Server) mountRootRedirector() { + r := s.router // Redirect root to UI URL r.Get("/*", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, s.appRoot+"/", http.StatusFound) @@ -170,11 +206,8 @@ func (s *Server) initRoutes() { r.Get(s.appRoot, func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, s.appRoot+"/", http.StatusFound) }) - - s.router = r } -// Serve UI app assets func (s *Server) frontendAssetsHandler() http.Handler { r := chi.NewRouter()