test: Add tests for DeleteUser mutation and refactor errors

This commit introduces a new test suite for the `DeleteUser` GraphQL mutation in `internal/adapters/graphql/user_mutations_test.go`. The tests cover successful deletion by an admin and by the user themselves, as well as failure cases for invalid permissions, non-existent users, and invalid input.

During the implementation of these tests, an inconsistency in error handling was discovered. The `internal/data/sql` repositories were using a mix of local and domain-level errors. This has been refactored to consistently use the centralized errors defined in the `internal/domain` package. This change improves the robustness and predictability of the data layer.

The following files were modified to standardize error handling:
- internal/data/sql/base_repository.go
- internal/data/sql/book_repository.go
- internal/data/sql/category_repository.go
- internal/data/sql/copyright_repository.go
- internal/data/sql/country_repository.go
- internal/data/sql/edition_repository.go
- internal/data/sql/email_verification_repository.go
- internal/data/sql/password_reset_repository.go
- internal/data/sql/source_repository.go
- internal/data/sql/tag_repository.go
- internal/data/sql/user_profile_repository.go
- internal/data/sql/user_repository.go
- internal/data/sql/user_session_repository.go
- internal/data/sql/work_repository.go
- internal/data/sql/base_repository_test.go
- internal/data/sql/copyright_repository_test.go
This commit is contained in:
google-labs-jules[bot] 2025-10-08 19:18:21 +00:00
parent 8a214b90fa
commit 66f9d7c725
17 changed files with 211 additions and 64 deletions

View File

@ -0,0 +1,155 @@
package graphql_test
import (
"context"
"errors"
"fmt"
"os"
"testing"
"tercul/internal/adapters/graphql"
"tercul/internal/app/auth"
"tercul/internal/domain"
platform_auth "tercul/internal/platform/auth"
"tercul/internal/testutil"
"github.com/stretchr/testify/suite"
)
type UserMutationTestSuite struct {
testutil.IntegrationTestSuite
resolver graphql.MutationResolver
}
func TestUserMutations(t *testing.T) {
suite.Run(t, new(UserMutationTestSuite))
}
func (s *UserMutationTestSuite) SetupSuite() {
s.IntegrationTestSuite.SetupSuite(&testutil.TestConfig{
DBPath: "user_mutations_test.db",
})
}
func (s *UserMutationTestSuite) TearDownSuite() {
s.IntegrationTestSuite.TearDownSuite()
os.Remove("user_mutations_test.db")
}
func (s *UserMutationTestSuite) SetupTest() {
s.IntegrationTestSuite.SetupTest()
s.resolver = (&graphql.Resolver{App: s.App}).Mutation()
}
func (s *UserMutationTestSuite) TestDeleteUser() {
// Helper to create a user for tests
createUser := func(username, email, password string, role domain.UserRole) *domain.User {
resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{
Username: username,
Email: email,
Password: password,
})
s.Require().NoError(err)
user, err := s.App.User.Queries.User(context.Background(), resp.User.ID)
s.Require().NoError(err)
if role != user.Role {
user.Role = role
err = s.DB.Save(user).Error
s.Require().NoError(err)
}
return user
}
// Helper to create a context with JWT claims
contextWithClaims := func(user *domain.User) context.Context {
return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{
UserID: user.ID,
Role: string(user.Role),
})
}
s.Run("Success as admin", func() {
// Arrange
adminUser := createUser("admin_deleter", "admin_deleter@test.com", "password123", domain.UserRoleAdmin)
userToDelete := createUser("user_to_delete", "user_to_delete@test.com", "password123", domain.UserRoleReader)
ctx := contextWithClaims(adminUser)
userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID)
// Act
deleted, err := s.resolver.DeleteUser(ctx, userIDToDeleteStr)
// Assert
s.Require().NoError(err)
s.True(deleted)
// Verify user is deleted from DB
_, err = s.App.User.Queries.User(context.Background(), userToDelete.ID)
s.Error(err)
s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected user to be not found after deletion")
})
s.Run("Success as self", func() {
// Arrange
userToDelete := createUser("user_to_delete_self", "user_to_delete_self@test.com", "password123", domain.UserRoleReader)
ctx := contextWithClaims(userToDelete)
userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID)
// Act
deleted, err := s.resolver.DeleteUser(ctx, userIDToDeleteStr)
// Assert
s.Require().NoError(err)
s.True(deleted)
// Verify user is deleted from DB
_, err = s.App.User.Queries.User(context.Background(), userToDelete.ID)
s.Error(err)
s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected user to be not found after deletion")
})
s.Run("Forbidden as other user", func() {
// Arrange
otherUser := createUser("other_user_deleter", "other_user_deleter@test.com", "password123", domain.UserRoleReader)
userToDelete := createUser("user_to_be_kept", "user_to_be_kept@test.com", "password123", domain.UserRoleReader)
ctx := contextWithClaims(otherUser)
userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID)
// Act
deleted, err := s.resolver.DeleteUser(ctx, userIDToDeleteStr)
// Assert
s.Require().Error(err)
s.False(deleted)
s.True(errors.Is(err, domain.ErrForbidden))
})
s.Run("Invalid user ID", func() {
// Arrange
adminUser := createUser("admin_deleter_2", "admin_deleter_2@test.com", "password123", domain.UserRoleAdmin)
ctx := contextWithClaims(adminUser)
// Act
deleted, err := s.resolver.DeleteUser(ctx, "invalid-id")
// Assert
s.Require().Error(err)
s.False(deleted)
s.True(errors.Is(err, domain.ErrValidation))
})
s.Run("User not found", func() {
// Arrange
adminUser := createUser("admin_deleter_3", "admin_deleter_3@test.com", "password123", domain.UserRoleAdmin)
ctx := contextWithClaims(adminUser)
nonExistentID := "999999"
// Act
deleted, err := s.resolver.DeleteUser(ctx, nonExistentID)
// Assert
s.Require().Error(err)
s.False(deleted)
s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected ErrEntityNotFound for non-existent user")
})
}

