diff --git a/internal/adapters/graphql/user_mutations_test.go b/internal/adapters/graphql/user_mutations_test.go new file mode 100644 index 0000000..622f7c9 --- /dev/null +++ b/internal/adapters/graphql/user_mutations_test.go @@ -0,0 +1,155 @@ +package graphql_test + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + "tercul/internal/adapters/graphql" + "tercul/internal/app/auth" + "tercul/internal/domain" + platform_auth "tercul/internal/platform/auth" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type UserMutationTestSuite struct { + testutil.IntegrationTestSuite + resolver graphql.MutationResolver +} + +func TestUserMutations(t *testing.T) { + suite.Run(t, new(UserMutationTestSuite)) +} + +func (s *UserMutationTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(&testutil.TestConfig{ + DBPath: "user_mutations_test.db", + }) +} + +func (s *UserMutationTestSuite) TearDownSuite() { + s.IntegrationTestSuite.TearDownSuite() + os.Remove("user_mutations_test.db") +} + +func (s *UserMutationTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + s.resolver = (&graphql.Resolver{App: s.App}).Mutation() +} + +func (s *UserMutationTestSuite) TestDeleteUser() { + // Helper to create a user for tests + createUser := func(username, email, password string, role domain.UserRole) *domain.User { + resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ + Username: username, + Email: email, + Password: password, + }) + s.Require().NoError(err) + + user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) + s.Require().NoError(err) + + if role != user.Role { + user.Role = role + err = s.DB.Save(user).Error + s.Require().NoError(err) + } + return user + } + + // Helper to create a context with JWT claims + contextWithClaims := func(user *domain.User) context.Context { + return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: user.ID, + Role: string(user.Role), + }) + } + + s.Run("Success as admin", func() { + // Arrange + adminUser := createUser("admin_deleter", "admin_deleter@test.com", "password123", domain.UserRoleAdmin) + userToDelete := createUser("user_to_delete", "user_to_delete@test.com", "password123", domain.UserRoleReader) + ctx := contextWithClaims(adminUser) + userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID) + + // Act + deleted, err := s.resolver.DeleteUser(ctx, userIDToDeleteStr) + + // Assert + s.Require().NoError(err) + s.True(deleted) + + // Verify user is deleted from DB + _, err = s.App.User.Queries.User(context.Background(), userToDelete.ID) + s.Error(err) + s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected user to be not found after deletion") + }) + + s.Run("Success as self", func() { + // Arrange + userToDelete := createUser("user_to_delete_self", "user_to_delete_self@test.com", "password123", domain.UserRoleReader) + ctx := contextWithClaims(userToDelete) + userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID) + + // Act + deleted, err := s.resolver.DeleteUser(ctx, userIDToDeleteStr) + + // Assert + s.Require().NoError(err) + s.True(deleted) + + // Verify user is deleted from DB + _, err = s.App.User.Queries.User(context.Background(), userToDelete.ID) + s.Error(err) + s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected user to be not found after deletion") + }) + + s.Run("Forbidden as other user", func() { + // Arrange + otherUser := createUser("other_user_deleter", "other_user_deleter@test.com", "password123", domain.UserRoleReader) + userToDelete := createUser("user_to_be_kept", "user_to_be_kept@test.com", "password123", domain.UserRoleReader) + ctx := contextWithClaims(otherUser) + userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID) + + // Act + deleted, err := s.resolver.DeleteUser(ctx, userIDToDeleteStr) + + // Assert + s.Require().Error(err) + s.False(deleted) + s.True(errors.Is(err, domain.ErrForbidden)) + }) + + s.Run("Invalid user ID", func() { + // Arrange + adminUser := createUser("admin_deleter_2", "admin_deleter_2@test.com", "password123", domain.UserRoleAdmin) + ctx := contextWithClaims(adminUser) + + // Act + deleted, err := s.resolver.DeleteUser(ctx, "invalid-id") + + // Assert + s.Require().Error(err) + s.False(deleted) + s.True(errors.Is(err, domain.ErrValidation)) + }) + + s.Run("User not found", func() { + // Arrange + adminUser := createUser("admin_deleter_3", "admin_deleter_3@test.com", "password123", domain.UserRoleAdmin) + ctx := contextWithClaims(adminUser) + nonExistentID := "999999" + + // Act + deleted, err := s.resolver.DeleteUser(ctx, nonExistentID) + + // Assert + s.Require().Error(err) + s.False(deleted) + s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected ErrEntityNotFound for non-existent user") + }) +} \ No newline at end of file diff --git a/internal/data/sql/base_repository.go b/internal/data/sql/base_repository.go index e406bd2..cd0303a 100644 --- a/internal/data/sql/base_repository.go +++ b/internal/data/sql/base_repository.go @@ -14,15 +14,7 @@ import ( "gorm.io/gorm" ) -// Common repository errors -var ( - ErrEntityNotFound = errors.New("entity not found") - ErrInvalidID = errors.New("invalid ID: cannot be zero") - ErrInvalidInput = errors.New("invalid input parameters") - ErrDatabaseOperation = errors.New("database operation failed") - ErrContextRequired = errors.New("context is required") - ErrTransactionFailed = errors.New("transaction failed") -) +// Common repository errors are defined in the domain package. // BaseRepositoryImpl provides a default implementation of BaseRepository using GORM type BaseRepositoryImpl[T any] struct { @@ -43,7 +35,7 @@ func NewBaseRepositoryImpl[T any](db *gorm.DB, cfg *config.Config) *BaseReposito // validateContext ensures context is not nil func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { if ctx == nil { - return ErrContextRequired + return domain.ErrContextRequired } return nil } @@ -51,7 +43,7 @@ func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { // validateID ensures ID is valid func (r *BaseRepositoryImpl[T]) validateID(id uint) error { if id == 0 { - return ErrInvalidID + return domain.ErrInvalidID } return nil } @@ -59,7 +51,7 @@ func (r *BaseRepositoryImpl[T]) validateID(id uint) error { // validateEntity ensures entity is not nil func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error { if entity == nil { - return ErrInvalidInput + return domain.ErrInvalidInput } return nil } @@ -133,7 +125,7 @@ func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error { if err != nil { log.Error(err, "Failed to create entity") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity created successfully in %s", duration)) @@ -151,7 +143,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent return err } if tx == nil { - return ErrTransactionFailed + return domain.ErrTransactionFailed } start := time.Now() @@ -160,7 +152,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent if err != nil { log.Error(err, "Failed to create entity in transaction") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration)) @@ -186,10 +178,10 @@ func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration)) - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id)) - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration)) @@ -216,10 +208,10 @@ func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint, if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration)) - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id)) - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration)) @@ -243,7 +235,7 @@ func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error { if err != nil { log.Error(err, "Failed to update entity") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration)) @@ -261,7 +253,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent return err } if tx == nil { - return ErrTransactionFailed + return domain.ErrTransactionFailed } start := time.Now() @@ -270,7 +262,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent if err != nil { log.Error(err, "Failed to update entity in transaction") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration)) @@ -295,12 +287,12 @@ func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error { if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id)) - return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) + return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration)) - return ErrEntityNotFound + return domain.ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration)) @@ -318,7 +310,7 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id return err } if tx == nil { - return ErrTransactionFailed + return domain.ErrTransactionFailed } start := time.Now() @@ -328,12 +320,12 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id)) - return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) + return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration)) - return ErrEntityNotFound + return domain.ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration)) @@ -360,7 +352,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (* // Get total count if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil { log.Error(err, "Failed to count entities") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } // Calculate offset @@ -369,7 +361,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (* // Get paginated data if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get paginated entities") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -410,7 +402,7 @@ func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *do if err := query.Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities with options") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -431,7 +423,7 @@ func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) { var entities []T if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil { log.Error(err, "Failed to get all entities") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -452,7 +444,7 @@ func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) { var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities") - return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return 0, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -475,7 +467,7 @@ func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *d if err := query.Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities with options") - return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return 0, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -506,10 +498,10 @@ func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads [] if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start))) - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id)) - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -541,7 +533,7 @@ func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, of var entities []T if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities for sync") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -565,7 +557,7 @@ func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, erro var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil { log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id)) - return false, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return false, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } duration := time.Since(start) @@ -587,7 +579,7 @@ func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) { tx := r.db.WithContext(ctx).Begin() if tx.Error != nil { log.Error(tx.Error, "Failed to begin transaction") - return nil, fmt.Errorf("%w: %v", ErrTransactionFailed, tx.Error) + return nil, fmt.Errorf("%w: %v", domain.ErrTransactionFailed, tx.Error) } log.Debug("Transaction started successfully") @@ -625,9 +617,9 @@ func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) if err := tx.Commit().Error; err != nil { log.Error(err, "Failed to commit transaction") - return fmt.Errorf("%w: %v", ErrTransactionFailed, err) + return fmt.Errorf("%w: %v", domain.ErrTransactionFailed, err) } log.Debug("Transaction committed successfully") return nil -} +} \ No newline at end of file diff --git a/internal/data/sql/base_repository_test.go b/internal/data/sql/base_repository_test.go index 56d0be8..dc0ab14 100644 --- a/internal/data/sql/base_repository_test.go +++ b/internal/data/sql/base_repository_test.go @@ -76,13 +76,13 @@ func (s *BaseRepositoryTestSuite) TestCreate() { s.Run("should return error for nil entity", func() { err := s.repo.Create(context.Background(), nil) - s.ErrorIs(err, sql.ErrInvalidInput) + s.ErrorIs(err, domain.ErrInvalidInput) }) s.Run("should return error for nil context", func() { //nolint:staticcheck // Testing behavior with nil context is intentional here. err := s.repo.Create(nil, &testutil.TestEntity{Name: "Test Context"}) - s.ErrorIs(err, sql.ErrContextRequired) + s.ErrorIs(err, domain.ErrContextRequired) }) } @@ -103,12 +103,12 @@ func (s *BaseRepositoryTestSuite) TestGetByID() { s.Run("should return ErrEntityNotFound for non-existent ID", func() { _, err := s.repo.GetByID(context.Background(), 99999) - s.ErrorIs(err, sql.ErrEntityNotFound) + s.ErrorIs(err, domain.ErrEntityNotFound) }) s.Run("should return ErrInvalidID for zero ID", func() { _, err := s.repo.GetByID(context.Background(), 0) - s.ErrorIs(err, sql.ErrInvalidID) + s.ErrorIs(err, domain.ErrInvalidID) }) } @@ -140,12 +140,12 @@ func (s *BaseRepositoryTestSuite) TestDelete() { // Assert s.Require().NoError(err) _, getErr := s.repo.GetByID(context.Background(), created.ID) - s.ErrorIs(getErr, sql.ErrEntityNotFound) + s.ErrorIs(getErr, domain.ErrEntityNotFound) }) s.Run("should return ErrEntityNotFound when deleting non-existent entity", func() { err := s.repo.Delete(context.Background(), 99999) - s.ErrorIs(err, sql.ErrEntityNotFound) + s.ErrorIs(err, domain.ErrEntityNotFound) }) } @@ -261,6 +261,6 @@ func (s *BaseRepositoryTestSuite) TestWithTx() { s.ErrorIs(err, simulatedErr) _, getErr := s.repo.GetByID(context.Background(), createdID) - s.ErrorIs(getErr, sql.ErrEntityNotFound, "Entity should not exist after rollback") + s.ErrorIs(getErr, domain.ErrEntityNotFound, "Entity should not exist after rollback") }) } \ No newline at end of file diff --git a/internal/data/sql/book_repository.go b/internal/data/sql/book_repository.go index 5b17765..c838239 100644 --- a/internal/data/sql/book_repository.go +++ b/internal/data/sql/book_repository.go @@ -70,7 +70,7 @@ func (r *bookRepository) FindByISBN(ctx context.Context, isbn string) (*domain.B var book domain.Book if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&book).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/category_repository.go b/internal/data/sql/category_repository.go index c656dd0..f9f0262 100644 --- a/internal/data/sql/category_repository.go +++ b/internal/data/sql/category_repository.go @@ -33,7 +33,7 @@ func (r *categoryRepository) FindByName(ctx context.Context, name string) (*doma var category domain.Category if err := r.db.WithContext(ctx).Where("name = ?", name).First(&category).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/copyright_repository.go b/internal/data/sql/copyright_repository.go index cd4a301..4b063e3 100644 --- a/internal/data/sql/copyright_repository.go +++ b/internal/data/sql/copyright_repository.go @@ -50,7 +50,7 @@ func (r *copyrightRepository) GetTranslationByLanguage(ctx context.Context, copy err := r.db.WithContext(ctx).Where("copyright_id = ? AND language_code = ?", copyrightID, languageCode).First(&translation).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/copyright_repository_test.go b/internal/data/sql/copyright_repository_test.go index 78468b7..b30dce4 100644 --- a/internal/data/sql/copyright_repository_test.go +++ b/internal/data/sql/copyright_repository_test.go @@ -127,7 +127,7 @@ func (s *CopyrightRepositoryTestSuite) TestGetTranslationByLanguage() { _, err := s.repo.GetTranslationByLanguage(context.Background(), copyrightID, languageCode) s.Require().Error(err) - s.Require().Equal(sql.ErrEntityNotFound, err) + s.Require().ErrorIs(err, domain.ErrEntityNotFound) }) } diff --git a/internal/data/sql/country_repository.go b/internal/data/sql/country_repository.go index 0c12e6d..48c2be6 100644 --- a/internal/data/sql/country_repository.go +++ b/internal/data/sql/country_repository.go @@ -27,7 +27,7 @@ func (r *countryRepository) GetByCode(ctx context.Context, code string) (*domain var country domain.Country if err := r.db.WithContext(ctx).Where("code = ?", code).First(&country).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/edition_repository.go b/internal/data/sql/edition_repository.go index 57e28bc..2093706 100644 --- a/internal/data/sql/edition_repository.go +++ b/internal/data/sql/edition_repository.go @@ -44,7 +44,7 @@ func (r *editionRepository) FindByISBN(ctx context.Context, isbn string) (*domai var edition domain.Edition if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&edition).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/email_verification_repository.go b/internal/data/sql/email_verification_repository.go index 31d8326..511d470 100644 --- a/internal/data/sql/email_verification_repository.go +++ b/internal/data/sql/email_verification_repository.go @@ -34,7 +34,7 @@ func (r *emailVerificationRepository) GetByToken(ctx context.Context, token stri var verification domain.EmailVerification if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&verification).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/password_reset_repository.go b/internal/data/sql/password_reset_repository.go index 6a81174..32c4d37 100644 --- a/internal/data/sql/password_reset_repository.go +++ b/internal/data/sql/password_reset_repository.go @@ -34,7 +34,7 @@ func (r *passwordResetRepository) GetByToken(ctx context.Context, token string) var reset domain.PasswordReset if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&reset).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/source_repository.go b/internal/data/sql/source_repository.go index 4702c8c..9adabdf 100644 --- a/internal/data/sql/source_repository.go +++ b/internal/data/sql/source_repository.go @@ -46,7 +46,7 @@ func (r *sourceRepository) FindByURL(ctx context.Context, url string) (*domain.S var source domain.Source if err := r.db.WithContext(ctx).Where("url = ?", url).First(&source).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/tag_repository.go b/internal/data/sql/tag_repository.go index f61e975..9dc32fe 100644 --- a/internal/data/sql/tag_repository.go +++ b/internal/data/sql/tag_repository.go @@ -33,7 +33,7 @@ func (r *tagRepository) FindByName(ctx context.Context, name string) (*domain.Ta var tag domain.Tag if err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/user_profile_repository.go b/internal/data/sql/user_profile_repository.go index 351adeb..f1ad4bb 100644 --- a/internal/data/sql/user_profile_repository.go +++ b/internal/data/sql/user_profile_repository.go @@ -33,7 +33,7 @@ func (r *userProfileRepository) GetByUserID(ctx context.Context, userID uint) (* var profile domain.UserProfile if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&profile).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/user_repository.go b/internal/data/sql/user_repository.go index a6bed79..53acfea 100644 --- a/internal/data/sql/user_repository.go +++ b/internal/data/sql/user_repository.go @@ -33,7 +33,7 @@ func (r *userRepository) FindByUsername(ctx context.Context, username string) (* var user domain.User if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -47,7 +47,7 @@ func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain var user domain.User if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -63,4 +63,4 @@ func (r *userRepository) ListByRole(ctx context.Context, role domain.UserRole) ( return nil, err } return users, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/user_session_repository.go b/internal/data/sql/user_session_repository.go index a431822..0bd74d3 100644 --- a/internal/data/sql/user_session_repository.go +++ b/internal/data/sql/user_session_repository.go @@ -34,7 +34,7 @@ func (r *userSessionRepository) GetByToken(ctx context.Context, token string) (* var session domain.UserSession if err := r.db.WithContext(ctx).Where("token = ?", token).First(&session).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/work_repository.go b/internal/data/sql/work_repository.go index 2eaa1b3..4364746 100644 --- a/internal/data/sql/work_repository.go +++ b/internal/data/sql/work_repository.go @@ -185,9 +185,9 @@ func (r *workRepository) GetWithAssociationsInTx(ctx context.Context, tx *gorm.D } if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) } return &entity, nil }