tercul-backend/internal/data/sql/base_repository.go
google-labs-jules[bot] c2e9a118e2 feat(testing): Increase test coverage and fix authz bugs
This commit significantly increases the test coverage across the application and fixes several underlying bugs that were discovered while writing the new tests.

The key changes include:

- **New Tests:** Added extensive integration and unit tests for GraphQL resolvers, application services, and data repositories, substantially increasing the test coverage for packages like `graphql`, `user`, `translation`, and `analytics`.

- **Authorization Bug Fixes:**
  - Fixed a critical bug where a user creating a `Work` was not correctly associated as its author, causing subsequent permission failures.
  - Corrected the authorization logic in `authz.Service` to properly check for entity ownership by non-admin users.

- **Test Refactoring:**
  - Refactored numerous test suites to use `testify/mock` instead of manual mocks, improving test clarity and maintainability.
  - Isolated integration tests by creating a fresh admin user and token for each test run, eliminating test pollution.
  - Centralized domain errors into `internal/domain/errors.go` and updated repositories to use them, making error handling more consistent.

- **Code Quality Improvements:**
  - Replaced manual mock implementations with `testify/mock` for better consistency.
  - Cleaned up redundant and outdated test files.

These changes stabilize the test suite, improve the overall quality of the codebase, and move the project closer to the goal of 80% test coverage.
2025-10-09 07:03:45 +00:00

623 lines
17 KiB
Go