View File

@ -14,15 +14,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// Common repository errors // Common repository errors are defined in the domain package.
var (
ErrEntityNotFound = errors.New("entity not found")
ErrInvalidID = errors.New("invalid ID: cannot be zero")
ErrInvalidInput = errors.New("invalid input parameters")
ErrDatabaseOperation = errors.New("database operation failed")
ErrContextRequired = errors.New("context is required")
ErrTransactionFailed = errors.New("transaction failed")
)
// BaseRepositoryImpl provides a default implementation of BaseRepository using GORM // BaseRepositoryImpl provides a default implementation of BaseRepository using GORM
type BaseRepositoryImpl[T any] struct { type BaseRepositoryImpl[T any] struct {
@ -43,7 +35,7 @@ func NewBaseRepositoryImpl[T any](db *gorm.DB, cfg *config.Config) *BaseReposito
// validateContext ensures context is not nil // validateContext ensures context is not nil
func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error {
if ctx == nil { if ctx == nil {
return ErrContextRequired return domain.ErrContextRequired
} }
return nil return nil
} }
@ -51,7 +43,7 @@ func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error {
// validateID ensures ID is valid // validateID ensures ID is valid
func (r *BaseRepositoryImpl[T]) validateID(id uint) error { func (r *BaseRepositoryImpl[T]) validateID(id uint) error {
if id == 0 { if id == 0 {
return ErrInvalidID return domain.ErrInvalidID
} }
return nil return nil
} }
@ -59,7 +51,7 @@ func (r *BaseRepositoryImpl[T]) validateID(id uint) error {
// validateEntity ensures entity is not nil // validateEntity ensures entity is not nil
func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error { func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error {
if entity == nil { if entity == nil {
return ErrInvalidInput return domain.ErrInvalidInput
} }
return nil return nil
} }
@ -133,7 +125,7 @@ func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error {
if err != nil { if err != nil {
log.Error(err, "Failed to create entity") log.Error(err, "Failed to create entity")
return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
log.Debug(fmt.Sprintf("Entity created successfully in %s", duration)) log.Debug(fmt.Sprintf("Entity created successfully in %s", duration))
@ -151,7 +143,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent
return err return err
} }
if tx == nil { if tx == nil {
return ErrTransactionFailed return domain.ErrTransactionFailed
} }
start := time.Now() start := time.Now()
@ -160,7 +152,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent
if err != nil { if err != nil {
log.Error(err, "Failed to create entity in transaction") log.Error(err, "Failed to create entity in transaction")
return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration)) log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration))
@ -186,10 +178,10 @@ func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration)) log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration))
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id)) log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id))
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration)) log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration))
@ -216,10 +208,10 @@ func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint,
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration)) log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration))
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id)) log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id))
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration)) log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration))
@ -243,7 +235,7 @@ func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error {
if err != nil { if err != nil {
log.Error(err, "Failed to update entity") log.Error(err, "Failed to update entity")
return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration)) log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration))
@ -261,7 +253,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent
return err return err
} }
if tx == nil { if tx == nil {
return ErrTransactionFailed return domain.ErrTransactionFailed
} }
start := time.Now() start := time.Now()
@ -270,7 +262,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent
if err != nil { if err != nil {
log.Error(err, "Failed to update entity in transaction") log.Error(err, "Failed to update entity in transaction")
return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration)) log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration))
@ -295,12 +287,12 @@ func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error {
if result.Error != nil { if result.Error != nil {
log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id)) log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id))
return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, result.Error)
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration)) log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration))
return ErrEntityNotFound return domain.ErrEntityNotFound
} }
log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration)) log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration))
@ -318,7 +310,7 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id
return err return err
} }
if tx == nil { if tx == nil {
return ErrTransactionFailed return domain.ErrTransactionFailed
} }
start := time.Now() start := time.Now()
@ -328,12 +320,12 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id
if result.Error != nil { if result.Error != nil {
log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id)) log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id))
return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, result.Error)
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration)) log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration))
return ErrEntityNotFound return domain.ErrEntityNotFound
} }
log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration)) log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration))
@ -360,7 +352,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (*
// Get total count // Get total count
if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil { if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil {
log.Error(err, "Failed to count entities") log.Error(err, "Failed to count entities")
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
// Calculate offset // Calculate offset
@ -369,7 +361,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (*
// Get paginated data // Get paginated data
if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil { if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil {
log.Error(err, "Failed to get paginated entities") log.Error(err, "Failed to get paginated entities")
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -410,7 +402,7 @@ func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *do
if err := query.Find(&entities).Error; err != nil { if err := query.Find(&entities).Error; err != nil {
log.Error(err, "Failed to get entities with options") log.Error(err, "Failed to get entities with options")
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -431,7 +423,7 @@ func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) {
var entities []T var entities []T
if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil { if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil {
log.Error(err, "Failed to get all entities") log.Error(err, "Failed to get all entities")
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -452,7 +444,7 @@ func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) {
var count int64 var count int64
if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil { if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil {
log.Error(err, "Failed to count entities") log.Error(err, "Failed to count entities")
return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return 0, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -475,7 +467,7 @@ func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *d
if err := query.Model(new(T)).Count(&count).Error; err != nil { if err := query.Model(new(T)).Count(&count).Error; err != nil {
log.Error(err, "Failed to count entities with options") log.Error(err, "Failed to count entities with options")
return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return 0, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -506,10 +498,10 @@ func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads []
if err := query.First(&entity, id).Error; err != nil { if err := query.First(&entity, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start))) log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start)))
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id)) log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id))
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -541,7 +533,7 @@ func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, of
var entities []T var entities []T
if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil { if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil {
log.Error(err, "Failed to get entities for sync") log.Error(err, "Failed to get entities for sync")
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -565,7 +557,7 @@ func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, erro
var count int64 var count int64
if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil { if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil {
log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id)) log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id))
return false, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return false, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
duration := time.Since(start) duration := time.Since(start)
@ -587,7 +579,7 @@ func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) {
tx := r.db.WithContext(ctx).Begin() tx := r.db.WithContext(ctx).Begin()
if tx.Error != nil { if tx.Error != nil {
log.Error(tx.Error, "Failed to begin transaction") log.Error(tx.Error, "Failed to begin transaction")
return nil, fmt.Errorf("%w: %v", ErrTransactionFailed, tx.Error) return nil, fmt.Errorf("%w: %v", domain.ErrTransactionFailed, tx.Error)
} }
log.Debug("Transaction started successfully") log.Debug("Transaction started successfully")
@ -625,9 +617,9 @@ func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB)
if err := tx.Commit().Error; err != nil { if err := tx.Commit().Error; err != nil {
log.Error(err, "Failed to commit transaction") log.Error(err, "Failed to commit transaction")
return fmt.Errorf("%w: %v", ErrTransactionFailed, err) return fmt.Errorf("%w: %v", domain.ErrTransactionFailed, err)
} }
log.Debug("Transaction committed successfully") log.Debug("Transaction committed successfully")
return nil return nil
} }

View File

@ -76,13 +76,13 @@ func (s *BaseRepositoryTestSuite) TestCreate() {
s.Run("should return error for nil entity", func() { s.Run("should return error for nil entity", func() {
err := s.repo.Create(context.Background(), nil) err := s.repo.Create(context.Background(), nil)
s.ErrorIs(err, sql.ErrInvalidInput) s.ErrorIs(err, domain.ErrInvalidInput)
}) })
s.Run("should return error for nil context", func() { s.Run("should return error for nil context", func() {
//nolint:staticcheck // Testing behavior with nil context is intentional here. //nolint:staticcheck // Testing behavior with nil context is intentional here.
err := s.repo.Create(nil, &testutil.TestEntity{Name: "Test Context"}) err := s.repo.Create(nil, &testutil.TestEntity{Name: "Test Context"})
s.ErrorIs(err, sql.ErrContextRequired) s.ErrorIs(err, domain.ErrContextRequired)
}) })
} }
@ -103,12 +103,12 @@ func (s *BaseRepositoryTestSuite) TestGetByID() {
s.Run("should return ErrEntityNotFound for non-existent ID", func() { s.Run("should return ErrEntityNotFound for non-existent ID", func() {
_, err := s.repo.GetByID(context.Background(), 99999) _, err := s.repo.GetByID(context.Background(), 99999)
s.ErrorIs(err, sql.ErrEntityNotFound) s.ErrorIs(err, domain.ErrEntityNotFound)
}) })
s.Run("should return ErrInvalidID for zero ID", func() { s.Run("should return ErrInvalidID for zero ID", func() {
_, err := s.repo.GetByID(context.Background(), 0) _, err := s.repo.GetByID(context.Background(), 0)
s.ErrorIs(err, sql.ErrInvalidID) s.ErrorIs(err, domain.ErrInvalidID)
}) })
} }
@ -140,12 +140,12 @@ func (s *BaseRepositoryTestSuite) TestDelete() {
// Assert // Assert
s.Require().NoError(err) s.Require().NoError(err)
_, getErr := s.repo.GetByID(context.Background(), created.ID) _, getErr := s.repo.GetByID(context.Background(), created.ID)
s.ErrorIs(getErr, sql.ErrEntityNotFound) s.ErrorIs(getErr, domain.ErrEntityNotFound)
}) })
s.Run("should return ErrEntityNotFound when deleting non-existent entity", func() { s.Run("should return ErrEntityNotFound when deleting non-existent entity", func() {
err := s.repo.Delete(context.Background(), 99999) err := s.repo.Delete(context.Background(), 99999)
s.ErrorIs(err, sql.ErrEntityNotFound) s.ErrorIs(err, domain.ErrEntityNotFound)
}) })
} }
@ -261,6 +261,6 @@ func (s *BaseRepositoryTestSuite) TestWithTx() {
s.ErrorIs(err, simulatedErr) s.ErrorIs(err, simulatedErr)
_, getErr := s.repo.GetByID(context.Background(), createdID) _, getErr := s.repo.GetByID(context.Background(), createdID)
s.ErrorIs(getErr, sql.ErrEntityNotFound, "Entity should not exist after rollback") s.ErrorIs(getErr, domain.ErrEntityNotFound, "Entity should not exist after rollback")
}) })
} }

