package repositories import ( "context" "errors" "fmt" "time" "gorm.io/gorm" "tercul/config" "tercul/logger" ) // Common repository errors var ( ErrEntityNotFound = errors.New("entity not found") ErrInvalidID = errors.New("invalid ID: cannot be zero") ErrInvalidInput = errors.New("invalid input parameters") ErrDatabaseOperation = errors.New("database operation failed") ErrContextRequired = errors.New("context is required") ErrTransactionFailed = errors.New("transaction failed") ) // PaginatedResult represents a paginated result set type PaginatedResult[T any] struct { Items []T `json:"items"` TotalCount int64 `json:"totalCount"` Page int `json:"page"` PageSize int `json:"pageSize"` TotalPages int `json:"totalPages"` HasNext bool `json:"hasNext"` HasPrev bool `json:"hasPrev"` } // QueryOptions provides options for repository queries type QueryOptions struct { Preloads []string OrderBy string Where map[string]interface{} Limit int Offset int } // BaseRepository defines common CRUD operations that all repositories should implement type BaseRepository[T any] interface { // Create adds a new entity to the database Create(ctx context.Context, entity *T) error // CreateInTx creates an entity within a transaction CreateInTx(ctx context.Context, tx *gorm.DB, entity *T) error // GetByID retrieves an entity by its ID GetByID(ctx context.Context, id uint) (*T, error) // GetByIDWithOptions retrieves an entity by its ID with query options GetByIDWithOptions(ctx context.Context, id uint, options *QueryOptions) (*T, error) // Update updates an existing entity Update(ctx context.Context, entity *T) error // UpdateInTx updates an entity within a transaction UpdateInTx(ctx context.Context, tx *gorm.DB, entity *T) error // Delete removes an entity by its ID Delete(ctx context.Context, id uint) error // DeleteInTx removes an entity by its ID within a transaction DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error // List returns a paginated list of entities List(ctx context.Context, page, pageSize int) (*PaginatedResult[T], error) // ListWithOptions returns entities with query options ListWithOptions(ctx context.Context, options *QueryOptions) ([]T, error) // ListAll returns all entities (use with caution for large datasets) ListAll(ctx context.Context) ([]T, error) // Count returns the total number of entities Count(ctx context.Context) (int64, error) // CountWithOptions returns the count with query options CountWithOptions(ctx context.Context, options *QueryOptions) (int64, error) // FindWithPreload retrieves an entity by its ID with preloaded relationships FindWithPreload(ctx context.Context, preloads []string, id uint) (*T, error) // GetAllForSync returns entities in batches for synchronization GetAllForSync(ctx context.Context, batchSize, offset int) ([]T, error) // Exists checks if an entity exists by ID Exists(ctx context.Context, id uint) (bool, error) // BeginTx starts a new transaction BeginTx(ctx context.Context) (*gorm.DB, error) // WithTx executes a function within a transaction WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error } // BaseRepositoryImpl provides a default implementation of BaseRepository using GORM type BaseRepositoryImpl[T any] struct { db *gorm.DB } // NewBaseRepositoryImpl creates a new BaseRepositoryImpl func NewBaseRepositoryImpl[T any](db *gorm.DB) *BaseRepositoryImpl[T] { return &BaseRepositoryImpl[T]{db: db} } // validateContext ensures context is not nil func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { if ctx == nil { return ErrContextRequired } return nil } // validateID ensures ID is valid func (r *BaseRepositoryImpl[T]) validateID(id uint) error { if id == 0 { return ErrInvalidID } return nil } // validateEntity ensures entity is not nil func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error { if entity == nil { return ErrInvalidInput } return nil } // validatePagination ensures pagination parameters are valid func (r *BaseRepositoryImpl[T]) validatePagination(page, pageSize int) (int, int, error) { if page < 1 { page = 1 } if pageSize < 1 { pageSize = config.Cfg.PageSize if pageSize < 1 { pageSize = 20 // Default page size } } if pageSize > 1000 { return 0, 0, fmt.Errorf("page size too large: %d (max: 1000)", pageSize) } return page, pageSize, nil } // buildQuery applies query options to a GORM query func (r *BaseRepositoryImpl[T]) buildQuery(query *gorm.DB, options *QueryOptions) *gorm.DB { if options == nil { return query } // Apply preloads for _, preload := range options.Preloads { query = query.Preload(preload) } // Apply where conditions for field, value := range options.Where { query = query.Where(field, value) } // Apply ordering if options.OrderBy != "" { query = query.Order(options.OrderBy) } // Apply limit and offset if options.Limit > 0 { query = query.Limit(options.Limit) } if options.Offset > 0 { query = query.Offset(options.Offset) } return query } // Create adds a new entity to the database func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } if err := r.validateEntity(entity); err != nil { return err } start := time.Now() err := r.db.WithContext(ctx).Create(entity).Error duration := time.Since(start) if err != nil { logger.LogError("Failed to create entity", logger.F("error", err), logger.F("duration", duration)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } logger.LogDebug("Entity created successfully", logger.F("duration", duration)) return nil } // CreateInTx creates an entity within a transaction func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } if err := r.validateEntity(entity); err != nil { return err } if tx == nil { return ErrTransactionFailed } start := time.Now() err := tx.WithContext(ctx).Create(entity).Error duration := time.Since(start) if err != nil { logger.LogError("Failed to create entity in transaction", logger.F("error", err), logger.F("duration", duration)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } logger.LogDebug("Entity created successfully in transaction", logger.F("duration", duration)) return nil } // GetByID retrieves an entity by its ID func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } if err := r.validateID(id); err != nil { return nil, err } start := time.Now() var entity T err := r.db.WithContext(ctx).First(&entity, id).Error duration := time.Since(start) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { logger.LogDebug("Entity not found", logger.F("id", id), logger.F("duration", duration)) return nil, ErrEntityNotFound } logger.LogError("Failed to get entity by ID", logger.F("id", id), logger.F("error", err), logger.F("duration", duration)) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } logger.LogDebug("Entity retrieved successfully", logger.F("id", id), logger.F("duration", duration)) return &entity, nil } // GetByIDWithOptions retrieves an entity by its ID with query options func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint, options *QueryOptions) (*T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } if err := r.validateID(id); err != nil { return nil, err } start := time.Now() var entity T query := r.buildQuery(r.db.WithContext(ctx), options) err := query.First(&entity, id).Error duration := time.Since(start) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { logger.LogDebug("Entity not found with options", logger.F("id", id), logger.F("duration", duration)) return nil, ErrEntityNotFound } logger.LogError("Failed to get entity by ID with options", logger.F("id", id), logger.F("error", err), logger.F("duration", duration)) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } logger.LogDebug("Entity retrieved successfully with options", logger.F("id", id), logger.F("duration", duration)) return &entity, nil } // Update updates an existing entity func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } if err := r.validateEntity(entity); err != nil { return err } start := time.Now() err := r.db.WithContext(ctx).Save(entity).Error duration := time.Since(start) if err != nil { logger.LogError("Failed to update entity", logger.F("error", err), logger.F("duration", duration)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } logger.LogDebug("Entity updated successfully", logger.F("duration", duration)) return nil } // UpdateInTx updates an entity within a transaction func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *T) error { if err := r.validateContext(ctx); err != nil { return err } if err := r.validateEntity(entity); err != nil { return err } if tx == nil { return ErrTransactionFailed } start := time.Now() err := tx.WithContext(ctx).Save(entity).Error duration := time.Since(start) if err != nil { logger.LogError("Failed to update entity in transaction", logger.F("error", err), logger.F("duration", duration)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } logger.LogDebug("Entity updated successfully in transaction", logger.F("duration", duration)) return nil } // Delete removes an entity by its ID func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error { if err := r.validateContext(ctx); err != nil { return err } if err := r.validateID(id); err != nil { return err } start := time.Now() var entity T result := r.db.WithContext(ctx).Delete(&entity, id) duration := time.Since(start) if result.Error != nil { logger.LogError("Failed to delete entity", logger.F("id", id), logger.F("error", result.Error), logger.F("duration", duration)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { logger.LogDebug("No entity found to delete", logger.F("id", id), logger.F("duration", duration)) return ErrEntityNotFound } logger.LogDebug("Entity deleted successfully", logger.F("id", id), logger.F("rowsAffected", result.RowsAffected), logger.F("duration", duration)) return nil } // DeleteInTx removes an entity by its ID within a transaction func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { if err := r.validateContext(ctx); err != nil { return err } if err := r.validateID(id); err != nil { return err } if tx == nil { return ErrTransactionFailed } start := time.Now() var entity T result := tx.WithContext(ctx).Delete(&entity, id) duration := time.Since(start) if result.Error != nil { logger.LogError("Failed to delete entity in transaction", logger.F("id", id), logger.F("error", result.Error), logger.F("duration", duration)) return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { logger.LogDebug("No entity found to delete in transaction", logger.F("id", id), logger.F("duration", duration)) return ErrEntityNotFound } logger.LogDebug("Entity deleted successfully in transaction", logger.F("id", id), logger.F("rowsAffected", result.RowsAffected), logger.F("duration", duration)) return nil } // List returns a paginated list of entities func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (*PaginatedResult[T], error) { if err := r.validateContext(ctx); err != nil { return nil, err } page, pageSize, err := r.validatePagination(page, pageSize) if err != nil { return nil, err } start := time.Now() var entities []T var totalCount int64 // Get total count if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil { logger.LogError("Failed to count entities", logger.F("error", err), logger.F("duration", time.Since(start))) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } // Calculate offset offset := (page - 1) * pageSize // Get paginated data if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil { logger.LogError("Failed to get paginated entities", logger.F("page", page), logger.F("pageSize", pageSize), logger.F("error", err), logger.F("duration", time.Since(start))) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) // Calculate total pages and pagination info totalPages := int(totalCount) / pageSize if int(totalCount)%pageSize > 0 { totalPages++ } hasNext := page < totalPages hasPrev := page > 1 logger.LogDebug("Paginated entities retrieved successfully", logger.F("page", page), logger.F("pageSize", pageSize), logger.F("totalCount", totalCount), logger.F("totalPages", totalPages), logger.F("hasNext", hasNext), logger.F("hasPrev", hasPrev), logger.F("duration", duration)) return &PaginatedResult[T]{ Items: entities, TotalCount: totalCount, Page: page, PageSize: pageSize, TotalPages: totalPages, HasNext: hasNext, HasPrev: hasPrev, }, nil } // ListWithOptions returns entities with query options func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *QueryOptions) ([]T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } start := time.Now() var entities []T query := r.buildQuery(r.db.WithContext(ctx), options) if err := query.Find(&entities).Error; err != nil { logger.LogError("Failed to get entities with options", logger.F("error", err), logger.F("duration", time.Since(start))) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) logger.LogDebug("Entities retrieved successfully with options", logger.F("count", len(entities)), logger.F("duration", duration)) return entities, nil } // ListAll returns all entities (use with caution for large datasets) func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } start := time.Now() var entities []T if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil { logger.LogError("Failed to get all entities", logger.F("error", err), logger.F("duration", time.Since(start))) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) logger.LogDebug("All entities retrieved successfully", logger.F("count", len(entities)), logger.F("duration", duration)) return entities, nil } // Count returns the total number of entities func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) { if err := r.validateContext(ctx); err != nil { return 0, err } start := time.Now() var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil { logger.LogError("Failed to count entities", logger.F("error", err), logger.F("duration", time.Since(start))) return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) logger.LogDebug("Entity count retrieved successfully", logger.F("count", count), logger.F("duration", duration)) return count, nil } // CountWithOptions returns the count with query options func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *QueryOptions) (int64, error) { if err := r.validateContext(ctx); err != nil { return 0, err } start := time.Now() var count int64 query := r.buildQuery(r.db.WithContext(ctx), options) if err := query.Model(new(T)).Count(&count).Error; err != nil { logger.LogError("Failed to count entities with options", logger.F("error", err), logger.F("duration", time.Since(start))) return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) logger.LogDebug("Entity count retrieved successfully with options", logger.F("count", count), logger.F("duration", duration)) return count, nil } // FindWithPreload retrieves an entity by its ID with preloaded relationships func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads []string, id uint) (*T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } if err := r.validateID(id); err != nil { return nil, err } start := time.Now() var entity T query := r.db.WithContext(ctx) for _, preload := range preloads { query = query.Preload(preload) } if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { logger.LogDebug("Entity not found with preloads", logger.F("id", id), logger.F("preloads", preloads), logger.F("duration", time.Since(start))) return nil, ErrEntityNotFound } logger.LogError("Failed to get entity with preloads", logger.F("id", id), logger.F("preloads", preloads), logger.F("error", err), logger.F("duration", time.Since(start))) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) logger.LogDebug("Entity retrieved successfully with preloads", logger.F("id", id), logger.F("preloads", preloads), logger.F("duration", duration)) return &entity, nil } // GetAllForSync returns entities in batches for synchronization func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, offset int) ([]T, error) { if err := r.validateContext(ctx); err != nil { return nil, err } if batchSize <= 0 { batchSize = config.Cfg.BatchSize if batchSize <= 0 { batchSize = 100 // Default batch size } } if batchSize > 1000 { return nil, fmt.Errorf("batch size too large: %d (max: 1000)", batchSize) } start := time.Now() var entities []T if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil { logger.LogError("Failed to get entities for sync", logger.F("batchSize", batchSize), logger.F("offset", offset), logger.F("error", err), logger.F("duration", time.Since(start))) return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) logger.LogDebug("Entities retrieved successfully for sync", logger.F("batchSize", batchSize), logger.F("offset", offset), logger.F("count", len(entities)), logger.F("duration", duration)) return entities, nil } // Exists checks if an entity exists by ID func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, error) { if err := r.validateContext(ctx); err != nil { return false, err } if err := r.validateID(id); err != nil { return false, err } start := time.Now() var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil { logger.LogError("Failed to check entity existence", logger.F("id", id), logger.F("error", err), logger.F("duration", time.Since(start))) return false, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) exists := count > 0 logger.LogDebug("Entity existence checked", logger.F("id", id), logger.F("exists", exists), logger.F("duration", duration)) return exists, nil } // BeginTx starts a new transaction func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) { if err := r.validateContext(ctx); err != nil { return nil, err } tx := r.db.WithContext(ctx).Begin() if tx.Error != nil { logger.LogError("Failed to begin transaction", logger.F("error", tx.Error)) return nil, fmt.Errorf("%w: %v", ErrTransactionFailed, tx.Error) } logger.LogDebug("Transaction started successfully") return tx, nil } // WithTx executes a function within a transaction func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { if err := r.validateContext(ctx); err != nil { return err } tx, err := r.BeginTx(ctx) if err != nil { return err } defer func() { if r := recover(); r != nil { tx.Rollback() logger.LogError("Transaction panic recovered", logger.F("panic", r)) } }() if err := fn(tx); err != nil { if rbErr := tx.Rollback().Error; rbErr != nil { logger.LogError("Failed to rollback transaction", logger.F("originalError", err), logger.F("rollbackError", rbErr)) return fmt.Errorf("transaction failed and rollback failed: %v (rollback: %v)", err, rbErr) } logger.LogDebug("Transaction rolled back due to error", logger.F("error", err)) return err } if err := tx.Commit().Error; err != nil { logger.LogError("Failed to commit transaction", logger.F("error", err)) return fmt.Errorf("%w: %v", ErrTransactionFailed, err) } logger.LogDebug("Transaction committed successfully") return nil }