package sql
import (
"context"
"errors"
"fmt"
"tercul/internal/domain"
"tercul/internal/platform/config"
"tercul/internal/platform/log"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
)
// BaseRepositoryImpl provides a default implementation of BaseRepository using GORM
type BaseRepositoryImpl[T any] struct {
db *gorm.DB
tracer trace.Tracer
cfg *config.Config
}
// NewBaseRepositoryImpl creates a new BaseRepositoryImpl
func NewBaseRepositoryImpl[T any](db *gorm.DB, cfg *config.Config) *BaseRepositoryImpl[T] {
return &BaseRepositoryImpl[T]{
db: db,
tracer: otel.Tracer("base.repository"),
cfg: cfg,
}
}
// validateContext ensures context is not nil
func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error {
if ctx == nil {
return domain.ErrValidation
}
return nil
}
// validateID ensures ID is valid
func (r *BaseRepositoryImpl[T]) validateID(id uint) error {
if id == 0 {
return domain.ErrValidation
}
return nil
}
// validateEntity ensures entity is not nil
func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error {
if entity == nil {
return domain.ErrValidation
}
return nil
}
// validatePagination ensures pagination parameters are valid
func (r *BaseRepositoryImpl[T]) validatePagination(page, pageSize int) (int, int, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = r.cfg.PageSize
if pageSize < 1 {
pageSize = 20 // Default page size
}
}
if pageSize > 1000 {
return 0, 0, fmt.Errorf("page size too large: %d (max: 1000)", pageSize)
}
return page, pageSize, nil
}
// buildQuery applies query options to a GORM query
func (r *BaseRepositoryImpl[T]) buildQuery(query *gorm.DB, options *domain.QueryOptions) *gorm.DB {
if options == nil {
return query
}
// Apply preloads
for _, preload := range options.Preloads {
query = query.Preload(preload)
}
// Apply where conditions
for field, value := range options.Where {
query = query.Where(field, value)
}
// Apply ordering
if options.OrderBy != "" {
query = query.Order(options.OrderBy)
}
// Apply limit and offset
if options.Limit > 0 {
query = query.Limit(options.Limit)
}
if options.Offset > 0 {
query = query.Offset(options.Offset)
}
return query
}
// Create adds a new entity to the database
func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "Create")
defer span.End()
if err := r.validateEntity(entity); err != nil {
return err
}
start := time.Now()
err := r.db.WithContext(ctx).Create(entity).Error
duration := time.Since(start)
if err != nil {
log.Error(err, "Failed to create entity")
return fmt.Errorf("database operation failed: %w", err)
}
log.Debug(fmt.Sprintf("Entity created successfully in %s", duration))
return nil
}
// CreateInTx creates an entity within a transaction
func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, entity *T) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "CreateInTx")
defer span.End()
if err := r.validateEntity(entity); err != nil {
return err
}
if tx == nil {
return domain.ErrInvalidOperation
}
start := time.Now()
err := tx.WithContext(ctx).Create(entity).Error
duration := time.Since(start)
if err != nil {
log.Error(err, "Failed to create entity in transaction")
return fmt.Errorf("database operation failed: %w", err)
}
log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration))
return nil
}
// GetByID retrieves an entity by its ID
func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "GetByID")
defer span.End()
if err := r.validateID(id); err != nil {
return nil, err
}
start := time.Now()
var entity T
err := r.db.WithContext(ctx).First(&entity, id).Error
duration := time.Since(start)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration))
return nil, domain.ErrEntityNotFound
}
log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id))
return nil, fmt.Errorf("database operation failed: %w", err)
}
log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration))
return &entity, nil
}
// GetByIDWithOptions retrieves an entity by its ID with query options
func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*T, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "GetByIDWithOptions")
defer span.End()
if err := r.validateID(id); err != nil {
return nil, err
}
start := time.Now()
var entity T
query := r.buildQuery(r.db.WithContext(ctx), options)
err := query.First(&entity, id).Error
duration := time.Since(start)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration))
return nil, domain.ErrEntityNotFound
}
log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id))
return nil, fmt.Errorf("database operation failed: %w", err)
}
log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration))
return &entity, nil
}
// Update updates an existing entity
func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "Update")
defer span.End()
if err := r.validateEntity(entity); err != nil {
return err
}
start := time.Now()
err := r.db.WithContext(ctx).Save(entity).Error
duration := time.Since(start)
if err != nil {
log.Error(err, "Failed to update entity")
return fmt.Errorf("database operation failed: %w", err)
}
log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration))
return nil
}
// UpdateInTx updates an entity within a transaction
func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *T) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "UpdateInTx")
defer span.End()
if err := r.validateEntity(entity); err != nil {
return err
}
if tx == nil {
return domain.ErrInvalidOperation
}
start := time.Now()
err := tx.WithContext(ctx).Save(entity).Error
duration := time.Since(start)
if err != nil {
log.Error(err, "Failed to update entity in transaction")
return fmt.Errorf("database operation failed: %w", err)
}
log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration))
return nil
}
// Delete removes an entity by its ID
func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "Delete")
defer span.End()
if err := r.validateID(id); err != nil {
return err
}
start := time.Now()
var entity T
result := r.db.WithContext(ctx).Delete(&entity, id)
duration := time.Since(start)
if result.Error != nil {
log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id))
return fmt.Errorf("database operation failed: %w", result.Error)
}
if result.RowsAffected == 0 {
log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration))
return domain.ErrEntityNotFound
}
log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration))
return nil
}
// DeleteInTx removes an entity by its ID within a transaction
func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "DeleteInTx")
defer span.End()
if err := r.validateID(id); err != nil {
return err
}
if tx == nil {
return domain.ErrInvalidOperation
}
start := time.Now()
var entity T
result := tx.WithContext(ctx).Delete(&entity, id)
duration := time.Since(start)
if result.Error != nil {
log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id))
return fmt.Errorf("database operation failed: %w", result.Error)
}
if result.RowsAffected == 0 {
log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration))
return domain.ErrEntityNotFound
}
log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration))
return nil
}
// List returns a paginated list of entities
func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[T], error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "List")
defer span.End()
page, pageSize, err := r.validatePagination(page, pageSize)
if err != nil {
return nil, err
}
start := time.Now()
var entities []T
var totalCount int64
// Get total count
if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil {
log.Error(err, "Failed to count entities")
return nil, fmt.Errorf("database operation failed: %w", err)
}
// Calculate offset
offset := (page - 1) * pageSize
// Get paginated data
if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil {
log.Error(err, "Failed to get paginated entities")
return nil, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
// Calculate total pages and pagination info
totalPages := int(totalCount) / pageSize
if int(totalCount)%pageSize > 0 {
totalPages++
}
hasNext := page < totalPages
hasPrev := page > 1
log.Debug(fmt.Sprintf("Paginated entities retrieved successfully in %s", duration))
return &domain.PaginatedResult[T]{
Items: entities,
TotalCount: totalCount,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
HasNext: hasNext,
HasPrev: hasPrev,
}, nil
}
// ListWithOptions returns entities with query options
func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]T, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "ListWithOptions")
defer span.End()
start := time.Now()
var entities []T
query := r.buildQuery(r.db.WithContext(ctx), options)
if err := query.Find(&entities).Error; err != nil {
log.Error(err, "Failed to get entities with options")
return nil, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
log.Debug(fmt.Sprintf("Entities retrieved successfully with options in %s", duration))
return entities, nil
}
// ListAll returns all entities (use with caution for large datasets)
func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "ListAll")
defer span.End()
start := time.Now()
var entities []T
if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil {
log.Error(err, "Failed to get all entities")
return nil, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
log.Debug(fmt.Sprintf("All entities retrieved successfully in %s", duration))
return entities, nil
}
// Count returns the total number of entities
func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) {
if err := r.validateContext(ctx); err != nil {
return 0, err
}
ctx, span := r.tracer.Start(ctx, "Count")
defer span.End()
start := time.Now()
var count int64
if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil {
log.Error(err, "Failed to count entities")
return 0, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
log.Debug(fmt.Sprintf("Entity count retrieved successfully in %s", duration))
return count, nil
}
// CountWithOptions returns the count with query options
func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) {
if err := r.validateContext(ctx); err != nil {
return 0, err
}
ctx, span := r.tracer.Start(ctx, "CountWithOptions")
defer span.End()
start := time.Now()
var count int64
query := r.buildQuery(r.db.WithContext(ctx), options)
if err := query.Model(new(T)).Count(&count).Error; err != nil {
log.Error(err, "Failed to count entities with options")
return 0, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
log.Debug(fmt.Sprintf("Entity count retrieved successfully with options in %s", duration))
return count, nil
}
// FindWithPreload retrieves an entity by its ID with preloaded relationships
func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads []string, id uint) (*T, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "FindWithPreload")
defer span.End()
if err := r.validateID(id); err != nil {
return nil, err
}
start := time.Now()
var entity T
query := r.db.WithContext(ctx)
for _, preload := range preloads {
query = query.Preload(preload)
}
if err := query.First(&entity, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start)))
return nil, domain.ErrEntityNotFound
}
log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id))
return nil, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with preloads in %s", id, duration))
return &entity, nil
}
// GetAllForSync returns entities in batches for synchronization
func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, offset int) ([]T, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "GetAllForSync")
defer span.End()
if batchSize <= 0 {
batchSize = r.cfg.BatchSize
if batchSize <= 0 {
batchSize = 100 // Default batch size
}
}
if batchSize > 1000 {
return nil, fmt.Errorf("batch size too large: %d (max: 1000)", batchSize)
}
start := time.Now()
var entities []T
if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil {
log.Error(err, "Failed to get entities for sync")
return nil, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
log.Debug(fmt.Sprintf("Entities retrieved successfully for sync in %s", duration))
return entities, nil
}
// Exists checks if an entity exists by ID
func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, error) {
if err := r.validateContext(ctx); err != nil {
return false, err
}
ctx, span := r.tracer.Start(ctx, "Exists")
defer span.End()
if err := r.validateID(id); err != nil {
return false, err
}
start := time.Now()
var count int64
if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil {
log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id))
return false, fmt.Errorf("database operation failed: %w", err)
}
duration := time.Since(start)
exists := count > 0
log.Debug(fmt.Sprintf("Entity existence checked for id %d in %s", id, duration))
return exists, nil
}
// BeginTx starts a new transaction
func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) {
if err := r.validateContext(ctx); err != nil {
return nil, err
}
ctx, span := r.tracer.Start(ctx, "BeginTx")
defer span.End()
tx := r.db.WithContext(ctx).Begin()
if tx.Error != nil {
log.Error(tx.Error, "Failed to begin transaction")
return nil, fmt.Errorf("transaction failed: %w", tx.Error)
}
log.Debug("Transaction started successfully")
return tx, nil
}
// WithTx executes a function within a transaction
func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error {
if err := r.validateContext(ctx); err != nil {
return err
}
ctx, span := r.tracer.Start(ctx, "WithTx")
defer span.End()
tx, err := r.BeginTx(ctx)
if err != nil {
return err
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
log.Error(fmt.Errorf("panic recovered: %v", r), "Transaction panic recovered")
}
}()
if err := fn(tx); err != nil {
if rbErr := tx.Rollback().Error; rbErr != nil {
log.Error(rbErr, fmt.Sprintf("Failed to rollback transaction after error: %v", err))
return fmt.Errorf("transaction failed and rollback failed: %v (rollback: %v)", err, rbErr)
}
log.Debug(fmt.Sprintf("Transaction rolled back due to error: %v", err))
return err
}
if err := tx.Commit().Error; err != nil {
log.Error(err, "Failed to commit transaction")
return fmt.Errorf("transaction failed: %w", err)
}
log.Debug("Transaction committed successfully")
return nil
}