View File

@ -70,7 +70,7 @@ func (r *bookRepository) FindByISBN(ctx context.Context, isbn string) (*domain.B
var book domain.Book var book domain.Book
if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&book).Error; err != nil { if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&book).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -33,7 +33,7 @@ func (r *categoryRepository) FindByName(ctx context.Context, name string) (*doma
var category domain.Category var category domain.Category
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&category).Error; err != nil { if err := r.db.WithContext(ctx).Where("name = ?", name).First(&category).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -50,7 +50,7 @@ func (r *copyrightRepository) GetTranslationByLanguage(ctx context.Context, copy
err := r.db.WithContext(ctx).Where("copyright_id = ? AND language_code = ?", copyrightID, languageCode).First(&translation).Error err := r.db.WithContext(ctx).Where("copyright_id = ? AND language_code = ?", copyrightID, languageCode).First(&translation).Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -127,7 +127,7 @@ func (s *CopyrightRepositoryTestSuite) TestGetTranslationByLanguage() {
_, err := s.repo.GetTranslationByLanguage(context.Background(), copyrightID, languageCode) _, err := s.repo.GetTranslationByLanguage(context.Background(), copyrightID, languageCode)
s.Require().Error(err) s.Require().Error(err)
s.Require().Equal(sql.ErrEntityNotFound, err) s.Require().ErrorIs(err, domain.ErrEntityNotFound)
}) })
} }

