2020-01-15 00:23:29 +01:00
|
|
|
package persistence
|
2020-01-12 23:32:06 +01:00
|
|
|
|
|
|
|
import (
|
2020-01-28 14:22:17 +01:00
|
|
|
"context"
|
2023-12-09 19:52:17 +01:00
|
|
|
"database/sql"
|
2022-10-01 00:54:25 +02:00
|
|
|
"errors"
|
2020-01-28 14:22:17 +01:00
|
|
|
"fmt"
|
|
|
|
"strings"
|
2020-02-05 14:47:32 +01:00
|
|
|
"time"
|
2020-01-28 14:22:17 +01:00
|
|
|
|
|
|
|
. "github.com/Masterminds/squirrel"
|
2020-01-31 15:53:19 +01:00
|
|
|
"github.com/google/uuid"
|
2023-11-27 19:06:23 +01:00
|
|
|
"github.com/navidrome/navidrome/conf"
|
2020-01-28 14:22:17 +01:00
|
|
|
"github.com/navidrome/navidrome/log"
|
2020-01-24 01:44:08 +01:00
|
|
|
"github.com/navidrome/navidrome/model"
|
2020-05-13 22:49:55 +02:00
|
|
|
"github.com/navidrome/navidrome/model/request"
|
2023-12-09 19:52:17 +01:00
|
|
|
"github.com/pocketbase/dbx"
|
2020-01-12 23:32:06 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
type sqlRepository struct {
|
2020-03-22 01:00:46 +01:00
|
|
|
ctx context.Context
|
|
|
|
tableName string
|
2023-12-09 19:52:17 +01:00
|
|
|
db dbx.Builder
|
2020-03-22 01:00:46 +01:00
|
|
|
sortMappings map[string]string
|
2020-01-12 23:32:06 +01:00
|
|
|
}
|
|
|
|
|
2020-01-28 14:22:17 +01:00
|
|
|
const invalidUserId = "-1"
|
|
|
|
|
|
|
|
func userId(ctx context.Context) string {
|
2020-05-13 22:49:55 +02:00
|
|
|
if user, ok := request.UserFrom(ctx); !ok {
|
2020-01-28 14:22:17 +01:00
|
|
|
return invalidUserId
|
2020-05-13 22:49:55 +02:00
|
|
|
} else {
|
|
|
|
return user.ID
|
2020-01-13 00:36:19 +01:00
|
|
|
}
|
2020-01-28 14:22:17 +01:00
|
|
|
}
|
|
|
|
|
2020-02-06 04:22:44 +01:00
|
|
|
func loggedUser(ctx context.Context) *model.User {
|
2020-05-13 22:49:55 +02:00
|
|
|
if user, ok := request.UserFrom(ctx); !ok {
|
2020-02-06 04:22:44 +01:00
|
|
|
return &model.User{}
|
2020-05-13 22:49:55 +02:00
|
|
|
} else {
|
|
|
|
return &user
|
2020-02-06 04:22:44 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-01-31 22:03:30 +01:00
|
|
|
func (r sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
|
2020-01-28 14:22:17 +01:00
|
|
|
sq := Select().From(r.tableName)
|
|
|
|
sq = r.applyOptions(sq, options...)
|
2020-01-31 20:52:06 +01:00
|
|
|
sq = r.applyFilters(sq, options...)
|
2020-01-28 14:22:17 +01:00
|
|
|
return sq
|
2020-01-12 23:32:06 +01:00
|
|
|
}
|
|
|
|
|
2020-01-31 22:03:30 +01:00
|
|
|
func (r sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
|
2020-01-21 14:49:43 +01:00
|
|
|
if len(options) > 0 {
|
|
|
|
if options[0].Max > 0 {
|
|
|
|
sq = sq.Limit(uint64(options[0].Max))
|
|
|
|
}
|
|
|
|
if options[0].Offset > 0 {
|
2020-01-31 20:52:06 +01:00
|
|
|
sq = sq.Offset(uint64(options[0].Offset))
|
2020-01-21 14:49:43 +01:00
|
|
|
}
|
|
|
|
if options[0].Sort != "" {
|
2020-12-23 17:37:38 +01:00
|
|
|
sq = sq.OrderBy(r.buildSortOrder(options[0].Sort, options[0].Order))
|
2020-01-21 14:49:43 +01:00
|
|
|
}
|
2020-01-31 20:52:06 +01:00
|
|
|
}
|
|
|
|
return sq
|
|
|
|
}
|
|
|
|
|
2020-12-23 17:37:38 +01:00
|
|
|
func (r sqlRepository) buildSortOrder(sort, order string) string {
|
|
|
|
if mapping, ok := r.sortMappings[sort]; ok {
|
|
|
|
sort = mapping
|
|
|
|
}
|
|
|
|
|
|
|
|
sort = toSnakeCase(sort)
|
|
|
|
order = strings.ToLower(strings.TrimSpace(order))
|
|
|
|
var reverseOrder string
|
|
|
|
if order == "desc" {
|
|
|
|
reverseOrder = "asc"
|
|
|
|
} else {
|
|
|
|
order = "asc"
|
|
|
|
reverseOrder = "desc"
|
|
|
|
}
|
|
|
|
|
|
|
|
var newSort []string
|
2021-05-28 23:35:32 +02:00
|
|
|
parts := strings.FieldsFunc(sort, splitFunc(','))
|
2020-12-23 17:37:38 +01:00
|
|
|
for _, p := range parts {
|
2021-05-28 23:35:32 +02:00
|
|
|
f := strings.FieldsFunc(p, splitFunc(' '))
|
2020-12-23 17:37:38 +01:00
|
|
|
newField := []string{f[0]}
|
|
|
|
if len(f) == 1 {
|
|
|
|
newField = append(newField, order)
|
|
|
|
} else {
|
|
|
|
if f[1] == "asc" {
|
|
|
|
newField = append(newField, order)
|
|
|
|
} else {
|
|
|
|
newField = append(newField, reverseOrder)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
newSort = append(newSort, strings.Join(newField, " "))
|
|
|
|
}
|
|
|
|
return strings.Join(newSort, ", ")
|
|
|
|
}
|
|
|
|
|
2021-05-28 23:35:32 +02:00
|
|
|
func splitFunc(delimiter rune) func(c rune) bool {
|
2023-05-23 05:49:50 +02:00
|
|
|
open := 0
|
2021-05-28 23:35:32 +02:00
|
|
|
return func(c rune) bool {
|
2023-05-23 05:49:50 +02:00
|
|
|
if c == '(' {
|
|
|
|
open++
|
2021-05-28 23:35:32 +02:00
|
|
|
return false
|
|
|
|
}
|
2023-05-23 05:49:50 +02:00
|
|
|
if open > 0 {
|
|
|
|
if c == ')' {
|
|
|
|
open--
|
|
|
|
}
|
2021-05-28 23:35:32 +02:00
|
|
|
return false
|
|
|
|
}
|
|
|
|
return c == delimiter
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-01-31 22:03:30 +01:00
|
|
|
func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
|
2020-01-31 20:52:06 +01:00
|
|
|
if len(options) > 0 && options[0].Filters != nil {
|
|
|
|
sq = sq.Where(options[0].Filters)
|
2020-01-21 14:49:43 +01:00
|
|
|
}
|
|
|
|
return sq
|
|
|
|
}
|
|
|
|
|
2020-01-28 14:22:17 +01:00
|
|
|
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
|
2023-12-09 19:52:17 +01:00
|
|
|
query, args, err := r.toSQL(sq)
|
2020-01-28 14:22:17 +01:00
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
2020-02-05 14:47:32 +01:00
|
|
|
start := time.Now()
|
2020-02-28 17:02:38 +01:00
|
|
|
var c int64
|
2023-12-14 23:13:09 +01:00
|
|
|
res, err := r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Execute()
|
2020-02-28 17:02:38 +01:00
|
|
|
if res != nil {
|
|
|
|
c, _ = res.RowsAffected()
|
|
|
|
}
|
2020-02-05 14:47:32 +01:00
|
|
|
r.logSQL(query, args, err, c, start)
|
2020-01-28 14:22:17 +01:00
|
|
|
if err != nil {
|
|
|
|
if err.Error() != "LastInsertId is not supported by this driver" {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return res.RowsAffected()
|
2020-01-12 23:32:06 +01:00
|
|
|
}
|
|
|
|
|
2023-12-09 19:52:17 +01:00
|
|
|
func (r sqlRepository) toSQL(sq Sqlizer) (string, dbx.Params, error) {
|
|
|
|
query, args, err := sq.ToSql()
|
|
|
|
if err != nil {
|
|
|
|
return "", nil, err
|
|
|
|
}
|
|
|
|
// Replace query placeholders with named params
|
|
|
|
params := dbx.Params{}
|
|
|
|
for i, arg := range args {
|
|
|
|
p := fmt.Sprintf("p%d", i)
|
|
|
|
query = strings.Replace(query, "?", "{:"+p+"}", 1)
|
|
|
|
params[p] = arg
|
|
|
|
}
|
|
|
|
return query, params, nil
|
|
|
|
}
|
|
|
|
|
2020-01-28 14:22:17 +01:00
|
|
|
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
|
2023-12-09 19:52:17 +01:00
|
|
|
query, args, err := r.toSQL(sq)
|
2020-01-28 14:22:17 +01:00
|
|
|
if err != nil {
|
2020-01-15 01:20:47 +01:00
|
|
|
return err
|
|
|
|
}
|
2020-02-05 14:47:32 +01:00
|
|
|
start := time.Now()
|
2023-12-14 23:13:09 +01:00
|
|
|
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).One(response)
|
2023-12-09 19:52:17 +01:00
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
2021-06-24 06:01:05 +02:00
|
|
|
r.logSQL(query, args, nil, 0, start)
|
2020-01-28 14:22:17 +01:00
|
|
|
return model.ErrNotFound
|
|
|
|
}
|
2020-02-05 14:47:32 +01:00
|
|
|
r.logSQL(query, args, err, 1, start)
|
2020-01-28 14:22:17 +01:00
|
|
|
return err
|
2020-01-15 01:20:47 +01:00
|
|
|
}
|
|
|
|
|
2023-11-27 19:06:23 +01:00
|
|
|
func (r sqlRepository) queryAll(sq SelectBuilder, response interface{}, options ...model.QueryOptions) error {
|
|
|
|
if len(options) > 0 && options[0].Offset > 0 {
|
|
|
|
sq = r.optimizePagination(sq, options[0])
|
|
|
|
}
|
2023-12-09 19:52:17 +01:00
|
|
|
query, args, err := r.toSQL(sq)
|
2020-01-13 21:28:55 +01:00
|
|
|
if err != nil {
|
2020-01-12 23:32:06 +01:00
|
|
|
return err
|
2020-01-13 21:28:55 +01:00
|
|
|
}
|
2020-02-05 14:47:32 +01:00
|
|
|
start := time.Now()
|
2023-12-14 23:13:09 +01:00
|
|
|
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).All(response)
|
2023-12-09 19:52:17 +01:00
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
|
|
r.logSQL(query, args, nil, -1, start)
|
2020-01-28 14:22:17 +01:00
|
|
|
return model.ErrNotFound
|
2020-01-13 21:28:55 +01:00
|
|
|
}
|
2023-12-09 19:52:17 +01:00
|
|
|
r.logSQL(query, args, err, -1, start)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// queryAllSlice is a helper function to query a single column and return the result in a slice
|
|
|
|
func (r sqlRepository) queryAllSlice(sq SelectBuilder, response interface{}) error {
|
|
|
|
query, args, err := r.toSQL(sq)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
start := time.Now()
|
2023-12-14 23:13:09 +01:00
|
|
|
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Column(response)
|
2023-12-09 19:52:17 +01:00
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
|
|
r.logSQL(query, args, nil, -1, start)
|
|
|
|
return model.ErrNotFound
|
|
|
|
}
|
|
|
|
r.logSQL(query, args, err, -1, start)
|
2020-01-13 21:28:55 +01:00
|
|
|
return err
|
2020-01-12 23:32:06 +01:00
|
|
|
}
|
|
|
|
|
2023-11-27 19:06:23 +01:00
|
|
|
// optimizePagination uses a less inefficient pagination, by not using OFFSET.
|
|
|
|
// See https://gist.github.com/ssokolow/262503
|
|
|
|
func (r sqlRepository) optimizePagination(sq SelectBuilder, options model.QueryOptions) SelectBuilder {
|
|
|
|
if options.Offset > conf.Server.DevOffsetOptimize {
|
|
|
|
sq = sq.RemoveOffset()
|
|
|
|
oidSq := sq.RemoveColumns().Columns(r.tableName + ".oid")
|
|
|
|
oidSq = oidSq.Limit(uint64(options.Offset))
|
|
|
|
oidSql, args, _ := oidSq.ToSql()
|
|
|
|
sq = sq.Where(r.tableName+".oid not in ("+oidSql+")", args...)
|
|
|
|
}
|
|
|
|
return sq
|
|
|
|
}
|
|
|
|
|
2020-01-28 14:22:17 +01:00
|
|
|
func (r sqlRepository) exists(existsQuery SelectBuilder) (bool, error) {
|
2020-02-03 03:29:27 +01:00
|
|
|
existsQuery = existsQuery.Columns("count(*) as exist").From(r.tableName)
|
|
|
|
var res struct{ Exist int64 }
|
|
|
|
err := r.queryOne(existsQuery, &res)
|
|
|
|
return res.Exist > 0, err
|
2020-01-13 06:04:11 +01:00
|
|
|
}
|
|
|
|
|
2020-01-28 14:22:17 +01:00
|
|
|
func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOptions) (int64, error) {
|
2023-11-26 04:29:05 +01:00
|
|
|
countQuery = countQuery.
|
2023-11-26 05:08:20 +01:00
|
|
|
RemoveColumns().Columns("count(distinct " + r.tableName + ".id) as count").
|
2024-01-21 04:02:05 +01:00
|
|
|
RemoveOffset().RemoveLimit().
|
2023-11-26 05:08:20 +01:00
|
|
|
From(r.tableName)
|
2020-01-31 20:52:06 +01:00
|
|
|
countQuery = r.applyFilters(countQuery, options...)
|
2020-01-28 14:22:17 +01:00
|
|
|
var res struct{ Count int64 }
|
2020-02-01 20:48:22 +01:00
|
|
|
err := r.queryOne(countQuery, &res)
|
|
|
|
return res.Count, err
|
2020-01-28 14:22:17 +01:00
|
|
|
}
|
2020-01-13 06:04:11 +01:00
|
|
|
|
2021-06-03 00:40:29 +02:00
|
|
|
func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) {
|
2023-12-09 19:52:17 +01:00
|
|
|
values, _ := toSQLArgs(m)
|
2021-06-03 00:40:29 +02:00
|
|
|
// If there's an ID, try to update first
|
2020-01-31 15:53:19 +01:00
|
|
|
if id != "" {
|
2021-06-03 00:40:29 +02:00
|
|
|
updateValues := map[string]interface{}{}
|
2021-11-01 18:55:47 +01:00
|
|
|
|
|
|
|
// This is a map of the columns that need to be updated, if specified
|
|
|
|
c2upd := map[string]struct{}{}
|
|
|
|
for _, c := range colsToUpdate {
|
|
|
|
c2upd[toSnakeCase(c)] = struct{}{}
|
|
|
|
}
|
2021-06-03 00:40:29 +02:00
|
|
|
for k, v := range values {
|
2021-11-01 18:55:47 +01:00
|
|
|
if _, found := c2upd[k]; len(c2upd) == 0 || found {
|
2021-06-03 00:40:29 +02:00
|
|
|
updateValues[k] = v
|
|
|
|
}
|
|
|
|
}
|
2021-11-01 18:55:47 +01:00
|
|
|
|
2021-06-03 00:40:29 +02:00
|
|
|
delete(updateValues, "created_at")
|
|
|
|
update := Update(r.tableName).Where(Eq{"id": id}).SetMap(updateValues)
|
2020-01-31 15:53:19 +01:00
|
|
|
count, err := r.executeSQL(update)
|
|
|
|
if err != nil {
|
2020-01-31 21:35:06 +01:00
|
|
|
return "", err
|
2020-01-31 15:53:19 +01:00
|
|
|
}
|
|
|
|
if count > 0 {
|
2021-06-03 00:40:29 +02:00
|
|
|
return id, nil
|
2020-01-31 15:53:19 +01:00
|
|
|
}
|
|
|
|
}
|
2021-11-01 18:55:47 +01:00
|
|
|
// If it does not have an ID OR the ID was not found (when it is a new record with predefined id)
|
2020-01-31 15:53:19 +01:00
|
|
|
if id == "" {
|
2021-02-01 07:22:31 +01:00
|
|
|
id = uuid.NewString()
|
2020-01-31 21:35:06 +01:00
|
|
|
values["id"] = id
|
2020-01-31 15:53:19 +01:00
|
|
|
}
|
|
|
|
insert := Insert(r.tableName).SetMap(values)
|
2020-01-31 21:35:06 +01:00
|
|
|
_, err = r.executeSQL(insert)
|
|
|
|
return id, err
|
2020-01-31 15:53:19 +01:00
|
|
|
}
|
|
|
|
|
2020-01-28 14:22:17 +01:00
|
|
|
func (r sqlRepository) delete(cond Sqlizer) error {
|
|
|
|
del := Delete(r.tableName).Where(cond)
|
|
|
|
_, err := r.executeSQL(del)
|
2023-12-09 19:52:17 +01:00
|
|
|
if errors.Is(err, sql.ErrNoRows) {
|
2020-01-28 14:22:17 +01:00
|
|
|
return model.ErrNotFound
|
2020-01-13 06:04:11 +01:00
|
|
|
}
|
2020-01-28 14:22:17 +01:00
|
|
|
return err
|
2020-01-13 06:04:11 +01:00
|
|
|
}
|
|
|
|
|
2023-12-09 19:52:17 +01:00
|
|
|
func (r sqlRepository) logSQL(sql string, args dbx.Params, err error, rowsAffected int64, start time.Time) {
|
2020-02-24 03:41:10 +01:00
|
|
|
elapsed := time.Since(start)
|
2023-12-09 19:52:17 +01:00
|
|
|
//var fmtArgs []string
|
|
|
|
//for name, val := range args {
|
|
|
|
// var f string
|
|
|
|
// switch a := args[val].(type) {
|
|
|
|
// case string:
|
|
|
|
// f = `'` + a + `'`
|
|
|
|
// default:
|
|
|
|
// f = fmt.Sprintf("%v", a)
|
|
|
|
// }
|
|
|
|
// fmtArgs = append(fmtArgs, f)
|
|
|
|
//}
|
2020-02-01 21:06:49 +01:00
|
|
|
if err != nil {
|
2023-12-09 19:52:17 +01:00
|
|
|
log.Error(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
|
2020-02-01 21:06:49 +01:00
|
|
|
} else {
|
2023-12-09 19:52:17 +01:00
|
|
|
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed)
|
2020-01-28 14:22:17 +01:00
|
|
|
}
|
2020-01-16 21:56:24 +01:00
|
|
|
}
|