package sql import ( "context" "errors" "fmt" "tercul/internal/domain" "tercul/internal/platform/config" "tercul/internal/platform/log" "time" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "gorm.io/gorm" ) // Common repository errors var ( ErrEntityNotFound = errors.New("entity not found") ErrInvalidID = errors.New("invalid ID: cannot be zero") ErrInvalidInput = errors.New("invalid input parameters") ErrDatabaseOperation = errors.New("database operation failed") ErrContextRequired = errors.New("context is required") ErrTransactionFailed = errors.New("transaction failed") ) // BaseRepositoryImpl provides a default implementation of BaseRepository using GORM type BaseRepositoryImpl[T any] struct { db *gorm.DB tracer trace.Tracer } // NewBaseRepositoryImpl creates a new BaseRepositoryImpl func NewBaseRepositoryImpl[T any](db *gorm.DB) domain.BaseRepository[T] { return &BaseRepositoryImpl[T]{ db: db, tracer: otel.Tracer("base.repository"), } } // validateContext ensures context is not nil func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { if ctx == nil { return ErrContextRequired } return nil } // validateID ensures ID is valid func (r *BaseRepositoryImpl[T]) validateID(id uint) error { if id == 0 { return ErrInvalidID } return nil } // validateEntity ensures entity is not nil func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error { if entity == nil { return ErrInvalidInput } return nil } // validatePagination ensures pagination parameters are valid func (r *BaseRepositoryImpl[T]) validatePagination(page, pageSize int) (int, int, error) { if page < 1 { page = 1 } if pageSize < 1 { pageSize = config.Cfg.PageSize if pageSize < 1 { pageSize = 20 // Default page size } } if pageSize > 1000 { return 0, 0, fmt.Errorf("page size too large: %d (max: 1000)", pageSize) } return page, pageSize, nil } // buildQuery applies query options to a GORM query func (r *BaseRepositoryImpl[T]) buildQuery(query *gorm.DB, options *domain.QueryOptions) *gorm.DB { if options == nil { return query } // Apply preloads for _, preload := range options.Preloads { query = query.Preload(preload) } // Apply where conditions for field, value := range options.Where { query = query.Where(field, value) } // Apply ordering if options.OrderBy != "" { query = query.Order(options.OrderBy) } // Apply limit and offset if options.Limit > 0 { query = query.Limit(options.Limit) } if options.Offset > 0 { query = query.Offset(options.Offset) } return query } // Create adds a new entity to the database func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "Create") defer span.End() if err := r.validateEntity(entity); err != nil { return err } start := time.Now() err := r.db.WithContext(ctx).Create(entity).Error duration := time.Since(start) if err != nil { log.Error(err, "Failed to create entity") return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity created successfully in %s", duration)) return nil } // CreateInTx creates an entity within a transaction func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "CreateInTx") defer span.End() if err := r.validateEntity(entity); err != nil { return err } if tx == nil { return ErrTransactionFailed } start := time.Now() err := tx.WithContext(ctx).Create(entity).Error duration := time.Since(start) if err != nil { log.Error(err, "Failed to create entity in transaction") return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration)) return nil } // GetByID retrieves an entity by its ID func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "GetByID") defer span.End() if err := r.validateID(id); err != nil { return nil, err } start := time.Now() var entity T err := r.db.WithContext(ctx).First(&entity, id).Error duration := time.Since(start) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration)) return nil, ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id)) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration)) return &entity, nil } // GetByIDWithOptions retrieves an entity by its ID with query options func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "GetByIDWithOptions") defer span.End() if err := r.validateID(id); err != nil { return nil, err } start := time.Now() var entity T query := r.buildQuery(r.db.WithContext(ctx), options) err := query.First(&entity, id).Error duration := time.Since(start) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration)) return nil, ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id)) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration)) return &entity, nil } // Update updates an existing entity func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "Update") defer span.End() if err := r.validateEntity(entity); err != nil { return err } start := time.Now() err := r.db.WithContext(ctx).Save(entity).Error duration := time.Since(start) if err != nil { log.Error(err, "Failed to update entity") return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration)) return nil } // UpdateInTx updates an entity within a transaction func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "UpdateInTx") defer span.End() if err := r.validateEntity(entity); err != nil { return err } if tx == nil { return ErrTransactionFailed } start := time.Now() err := tx.WithContext(ctx).Save(entity).Error duration := time.Since(start) if err != nil { log.Error(err, "Failed to update entity in transaction") return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration)) return nil } // Delete removes an entity by its ID func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "Delete") defer span.End() if err := r.validateID(id); err != nil { return err } start := time.Now() var entity T result := r.db.WithContext(ctx).Delete(&entity, id) duration := time.Since(start) if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration)) return ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration)) return nil } // DeleteInTx removes an entity by its ID within a transaction func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "DeleteInTx") defer span.End() if err := r.validateID(id); err != nil { return err } if tx == nil { return ErrTransactionFailed } start := time.Now() var entity T result := tx.WithContext(ctx).Delete(&entity, id) duration := time.Since(start) if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration)) return ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration)) return nil } // List returns a paginated list of entities func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[T], error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "List") defer span.End() page, pageSize, err := r.validatePagination(page, pageSize) if err != nil { return nil, err } start := time.Now() var entities []T var totalCount int64 // Get total count if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil { log.Error(err, "Failed to count entities") return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } // Calculate offset offset := (page - 1) * pageSize // Get paginated data if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get paginated entities") return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) // Calculate total pages and pagination info totalPages := int(totalCount) / pageSize if int(totalCount)%pageSize > 0 { totalPages++ } hasNext := page < totalPages hasPrev := page > 1 log.Debug(fmt.Sprintf("Paginated entities retrieved successfully in %s", duration)) return &domain.PaginatedResult[T]{ Items: entities, TotalCount: totalCount, Page: page, PageSize: pageSize, TotalPages: totalPages, HasNext: hasNext, HasPrev: hasPrev, }, nil } // ListWithOptions returns entities with query options func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "ListWithOptions") defer span.End() start := time.Now() var entities []T query := r.buildQuery(r.db.WithContext(ctx), options) if err := query.Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities with options") return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) log.Debug(fmt.Sprintf("Entities retrieved successfully with options in %s", duration)) return entities, nil } // ListAll returns all entities (use with caution for large datasets) func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "ListAll") defer span.End() start := time.Now() var entities []T if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil { log.Error(err, "Failed to get all entities") return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) log.Debug(fmt.Sprintf("All entities retrieved successfully in %s", duration)) return entities, nil } // Count returns the total number of entities func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) { if err := r.validateContext(ctx); err != nil { return 0, err } ctx, span := r.tracer.Start(ctx, "Count") defer span.End() start := time.Now() var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities") return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) log.Debug(fmt.Sprintf("Entity count retrieved successfully in %s", duration)) return count, nil } // CountWithOptions returns the count with query options func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { if err := r.validateContext(ctx); err != nil { return 0, err } ctx, span := r.tracer.Start(ctx, "CountWithOptions") defer span.End() start := time.Now() var count int64 query := r.buildQuery(r.db.WithContext(ctx), options) if err := query.Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities with options") return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) log.Debug(fmt.Sprintf("Entity count retrieved successfully with options in %s", duration)) return count, nil } // FindWithPreload retrieves an entity by its ID with preloaded relationships func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads []string, id uint) (*T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "FindWithPreload") defer span.End() if err := r.validateID(id); err != nil { return nil, err } start := time.Now() var entity T query := r.db.WithContext(ctx) for _, preload := range preloads { query = query.Preload(preload) } if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start))) return nil, ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id)) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with preloads in %s", id, duration)) return &entity, nil } // GetAllForSync returns entities in batches for synchronization func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, offset int) ([]T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "GetAllForSync") defer span.End() if batchSize <= 0 { batchSize = config.Cfg.BatchSize if batchSize <= 0 { batchSize = 100 // Default batch size } } if batchSize > 1000 { return nil, fmt.Errorf("batch size too large: %d (max: 1000)", batchSize) } start := time.Now() var entities []T if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities for sync") return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) log.Debug(fmt.Sprintf("Entities retrieved successfully for sync in %s", duration)) return entities, nil } // Exists checks if an entity exists by ID func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, error) { if err := r.validateContext(ctx); err != nil { return false, err } ctx, span := r.tracer.Start(ctx, "Exists") defer span.End() if err := r.validateID(id); err != nil { return false, err } start := time.Now() var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil { log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id)) return false, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) exists := count > 0 log.Debug(fmt.Sprintf("Entity existence checked for id %d in %s", id, duration)) return exists, nil } // BeginTx starts a new transaction func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) { if err := r.validateContext(ctx); err != nil { return nil, err } ctx, span := r.tracer.Start(ctx, "BeginTx") defer span.End() tx := r.db.WithContext(ctx).Begin() if tx.Error != nil { log.Error(tx.Error, "Failed to begin transaction") return nil, fmt.Errorf("%w: %v", ErrTransactionFailed, tx.Error) } log.Debug("Transaction started successfully") return tx, nil } // WithTx executes a function within a transaction func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { if err := r.validateContext(ctx); err != nil { return err } ctx, span := r.tracer.Start(ctx, "WithTx") defer span.End() tx, err := r.BeginTx(ctx) if err != nil { return err } defer func() { if r := recover(); r != nil { tx.Rollback() log.Error(fmt.Errorf("panic recovered: %v", r), "Transaction panic recovered") } }() if err := fn(tx); err != nil { if rbErr := tx.Rollback().Error; rbErr != nil { log.Error(rbErr, fmt.Sprintf("Failed to rollback transaction after error: %v", err)) return fmt.Errorf("transaction failed and rollback failed: %v (rollback: %v)", err, rbErr) } log.Debug(fmt.Sprintf("Transaction rolled back due to error: %v", err)) return err } if err := tx.Commit().Error; err != nil { log.Error(err, "Failed to commit transaction") return fmt.Errorf("%w: %v", ErrTransactionFailed, err) } log.Debug("Transaction committed successfully") return nil }