Files
repo/repo.go

150 lines
3.2 KiB
Go
Raw Normal View History

2025-05-03 21:35:32 +03:00
package repo
import (
"context"
"database/sql"
"encoding/json"
"errors"
"time"
)
var (
ErrInitRepo = errors.New("error creating table")
ErrMarshal = errors.New("error marshal")
ErrExecQuery = errors.New("error executing DB query")
ErrNotFound = errors.New("not found")
)
type Repo[T any] interface {
2025-08-30 10:30:53 +03:00
// Create saves object to the repository. If id already exists
2025-05-03 21:35:32 +03:00
// in the database it is an error.
Create(ctx context.Context, id string, v *T) error
// Read returns object with specified id or ErrNotFound
Read(ctx context.Context, id string) (*T, error)
// Update updates object with id.
Update(ctx context.Context, id string, v *T) error
2025-08-30 10:30:53 +03:00
// Delete deletes object from the database.
2025-05-03 21:35:32 +03:00
Delete(ctx context.Context, id string) error
}
// OpenOrCreate accepts *sql.DB and tablename and tries to create the named table
// in the database if it does not exist. The DB instance will also be used by the repository.
func OpenOrCreate[T any](ctx context.Context, db *sql.DB, tablename string) (Repo[T], error) {
if err := initDB(ctx, db, tablename); err != nil {
return nil, err
}
return &repo[T]{
db: db,
table: tablename,
}, nil
}
type repo[T any] struct {
db *sql.DB
table string
}
func (r *repo[T]) Create(ctx context.Context, id string, v *T) error {
now := time.Now()
b, err := marshal(v)
if err != nil {
return err
}
query := "INSERT INTO " + r.table + " (id, created_at, updated_at, payload) VALUES ($1, $2, $3, $4)"
if err := r.execContext(ctx, query, id, now, now, string(b)); err != nil {
return err
}
return nil
}
func (r *repo[T]) Read(ctx context.Context, id string) (*T, error) {
2025-08-30 10:30:53 +03:00
query := "SELECT payload FROM " + r.table + " WHERE id = $1"
2025-05-03 21:35:32 +03:00
row := r.db.QueryRowContext(ctx, query, id)
var s string
err := row.Scan(&s)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, errors.Join(ErrNotFound, err)
} else if err != nil {
return nil, err
}
v, err := unmarshal[T]([]byte(s))
if err != nil {
return nil, err
}
return v, nil
}
func (r *repo[T]) Update(ctx context.Context, id string, v *T) error {
now := time.Now()
b, err := marshal(v)
if err != nil {
return err
}
2025-08-30 10:30:53 +03:00
query := "UPDATE " + r.table + " SET updated_at = $1, payload = $2 WHERE id = $3"
2025-05-03 21:35:32 +03:00
if err := r.execContext(ctx, query, now, string(b), id); err != nil {
return err
}
return nil
}
2025-08-30 10:30:53 +03:00
// Delete deletes record with cpecified id.
2025-05-03 21:35:32 +03:00
func (r *repo[T]) Delete(ctx context.Context, id string) error {
query := "DELETE FROM " + r.table + " WHERE id = $1"
if err := r.execContext(ctx, query, id); err != nil {
return err
}
return nil
}
func (r *repo[T]) execContext(ctx context.Context, query string, args ...any) error {
res, err := r.db.ExecContext(ctx, query, args...)
if err != nil {
return errors.Join(ErrExecQuery, err)
}
2025-08-30 10:30:53 +03:00
2025-05-03 21:35:32 +03:00
affected, err := res.RowsAffected()
if err != nil {
return errors.Join(ErrExecQuery, err)
}
if affected == 0 {
return ErrNotFound
}
return nil
}
func marshal(v any) ([]byte, error) {
b, err := json.Marshal(v)
if err != nil {
return nil, errors.Join(ErrMarshal, err)
}
return b, nil
}
func unmarshal[T any](b []byte) (*T, error) {
v := new(T)
if err := json.Unmarshal(b, v); err != nil {
return nil, errors.Join(ErrMarshal, err)
}
return v, nil
}