navidrome/persistence/sql_base_repository.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

255 lines
6.3 KiB
Go
Raw Normal View History

package persistence
2020-01-12 23:32:06 +01:00
import (
"context"
2022-10-01 00:54:25 +02:00
"errors"
"fmt"
"strings"
"time"
. "github.com/Masterminds/squirrel"
2022-07-30 18:43:48 +02:00
"github.com/beego/beego/v2/client/orm"
"github.com/google/uuid"
"github.com/navidrome/navidrome/log"
2020-01-24 01:44:08 +01:00
"github.com/navidrome/navidrome/model"
2020-05-13 22:49:55 +02:00
"github.com/navidrome/navidrome/model/request"
2020-01-12 23:32:06 +01:00
)
type sqlRepository struct {
ctx context.Context
tableName string
2022-07-30 18:43:48 +02:00
ormer orm.QueryExecutor
sortMappings map[string]string
2020-01-12 23:32:06 +01:00
}
const invalidUserId = "-1"
func userId(ctx context.Context) string {
2020-05-13 22:49:55 +02:00
if user, ok := request.UserFrom(ctx); !ok {
return invalidUserId
2020-05-13 22:49:55 +02:00
} else {
return user.ID
2020-01-13 00:36:19 +01:00
}
}
func loggedUser(ctx context.Context) *model.User {
2020-05-13 22:49:55 +02:00
if user, ok := request.UserFrom(ctx); !ok {
return &model.User{}
2020-05-13 22:49:55 +02:00
} else {
return &user
}
}
2020-01-31 22:03:30 +01:00
func (r sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
sq := Select().From(r.tableName)
sq = r.applyOptions(sq, options...)
2020-01-31 20:52:06 +01:00
sq = r.applyFilters(sq, options...)
return sq
2020-01-12 23:32:06 +01:00
}
2020-01-31 22:03:30 +01:00
func (r sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
if len(options) > 0 {
if options[0].Max > 0 {
sq = sq.Limit(uint64(options[0].Max))
}
if options[0].Offset > 0 {
2020-01-31 20:52:06 +01:00
sq = sq.Offset(uint64(options[0].Offset))
}
if options[0].Sort != "" {
2020-12-23 17:37:38 +01:00
sq = sq.OrderBy(r.buildSortOrder(options[0].Sort, options[0].Order))
}
2020-01-31 20:52:06 +01:00
}
return sq
}
2020-12-23 17:37:38 +01:00
func (r sqlRepository) buildSortOrder(sort, order string) string {
if mapping, ok := r.sortMappings[sort]; ok {
sort = mapping
}
sort = toSnakeCase(sort)
order = strings.ToLower(strings.TrimSpace(order))
var reverseOrder string
if order == "desc" {
reverseOrder = "asc"
} else {
order = "asc"
reverseOrder = "desc"
}
var newSort []string
parts := strings.FieldsFunc(sort, splitFunc(','))
2020-12-23 17:37:38 +01:00
for _, p := range parts {
f := strings.FieldsFunc(p, splitFunc(' '))
2020-12-23 17:37:38 +01:00
newField := []string{f[0]}
if len(f) == 1 {
newField = append(newField, order)
} else {
if f[1] == "asc" {
newField = append(newField, order)
} else {
newField = append(newField, reverseOrder)
}
}
newSort = append(newSort, strings.Join(newField, " "))
}
return strings.Join(newSort, ", ")
}
func splitFunc(delimiter rune) func(c rune) bool {
open := 0
return func(c rune) bool {
if c == '(' {
open++
return false
}
if open > 0 {
if c == ')' {
open--
}
return false
}
return c == delimiter
}
}
2020-01-31 22:03:30 +01:00
func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
2020-01-31 20:52:06 +01:00
if len(options) > 0 && options[0].Filters != nil {
sq = sq.Where(options[0].Filters)
}
return sq
}
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
2020-02-01 21:06:49 +01:00
query, args, err := sq.ToSql()
if err != nil {
return 0, err
}
start := time.Now()
2020-02-28 17:02:38 +01:00
var c int64
res, err := r.ormer.Raw(query, args...).Exec()
2020-02-28 17:02:38 +01:00
if res != nil {
c, _ = res.RowsAffected()
}
r.logSQL(query, args, err, c, start)
if err != nil {
if err.Error() != "LastInsertId is not supported by this driver" {
return 0, err
}
}
return res.RowsAffected()
2020-01-12 23:32:06 +01:00
}
// Note: Due to a bug in the QueryRow method, this function does not map any embedded structs (ex: annotations)
// In this case, use the queryAll method and get the first item of the returned list
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
2020-02-01 21:06:49 +01:00
query, args, err := sq.ToSql()
if err != nil {
2020-01-15 01:20:47 +01:00
return err
}
start := time.Now()
err = r.ormer.Raw(query, args...).QueryRow(response)
2022-10-01 00:54:25 +02:00
if errors.Is(err, orm.ErrNoRows) {
2021-06-24 06:01:05 +02:00
r.logSQL(query, args, nil, 0, start)
return model.ErrNotFound
}
r.logSQL(query, args, err, 1, start)
return err
2020-01-15 01:20:47 +01:00
}
func (r sqlRepository) queryAll(sq Sqlizer, response interface{}) error {
2020-02-01 21:06:49 +01:00
query, args, err := sq.ToSql()
if err != nil {
2020-01-12 23:32:06 +01:00
return err
}
start := time.Now()
2020-02-01 21:06:49 +01:00
c, err := r.ormer.Raw(query, args...).QueryRows(response)
2022-10-01 00:54:25 +02:00
if errors.Is(err, orm.ErrNoRows) {
r.logSQL(query, args, nil, c, start)
return model.ErrNotFound
}
r.logSQL(query, args, nil, c, start)
return err
2020-01-12 23:32:06 +01:00
}
func (r sqlRepository) exists(existsQuery SelectBuilder) (bool, error) {
existsQuery = existsQuery.Columns("count(*) as exist").From(r.tableName)
var res struct{ Exist int64 }
err := r.queryOne(existsQuery, &res)
return res.Exist > 0, err
2020-01-13 06:04:11 +01:00
}
func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOptions) (int64, error) {
2021-07-17 02:20:33 +02:00
countQuery = countQuery.Columns("count(distinct " + r.tableName + ".id) as count").From(r.tableName)
2020-01-31 20:52:06 +01:00
countQuery = r.applyFilters(countQuery, options...)
var res struct{ Count int64 }
err := r.queryOne(countQuery, &res)
return res.Count, err
}
2020-01-13 06:04:11 +01:00
2021-06-03 00:40:29 +02:00
func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) {
values, _ := toSqlArgs(m)
2021-06-03 00:40:29 +02:00
// If there's an ID, try to update first
if id != "" {
2021-06-03 00:40:29 +02:00
updateValues := map[string]interface{}{}
// This is a map of the columns that need to be updated, if specified
c2upd := map[string]struct{}{}
for _, c := range colsToUpdate {
c2upd[toSnakeCase(c)] = struct{}{}
}
2021-06-03 00:40:29 +02:00
for k, v := range values {
if _, found := c2upd[k]; len(c2upd) == 0 || found {
2021-06-03 00:40:29 +02:00
updateValues[k] = v
}
}
2021-06-03 00:40:29 +02:00
delete(updateValues, "created_at")
update := Update(r.tableName).Where(Eq{"id": id}).SetMap(updateValues)
count, err := r.executeSQL(update)
if err != nil {
return "", err
}
if count > 0 {
2021-06-03 00:40:29 +02:00
return id, nil
}
}
// If it does not have an ID OR the ID was not found (when it is a new record with predefined id)
if id == "" {
id = uuid.NewString()
values["id"] = id
}
insert := Insert(r.tableName).SetMap(values)
_, err = r.executeSQL(insert)
return id, err
}
func (r sqlRepository) delete(cond Sqlizer) error {
del := Delete(r.tableName).Where(cond)
_, err := r.executeSQL(del)
2022-10-01 00:54:25 +02:00
if errors.Is(err, orm.ErrNoRows) {
return model.ErrNotFound
2020-01-13 06:04:11 +01:00
}
return err
2020-01-13 06:04:11 +01:00
}
func (r sqlRepository) logSQL(sql string, args []interface{}, err error, rowsAffected int64, start time.Time) {
2020-02-24 03:41:10 +01:00
elapsed := time.Since(start)
2020-02-01 21:06:49 +01:00
var fmtArgs []string
for i := range args {
var f string
switch a := args[i].(type) {
case string:
f = `'` + a + `'`
default:
f = fmt.Sprintf("%v", a)
}
2020-02-01 21:06:49 +01:00
fmtArgs = append(fmtArgs, f)
}
if err != nil {
2020-02-24 03:41:10 +01:00
log.Error(r.ctx, "SQL: `"+sql+"`", "args", `[`+strings.Join(fmtArgs, ",")+`]`, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
2020-02-01 21:06:49 +01:00
} else {
2020-02-24 03:41:10 +01:00
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", `[`+strings.Join(fmtArgs, ",")+`]`, "rowsAffected", rowsAffected, "elapsedTime", elapsed)
}
2020-01-16 21:56:24 +01:00
}