View File

@ -27,7 +27,7 @@ func (r *countryRepository) GetByCode(ctx context.Context, code string) (*domain
var country domain.Country var country domain.Country
if err := r.db.WithContext(ctx).Where("code = ?", code).First(&country).Error; err != nil { if err := r.db.WithContext(ctx).Where("code = ?", code).First(&country).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -44,7 +44,7 @@ func (r *editionRepository) FindByISBN(ctx context.Context, isbn string) (*domai
var edition domain.Edition var edition domain.Edition
if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&edition).Error; err != nil { if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&edition).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -34,7 +34,7 @@ func (r *emailVerificationRepository) GetByToken(ctx context.Context, token stri
var verification domain.EmailVerification var verification domain.EmailVerification
if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&verification).Error; err != nil { if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&verification).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -34,7 +34,7 @@ func (r *passwordResetRepository) GetByToken(ctx context.Context, token string)
var reset domain.PasswordReset var reset domain.PasswordReset
if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&reset).Error; err != nil { if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&reset).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -46,7 +46,7 @@ func (r *sourceRepository) FindByURL(ctx context.Context, url string) (*domain.S
var source domain.Source var source domain.Source
if err := r.db.WithContext(ctx).Where("url = ?", url).First(&source).Error; err != nil { if err := r.db.WithContext(ctx).Where("url = ?", url).First(&source).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -33,7 +33,7 @@ func (r *tagRepository) FindByName(ctx context.Context, name string) (*domain.Ta
var tag domain.Tag var tag domain.Tag
if err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error; err != nil { if err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -33,7 +33,7 @@ func (r *userProfileRepository) GetByUserID(ctx context.Context, userID uint) (*
var profile domain.UserProfile var profile domain.UserProfile
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&profile).Error; err != nil { if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&profile).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -33,7 +33,7 @@ func (r *userRepository) FindByUsername(ctx context.Context, username string) (*
var user domain.User var user domain.User
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil { if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }
@ -47,7 +47,7 @@ func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain
var user domain.User var user domain.User
if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil { if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }
@ -63,4 +63,4 @@ func (r *userRepository) ListByRole(ctx context.Context, role domain.UserRole) (
return nil, err return nil, err
} }
return users, nil return users, nil
} }

View File

@ -34,7 +34,7 @@ func (r *userSessionRepository) GetByToken(ctx context.Context, token string) (*
var session domain.UserSession var session domain.UserSession
if err := r.db.WithContext(ctx).Where("token = ?", token).First(&session).Error; err != nil { if err := r.db.WithContext(ctx).Where("token = ?", token).First(&session).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, err return nil, err
} }

View File

@ -185,9 +185,9 @@ func (r *workRepository) GetWithAssociationsInTx(ctx context.Context, tx *gorm.D
} }
if err := query.First(&entity, id).Error; err != nil { if err := query.First(&entity, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrEntityNotFound return nil, domain.ErrEntityNotFound
} }
return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err)
} }
return &entity, nil return &entity, nil
} }