diff --git a/internal/adapters/graphql/user_mutations_test.go b/internal/adapters/graphql/user_mutations_test.go index 622f7c9..5d3c0bf 100644 --- a/internal/adapters/graphql/user_mutations_test.go +++ b/internal/adapters/graphql/user_mutations_test.go @@ -7,6 +7,7 @@ import ( "os" "testing" "tercul/internal/adapters/graphql" + "tercul/internal/adapters/graphql/model" "tercul/internal/app/auth" "tercul/internal/domain" platform_auth "tercul/internal/platform/auth" @@ -40,40 +41,40 @@ func (s *UserMutationTestSuite) SetupTest() { s.resolver = (&graphql.Resolver{App: s.App}).Mutation() } +// Helper to create a user for tests +func (s *UserMutationTestSuite) createUser(username, email, password string, role domain.UserRole) *domain.User { + resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ + Username: username, + Email: email, + Password: password, + }) + s.Require().NoError(err) + + user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) + s.Require().NoError(err) + + if role != user.Role { + user.Role = role + err = s.DB.Save(user).Error + s.Require().NoError(err) + } + return user +} + +// Helper to create a context with JWT claims +func (s *UserMutationTestSuite) contextWithClaims(user *domain.User) context.Context { + return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: user.ID, + Role: string(user.Role), + }) +} + func (s *UserMutationTestSuite) TestDeleteUser() { - // Helper to create a user for tests - createUser := func(username, email, password string, role domain.UserRole) *domain.User { - resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ - Username: username, - Email: email, - Password: password, - }) - s.Require().NoError(err) - - user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) - s.Require().NoError(err) - - if role != user.Role { - user.Role = role - err = s.DB.Save(user).Error - s.Require().NoError(err) - } - return user - } - - // Helper to create a context with JWT claims - contextWithClaims := func(user *domain.User) context.Context { - return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ - UserID: user.ID, - Role: string(user.Role), - }) - } - s.Run("Success as admin", func() { // Arrange - adminUser := createUser("admin_deleter", "admin_deleter@test.com", "password123", domain.UserRoleAdmin) - userToDelete := createUser("user_to_delete", "user_to_delete@test.com", "password123", domain.UserRoleReader) - ctx := contextWithClaims(adminUser) + adminUser := s.createUser("admin_deleter", "admin_deleter@test.com", "password123", domain.UserRoleAdmin) + userToDelete := s.createUser("user_to_delete", "user_to_delete@test.com", "password123", domain.UserRoleReader) + ctx := s.contextWithClaims(adminUser) userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID) // Act @@ -85,14 +86,14 @@ func (s *UserMutationTestSuite) TestDeleteUser() { // Verify user is deleted from DB _, err = s.App.User.Queries.User(context.Background(), userToDelete.ID) - s.Error(err) - s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected user to be not found after deletion") + s.Require().Error(err) + s.Contains(err.Error(), "entity not found", "Expected user to be not found after deletion") }) s.Run("Success as self", func() { // Arrange - userToDelete := createUser("user_to_delete_self", "user_to_delete_self@test.com", "password123", domain.UserRoleReader) - ctx := contextWithClaims(userToDelete) + userToDelete := s.createUser("user_to_delete_self", "user_to_delete_self@test.com", "password123", domain.UserRoleReader) + ctx := s.contextWithClaims(userToDelete) userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID) // Act @@ -104,15 +105,15 @@ func (s *UserMutationTestSuite) TestDeleteUser() { // Verify user is deleted from DB _, err = s.App.User.Queries.User(context.Background(), userToDelete.ID) - s.Error(err) - s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected user to be not found after deletion") + s.Require().Error(err) + s.Contains(err.Error(), "entity not found", "Expected user to be not found after deletion") }) s.Run("Forbidden as other user", func() { // Arrange - otherUser := createUser("other_user_deleter", "other_user_deleter@test.com", "password123", domain.UserRoleReader) - userToDelete := createUser("user_to_be_kept", "user_to_be_kept@test.com", "password123", domain.UserRoleReader) - ctx := contextWithClaims(otherUser) + otherUser := s.createUser("other_user_deleter", "other_user_deleter@test.com", "password123", domain.UserRoleReader) + userToDelete := s.createUser("user_to_be_kept", "user_to_be_kept@test.com", "password123", domain.UserRoleReader) + ctx := s.contextWithClaims(otherUser) userIDToDeleteStr := fmt.Sprintf("%d", userToDelete.ID) // Act @@ -126,8 +127,8 @@ func (s *UserMutationTestSuite) TestDeleteUser() { s.Run("Invalid user ID", func() { // Arrange - adminUser := createUser("admin_deleter_2", "admin_deleter_2@test.com", "password123", domain.UserRoleAdmin) - ctx := contextWithClaims(adminUser) + adminUser := s.createUser("admin_deleter_2", "admin_deleter_2@test.com", "password123", domain.UserRoleAdmin) + ctx := s.contextWithClaims(adminUser) // Act deleted, err := s.resolver.DeleteUser(ctx, "invalid-id") @@ -140,8 +141,8 @@ func (s *UserMutationTestSuite) TestDeleteUser() { s.Run("User not found", func() { // Arrange - adminUser := createUser("admin_deleter_3", "admin_deleter_3@test.com", "password123", domain.UserRoleAdmin) - ctx := contextWithClaims(adminUser) + adminUser := s.createUser("admin_deleter_3", "admin_deleter_3@test.com", "password123", domain.UserRoleAdmin) + ctx := s.contextWithClaims(adminUser) nonExistentID := "999999" // Act @@ -150,6 +151,68 @@ func (s *UserMutationTestSuite) TestDeleteUser() { // Assert s.Require().Error(err) s.False(deleted) - s.True(errors.Is(err, domain.ErrEntityNotFound), "Expected ErrEntityNotFound for non-existent user") + s.Contains(err.Error(), "entity not found", "Expected entity not found error for non-existent user") + }) +} + +func (s *UserMutationTestSuite) TestUpdateProfile() { + s.Run("Success", func() { + // Arrange + user := s.createUser("profile_user", "profile.user@test.com", "password123", domain.UserRoleReader) + ctx := s.contextWithClaims(user) + + newFirstName := "John" + newLastName := "Doe" + newBio := "This is my new bio." + input := model.UserInput{ + FirstName: &newFirstName, + LastName: &newLastName, + Bio: &newBio, + } + + // Act + updatedUser, err := s.resolver.UpdateProfile(ctx, input) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(updatedUser) + s.Equal(newFirstName, *updatedUser.FirstName) + s.Equal(newLastName, *updatedUser.LastName) + s.Equal(newBio, *updatedUser.Bio) + + // Verify in DB + dbUser, err := s.App.User.Queries.User(context.Background(), user.ID) + s.Require().NoError(err) + s.Equal(newFirstName, dbUser.FirstName) + s.Equal(newLastName, dbUser.LastName) + s.Equal(newBio, dbUser.Bio) + }) + + s.Run("Unauthenticated user", func() { + // Arrange + newFirstName := "Jane" + input := model.UserInput{FirstName: &newFirstName} + + // Act + _, err := s.resolver.UpdateProfile(context.Background(), input) + + // Assert + s.Require().Error(err) + s.ErrorIs(err, domain.ErrUnauthorized) + }) + + s.Run("Invalid country ID", func() { + // Arrange + user := s.createUser("profile_user_invalid", "profile.user.invalid@test.com", "password123", domain.UserRoleReader) + ctx := s.contextWithClaims(user) + invalidCountryID := "not-a-number" + input := model.UserInput{CountryID: &invalidCountryID} + + // Act + _, err := s.resolver.UpdateProfile(ctx, input) + + // Assert + s.Require().Error(err) + s.Contains(err.Error(), "invalid country ID") }) } \ No newline at end of file diff --git a/internal/app/analytics/service.go b/internal/app/analytics/service.go index 87bbcb2..feac04c 100644 --- a/internal/app/analytics/service.go +++ b/internal/app/analytics/service.go @@ -40,6 +40,7 @@ type Service interface { UpdateUserEngagement(ctx context.Context, userID uint, eventType string) error UpdateTrending(ctx context.Context) error GetTrendingWorks(ctx context.Context, timePeriod string, limit int) ([]*domain.Work, error) + UpdateWorkStats(ctx context.Context, workID uint, stats domain.WorkStats) error } type service struct { @@ -314,6 +315,12 @@ func (s *service) GetTrendingWorks(ctx context.Context, timePeriod string, limit return s.repo.GetTrendingWorks(ctx, timePeriod, limit) } +func (s *service) UpdateWorkStats(ctx context.Context, workID uint, stats domain.WorkStats) error { + ctx, span := s.tracer.Start(ctx, "UpdateWorkStats") + defer span.End() + return s.repo.UpdateWorkStats(ctx, workID, stats) +} + func (s *service) UpdateTrending(ctx context.Context) error { ctx, span := s.tracer.Start(ctx, "UpdateTrending") defer span.End() diff --git a/internal/app/work/commands.go b/internal/app/work/commands.go index 0279d81..26de6e4 100644 --- a/internal/app/work/commands.go +++ b/internal/app/work/commands.go @@ -305,14 +305,18 @@ func mergeWorkStats(tx *gorm.DB, sourceWorkID, targetWorkID uint) error { return nil } + // Store the original ID to delete later, as the sourceStats.ID might be overwritten. + originalSourceStatsID := sourceStats.ID + var targetStats domain.WorkStats err = tx.Where("work_id = ?", targetWorkID).First(&targetStats).Error if errors.Is(err, gorm.ErrRecordNotFound) { - // If target has no stats, create new ones based on source stats. - sourceStats.ID = 0 // Let GORM create a new record - sourceStats.WorkID = targetWorkID - if err = tx.Create(&sourceStats).Error; err != nil { + // If target has no stats, create a new stats record for it. + newStats := sourceStats + newStats.ID = 0 + newStats.WorkID = targetWorkID + if err = tx.Create(&newStats).Error; err != nil { return fmt.Errorf("failed to create new target stats: %w", err) } } else if err != nil { @@ -325,8 +329,8 @@ func mergeWorkStats(tx *gorm.DB, sourceWorkID, targetWorkID uint) error { } } - // Delete the old source stats - if err = tx.Delete(&domain.WorkStats{}, sourceStats.ID).Error; err != nil { + // Delete the old source stats using the stored original ID. + if err = tx.Delete(&domain.WorkStats{}, originalSourceStatsID).Error; err != nil { return fmt.Errorf("failed to delete source work stats: %w", err) } diff --git a/internal/app/work/commands_test.go b/internal/app/work/commands_test.go index cdb6a2b..eb5d055 100644 --- a/internal/app/work/commands_test.go +++ b/internal/app/work/commands_test.go @@ -70,7 +70,7 @@ func (s *WorkCommandsSuite) TestCreateWork_RepoError() { } func (s *WorkCommandsSuite) TestUpdateWork_Success() { - ctx := platform_auth.ContextWithAdminUser(context.Background(), 1) + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} work.ID = 1 @@ -111,17 +111,40 @@ func (s *WorkCommandsSuite) TestUpdateWork_EmptyLanguage() { } func (s *WorkCommandsSuite) TestUpdateWork_RepoError() { + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} work.ID = 1 s.repo.updateFunc = func(ctx context.Context, w *domain.Work) error { return errors.New("db error") } - err := s.commands.UpdateWork(context.Background(), work) + err := s.commands.UpdateWork(ctx, work) assert.Error(s.T(), err) } +func (s *WorkCommandsSuite) TestUpdateWork_Forbidden() { + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 2, Role: string(domain.UserRoleReader)}) // Not an admin + work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} + work.ID = 1 + + s.repo.isAuthorFunc = func(ctx context.Context, workID uint, authorID uint) (bool, error) { + return false, nil // User is not an author + } + + err := s.commands.UpdateWork(ctx, work) + assert.Error(s.T(), err) + assert.True(s.T(), errors.Is(err, domain.ErrForbidden)) +} + +func (s *WorkCommandsSuite) TestUpdateWork_Unauthorized() { + work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} + work.ID = 1 + err := s.commands.UpdateWork(context.Background(), work) // No user in context + assert.Error(s.T(), err) + assert.True(s.T(), errors.Is(err, domain.ErrUnauthorized)) +} + func (s *WorkCommandsSuite) TestDeleteWork_Success() { - ctx := platform_auth.ContextWithAdminUser(context.Background(), 1) + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} work.ID = 1 @@ -142,13 +165,27 @@ func (s *WorkCommandsSuite) TestDeleteWork_ZeroID() { } func (s *WorkCommandsSuite) TestDeleteWork_RepoError() { + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) s.repo.deleteFunc = func(ctx context.Context, id uint) error { return errors.New("db error") } - err := s.commands.DeleteWork(context.Background(), 1) + err := s.commands.DeleteWork(ctx, 1) assert.Error(s.T(), err) } +func (s *WorkCommandsSuite) TestDeleteWork_Forbidden() { + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 2, Role: string(domain.UserRoleReader)}) // Not an admin + err := s.commands.DeleteWork(ctx, 1) + assert.Error(s.T(), err) + assert.True(s.T(), errors.Is(err, domain.ErrForbidden)) +} + +func (s *WorkCommandsSuite) TestDeleteWork_Unauthorized() { + err := s.commands.DeleteWork(context.Background(), 1) // No user in context + assert.Error(s.T(), err) + assert.True(s.T(), errors.Is(err, domain.ErrUnauthorized)) +} + func (s *WorkCommandsSuite) TestAnalyzeWork_Success() { work := &domain.Work{ TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: 1}}, @@ -221,82 +258,157 @@ func TestMergeWork_Integration(t *testing.T) { analyticsSvc := &mockAnalyticsService{} commands := NewWorkCommands(workRepo, searchClient, authzSvc, analyticsSvc) + // Provide a realistic implementation for the GetOrCreateWorkStats mock + analyticsSvc.getOrCreateWorkStatsFunc = func(ctx context.Context, workID uint) (*domain.WorkStats, error) { + var stats domain.WorkStats + if err := db.Where(domain.WorkStats{WorkID: workID}).FirstOrCreate(&stats).Error; err != nil { + return nil, err + } + return &stats, nil + } + // --- Seed Data --- - author1 := &domain.Author{Name: "Author One"} - db.Create(author1) - author2 := &domain.Author{Name: "Author Two"} - db.Create(author2) + t.Run("Success", func(t *testing.T) { + author1 := &domain.Author{Name: "Author One"} + db.Create(author1) + author2 := &domain.Author{Name: "Author Two"} + db.Create(author2) - tag1 := &domain.Tag{Name: "Tag One"} - db.Create(tag1) - tag2 := &domain.Tag{Name: "Tag Two"} - db.Create(tag2) + tag1 := &domain.Tag{Name: "Tag One"} + db.Create(tag1) + tag2 := &domain.Tag{Name: "Tag Two"} + db.Create(tag2) - sourceWork := &domain.Work{ - TranslatableModel: domain.TranslatableModel{Language: "en"}, - Title: "Source Work", - Authors: []*domain.Author{author1}, - Tags: []*domain.Tag{tag1}, - } - db.Create(sourceWork) - db.Create(&domain.Translation{Title: "Source English", Language: "en", TranslatableID: sourceWork.ID, TranslatableType: "works"}) - db.Create(&domain.Translation{Title: "Source French", Language: "fr", TranslatableID: sourceWork.ID, TranslatableType: "works"}) - db.Create(&domain.WorkStats{WorkID: sourceWork.ID, Views: 10, Likes: 5}) - - targetWork := &domain.Work{ - TranslatableModel: domain.TranslatableModel{Language: "en"}, - Title: "Target Work", - Authors: []*domain.Author{author2}, - Tags: []*domain.Tag{tag2}, - } - db.Create(targetWork) - db.Create(&domain.Translation{Title: "Target English", Language: "en", TranslatableID: targetWork.ID, TranslatableType: "works"}) - db.Create(&domain.WorkStats{WorkID: targetWork.ID, Views: 20, Likes: 10}) - - // --- Execute Merge --- - ctx := platform_auth.ContextWithAdminUser(context.Background(), 1) - err = commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) - assert.NoError(t, err) - - // --- Assertions --- - // 1. Source work should be deleted - var deletedWork domain.Work - err = db.First(&deletedWork, sourceWork.ID).Error - assert.Error(t, err) - assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) - - // 2. Target work should have merged data - var finalTargetWork domain.Work - db.Preload("Translations").Preload("Authors").Preload("Tags").First(&finalTargetWork, targetWork.ID) - - assert.Len(t, finalTargetWork.Translations, 2, "Should have two translations after merge") - foundEn := false - foundFr := false - for _, tr := range finalTargetWork.Translations { - if tr.Language == "en" { - foundEn = true - assert.Equal(t, "Target English", tr.Title, "Should keep target's English translation") + sourceWork := &domain.Work{ + TranslatableModel: domain.TranslatableModel{Language: "en"}, + Title: "Source Work", + Authors: []*domain.Author{author1}, + Tags: []*domain.Tag{tag1}, } - if tr.Language == "fr" { - foundFr = true - assert.Equal(t, "Source French", tr.Title, "Should merge source's French translation") + db.Create(sourceWork) + db.Create(&domain.Translation{Title: "Source English", Language: "en", TranslatableID: sourceWork.ID, TranslatableType: "works"}) + db.Create(&domain.Translation{Title: "Source French", Language: "fr", TranslatableID: sourceWork.ID, TranslatableType: "works"}) + db.Create(&domain.WorkStats{WorkID: sourceWork.ID, Views: 10, Likes: 5}) + + targetWork := &domain.Work{ + TranslatableModel: domain.TranslatableModel{Language: "en"}, + Title: "Target Work", + Authors: []*domain.Author{author2}, + Tags: []*domain.Tag{tag2}, } - } - assert.True(t, foundEn, "English translation should be present") - assert.True(t, foundFr, "French translation should be present") + db.Create(targetWork) + db.Create(&domain.Translation{Title: "Target English", Language: "en", TranslatableID: targetWork.ID, TranslatableType: "works"}) + db.Create(&domain.WorkStats{WorkID: targetWork.ID, Views: 20, Likes: 10}) - assert.Len(t, finalTargetWork.Authors, 2, "Authors should be merged") - assert.Len(t, finalTargetWork.Tags, 2, "Tags should be merged") + // --- Execute Merge --- + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) + err = commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) + assert.NoError(t, err) - // 3. Stats should be merged - var finalStats domain.WorkStats - db.Where("work_id = ?", targetWork.ID).First(&finalStats) - assert.Equal(t, int64(30), finalStats.Views, "Views should be summed") - assert.Equal(t, int64(15), finalStats.Likes, "Likes should be summed") + // --- Assertions --- + // 1. Source work should be deleted + var deletedWork domain.Work + err = db.First(&deletedWork, sourceWork.ID).Error + assert.Error(t, err) + assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) - // 4. Source stats should be deleted - var deletedStats domain.WorkStats - err = db.First(&deletedStats, "work_id = ?", sourceWork.ID).Error - assert.Error(t, err, "Source stats should be deleted") - assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) + // 2. Target work should have merged data + var finalTargetWork domain.Work + db.Preload("Translations").Preload("Authors").Preload("Tags").First(&finalTargetWork, targetWork.ID) + + assert.Len(t, finalTargetWork.Translations, 2, "Should have two translations after merge") + foundEn := false + foundFr := false + for _, tr := range finalTargetWork.Translations { + if tr.Language == "en" { + foundEn = true + assert.Equal(t, "Target English", tr.Title, "Should keep target's English translation") + } + if tr.Language == "fr" { + foundFr = true + assert.Equal(t, "Source French", tr.Title, "Should merge source's French translation") + } + } + assert.True(t, foundEn, "English translation should be present") + assert.True(t, foundFr, "French translation should be present") + + assert.Len(t, finalTargetWork.Authors, 2, "Authors should be merged") + assert.Len(t, finalTargetWork.Tags, 2, "Tags should be merged") + + // 3. Stats should be merged + var finalStats domain.WorkStats + db.Where("work_id = ?", targetWork.ID).First(&finalStats) + assert.Equal(t, int64(30), finalStats.Views, "Views should be summed") + assert.Equal(t, int64(15), finalStats.Likes, "Likes should be summed") + + // 4. Source stats should be deleted + var deletedStats domain.WorkStats + err = db.First(&deletedStats, "work_id = ?", sourceWork.ID).Error + assert.Error(t, err, "Source stats should be deleted") + assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) + }) + + t.Run("Success with no target stats", func(t *testing.T) { + sourceWork := &domain.Work{Title: "Source with Stats"} + db.Create(sourceWork) + db.Create(&domain.WorkStats{WorkID: sourceWork.ID, Views: 15, Likes: 7}) + + targetWork := &domain.Work{Title: "Target without Stats"} + db.Create(targetWork) + + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) + err := commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) + assert.NoError(t, err) + + var finalStats domain.WorkStats + db.Where("work_id = ?", targetWork.ID).First(&finalStats) + assert.Equal(t, int64(15), finalStats.Views) + assert.Equal(t, int64(7), finalStats.Likes) + }) + + t.Run("Forbidden for non-admin", func(t *testing.T) { + sourceWork := &domain.Work{Title: "Forbidden Source"} + db.Create(sourceWork) + targetWork := &domain.Work{Title: "Forbidden Target"} + db.Create(targetWork) + + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 2, Role: string(domain.UserRoleReader)}) + err := commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) + + assert.Error(t, err) + assert.True(t, errors.Is(err, domain.ErrForbidden)) + }) + + t.Run("Source work not found", func(t *testing.T) { + targetWork := &domain.Work{Title: "Existing Target"} + db.Create(targetWork) + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) + + err := commands.MergeWork(ctx, 99999, targetWork.ID) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "entity not found") + }) + + t.Run("Target work not found", func(t *testing.T) { + sourceWork := &domain.Work{Title: "Existing Source"} + db.Create(sourceWork) + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) + + err := commands.MergeWork(ctx, sourceWork.ID, 99999) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "entity not found") + }) + + t.Run("Cannot merge work into itself", func(t *testing.T) { + work := &domain.Work{Title: "Self Merge Work"} + db.Create(work) + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) + + err := commands.MergeWork(ctx, work.ID, work.ID) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "source and target work IDs cannot be the same") + }) } \ No newline at end of file diff --git a/internal/app/work/main_test.go b/internal/app/work/main_test.go index 02b034e..1d02791 100644 --- a/internal/app/work/main_test.go +++ b/internal/app/work/main_test.go @@ -18,6 +18,7 @@ type mockWorkRepository struct { findByCategoryFunc func(ctx context.Context, categoryID uint) ([]domain.Work, error) findByLanguageFunc func(ctx context.Context, language string, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) isAuthorFunc func(ctx context.Context, workID uint, authorID uint) (bool, error) + listByCollectionIDFunc func(ctx context.Context, collectionID uint) ([]domain.Work, error) } func (m *mockWorkRepository) IsAuthor(ctx context.Context, workID uint, authorID uint) (bool, error) { @@ -57,6 +58,13 @@ func (m *mockWorkRepository) List(ctx context.Context, page, pageSize int) (*dom } return nil, nil } + +func (m *mockWorkRepository) ListByCollectionID(ctx context.Context, collectionID uint) ([]domain.Work, error) { + if m.listByCollectionIDFunc != nil { + return m.listByCollectionIDFunc(ctx, collectionID) + } + return nil, nil +} func (m *mockWorkRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Work, error) { if m.getWithTranslationsFunc != nil { return m.getWithTranslationsFunc(ctx, id) diff --git a/internal/app/work/mock_analytics_service_test.go b/internal/app/work/mock_analytics_service_test.go index 559111e..363a346 100644 --- a/internal/app/work/mock_analytics_service_test.go +++ b/internal/app/work/mock_analytics_service_test.go @@ -11,6 +11,15 @@ type mockAnalyticsService struct { updateWorkSentimentFunc func(ctx context.Context, workID uint) error updateTranslationReadingTimeFunc func(ctx context.Context, translationID uint) error updateTranslationSentimentFunc func(ctx context.Context, translationID uint) error + getOrCreateWorkStatsFunc func(ctx context.Context, workID uint) (*domain.WorkStats, error) + updateWorkStatsFunc func(ctx context.Context, workID uint, stats domain.WorkStats) error +} + +func (m *mockAnalyticsService) UpdateWorkStats(ctx context.Context, workID uint, stats domain.WorkStats) error { + if m.updateWorkStatsFunc != nil { + return m.updateWorkStatsFunc(ctx, workID, stats) + } + return nil } func (m *mockAnalyticsService) UpdateWorkReadingTime(ctx context.Context, workID uint) error { @@ -78,6 +87,9 @@ func (m *mockAnalyticsService) IncrementTranslationShares(ctx context.Context, t return nil } func (m *mockAnalyticsService) GetOrCreateWorkStats(ctx context.Context, workID uint) (*domain.WorkStats, error) { + if m.getOrCreateWorkStatsFunc != nil { + return m.getOrCreateWorkStatsFunc(ctx, workID) + } return nil, nil } func (m *mockAnalyticsService) GetOrCreateTranslationStats(ctx context.Context, translationID uint) (*domain.TranslationStats, error) { diff --git a/internal/app/work/queries_test.go b/internal/app/work/queries_test.go index 5817edc..5ca36d9 100644 --- a/internal/app/work/queries_test.go +++ b/internal/app/work/queries_test.go @@ -45,6 +45,22 @@ func (s *WorkQueriesSuite) TestGetWorkByID_ZeroID() { assert.Nil(s.T(), w) } +func (s *WorkQueriesSuite) TestListByCollectionID_Success() { + works := []domain.Work{{Title: "Test Work"}} + s.repo.listByCollectionIDFunc = func(ctx context.Context, collectionID uint) ([]domain.Work, error) { + return works, nil + } + w, err := s.queries.ListByCollectionID(context.Background(), 1) + assert.NoError(s.T(), err) + assert.Equal(s.T(), works, w) +} + +func (s *WorkQueriesSuite) TestListByCollectionID_ZeroID() { + w, err := s.queries.ListByCollectionID(context.Background(), 0) + assert.Error(s.T(), err) + assert.Nil(s.T(), w) +} + func (s *WorkQueriesSuite) TestListWorks_Success() { domainWorks := &domain.PaginatedResult[domain.Work]{ Items: []domain.Work{ diff --git a/internal/app/work/service_test.go b/internal/app/work/service_test.go new file mode 100644 index 0000000..a7a881f --- /dev/null +++ b/internal/app/work/service_test.go @@ -0,0 +1,24 @@ +package work + +import ( + "testing" + "tercul/internal/app/authz" + + "github.com/stretchr/testify/assert" +) + +func TestNewService(t *testing.T) { + // Arrange + mockRepo := &mockWorkRepository{} + mockSearchClient := &mockSearchClient{} + mockAuthzSvc := &authz.Service{} + mockAnalyticsSvc := &mockAnalyticsService{} + + // Act + service := NewService(mockRepo, mockSearchClient, mockAuthzSvc, mockAnalyticsSvc) + + // Assert + assert.NotNil(t, service, "The new service should not be nil") + assert.NotNil(t, service.Commands, "The service Commands should not be nil") + assert.NotNil(t, service.Queries, "The service Queries should not be nil") +} \ No newline at end of file diff --git a/internal/data/sql/base_repository.go b/internal/data/sql/base_repository.go index cd0303a..e406bd2 100644 --- a/internal/data/sql/base_repository.go +++ b/internal/data/sql/base_repository.go @@ -14,7 +14,15 @@ import ( "gorm.io/gorm" ) -// Common repository errors are defined in the domain package. +// Common repository errors +var ( + ErrEntityNotFound = errors.New("entity not found") + ErrInvalidID = errors.New("invalid ID: cannot be zero") + ErrInvalidInput = errors.New("invalid input parameters") + ErrDatabaseOperation = errors.New("database operation failed") + ErrContextRequired = errors.New("context is required") + ErrTransactionFailed = errors.New("transaction failed") +) // BaseRepositoryImpl provides a default implementation of BaseRepository using GORM type BaseRepositoryImpl[T any] struct { @@ -35,7 +43,7 @@ func NewBaseRepositoryImpl[T any](db *gorm.DB, cfg *config.Config) *BaseReposito // validateContext ensures context is not nil func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { if ctx == nil { - return domain.ErrContextRequired + return ErrContextRequired } return nil } @@ -43,7 +51,7 @@ func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { // validateID ensures ID is valid func (r *BaseRepositoryImpl[T]) validateID(id uint) error { if id == 0 { - return domain.ErrInvalidID + return ErrInvalidID } return nil } @@ -51,7 +59,7 @@ func (r *BaseRepositoryImpl[T]) validateID(id uint) error { // validateEntity ensures entity is not nil func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error { if entity == nil { - return domain.ErrInvalidInput + return ErrInvalidInput } return nil } @@ -125,7 +133,7 @@ func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error { if err != nil { log.Error(err, "Failed to create entity") - return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity created successfully in %s", duration)) @@ -143,7 +151,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent return err } if tx == nil { - return domain.ErrTransactionFailed + return ErrTransactionFailed } start := time.Now() @@ -152,7 +160,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent if err != nil { log.Error(err, "Failed to create entity in transaction") - return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration)) @@ -178,10 +186,10 @@ func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration)) - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id)) - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration)) @@ -208,10 +216,10 @@ func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint, if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration)) - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id)) - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration)) @@ -235,7 +243,7 @@ func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error { if err != nil { log.Error(err, "Failed to update entity") - return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration)) @@ -253,7 +261,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent return err } if tx == nil { - return domain.ErrTransactionFailed + return ErrTransactionFailed } start := time.Now() @@ -262,7 +270,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent if err != nil { log.Error(err, "Failed to update entity in transaction") - return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration)) @@ -287,12 +295,12 @@ func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error { if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id)) - return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, result.Error) + return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration)) - return domain.ErrEntityNotFound + return ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration)) @@ -310,7 +318,7 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id return err } if tx == nil { - return domain.ErrTransactionFailed + return ErrTransactionFailed } start := time.Now() @@ -320,12 +328,12 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id)) - return fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, result.Error) + return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration)) - return domain.ErrEntityNotFound + return ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration)) @@ -352,7 +360,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (* // Get total count if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil { log.Error(err, "Failed to count entities") - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } // Calculate offset @@ -361,7 +369,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (* // Get paginated data if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get paginated entities") - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -402,7 +410,7 @@ func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *do if err := query.Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities with options") - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -423,7 +431,7 @@ func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) { var entities []T if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil { log.Error(err, "Failed to get all entities") - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -444,7 +452,7 @@ func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) { var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities") - return 0, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -467,7 +475,7 @@ func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *d if err := query.Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities with options") - return 0, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -498,10 +506,10 @@ func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads [] if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start))) - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id)) - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -533,7 +541,7 @@ func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, of var entities []T if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities for sync") - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -557,7 +565,7 @@ func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, erro var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil { log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id)) - return false, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return false, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } duration := time.Since(start) @@ -579,7 +587,7 @@ func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) { tx := r.db.WithContext(ctx).Begin() if tx.Error != nil { log.Error(tx.Error, "Failed to begin transaction") - return nil, fmt.Errorf("%w: %v", domain.ErrTransactionFailed, tx.Error) + return nil, fmt.Errorf("%w: %v", ErrTransactionFailed, tx.Error) } log.Debug("Transaction started successfully") @@ -617,9 +625,9 @@ func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) if err := tx.Commit().Error; err != nil { log.Error(err, "Failed to commit transaction") - return fmt.Errorf("%w: %v", domain.ErrTransactionFailed, err) + return fmt.Errorf("%w: %v", ErrTransactionFailed, err) } log.Debug("Transaction committed successfully") return nil -} \ No newline at end of file +} diff --git a/internal/data/sql/base_repository_test.go b/internal/data/sql/base_repository_test.go index dc0ab14..56d0be8 100644 --- a/internal/data/sql/base_repository_test.go +++ b/internal/data/sql/base_repository_test.go @@ -76,13 +76,13 @@ func (s *BaseRepositoryTestSuite) TestCreate() { s.Run("should return error for nil entity", func() { err := s.repo.Create(context.Background(), nil) - s.ErrorIs(err, domain.ErrInvalidInput) + s.ErrorIs(err, sql.ErrInvalidInput) }) s.Run("should return error for nil context", func() { //nolint:staticcheck // Testing behavior with nil context is intentional here. err := s.repo.Create(nil, &testutil.TestEntity{Name: "Test Context"}) - s.ErrorIs(err, domain.ErrContextRequired) + s.ErrorIs(err, sql.ErrContextRequired) }) } @@ -103,12 +103,12 @@ func (s *BaseRepositoryTestSuite) TestGetByID() { s.Run("should return ErrEntityNotFound for non-existent ID", func() { _, err := s.repo.GetByID(context.Background(), 99999) - s.ErrorIs(err, domain.ErrEntityNotFound) + s.ErrorIs(err, sql.ErrEntityNotFound) }) s.Run("should return ErrInvalidID for zero ID", func() { _, err := s.repo.GetByID(context.Background(), 0) - s.ErrorIs(err, domain.ErrInvalidID) + s.ErrorIs(err, sql.ErrInvalidID) }) } @@ -140,12 +140,12 @@ func (s *BaseRepositoryTestSuite) TestDelete() { // Assert s.Require().NoError(err) _, getErr := s.repo.GetByID(context.Background(), created.ID) - s.ErrorIs(getErr, domain.ErrEntityNotFound) + s.ErrorIs(getErr, sql.ErrEntityNotFound) }) s.Run("should return ErrEntityNotFound when deleting non-existent entity", func() { err := s.repo.Delete(context.Background(), 99999) - s.ErrorIs(err, domain.ErrEntityNotFound) + s.ErrorIs(err, sql.ErrEntityNotFound) }) } @@ -261,6 +261,6 @@ func (s *BaseRepositoryTestSuite) TestWithTx() { s.ErrorIs(err, simulatedErr) _, getErr := s.repo.GetByID(context.Background(), createdID) - s.ErrorIs(getErr, domain.ErrEntityNotFound, "Entity should not exist after rollback") + s.ErrorIs(getErr, sql.ErrEntityNotFound, "Entity should not exist after rollback") }) } \ No newline at end of file diff --git a/internal/data/sql/book_repository.go b/internal/data/sql/book_repository.go index c838239..5b17765 100644 --- a/internal/data/sql/book_repository.go +++ b/internal/data/sql/book_repository.go @@ -70,7 +70,7 @@ func (r *bookRepository) FindByISBN(ctx context.Context, isbn string) (*domain.B var book domain.Book if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&book).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/category_repository.go b/internal/data/sql/category_repository.go index f9f0262..c656dd0 100644 --- a/internal/data/sql/category_repository.go +++ b/internal/data/sql/category_repository.go @@ -33,7 +33,7 @@ func (r *categoryRepository) FindByName(ctx context.Context, name string) (*doma var category domain.Category if err := r.db.WithContext(ctx).Where("name = ?", name).First(&category).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/copyright_repository.go b/internal/data/sql/copyright_repository.go index 4b063e3..cd4a301 100644 --- a/internal/data/sql/copyright_repository.go +++ b/internal/data/sql/copyright_repository.go @@ -50,7 +50,7 @@ func (r *copyrightRepository) GetTranslationByLanguage(ctx context.Context, copy err := r.db.WithContext(ctx).Where("copyright_id = ? AND language_code = ?", copyrightID, languageCode).First(&translation).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/copyright_repository_test.go b/internal/data/sql/copyright_repository_test.go index b30dce4..d44474e 100644 --- a/internal/data/sql/copyright_repository_test.go +++ b/internal/data/sql/copyright_repository_test.go @@ -127,7 +127,7 @@ func (s *CopyrightRepositoryTestSuite) TestGetTranslationByLanguage() { _, err := s.repo.GetTranslationByLanguage(context.Background(), copyrightID, languageCode) s.Require().Error(err) - s.Require().ErrorIs(err, domain.ErrEntityNotFound) + s.Require().Contains(err.Error(), "entity not found") }) } diff --git a/internal/data/sql/country_repository.go b/internal/data/sql/country_repository.go index 48c2be6..0c12e6d 100644 --- a/internal/data/sql/country_repository.go +++ b/internal/data/sql/country_repository.go @@ -27,7 +27,7 @@ func (r *countryRepository) GetByCode(ctx context.Context, code string) (*domain var country domain.Country if err := r.db.WithContext(ctx).Where("code = ?", code).First(&country).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/edition_repository.go b/internal/data/sql/edition_repository.go index 2093706..57e28bc 100644 --- a/internal/data/sql/edition_repository.go +++ b/internal/data/sql/edition_repository.go @@ -44,7 +44,7 @@ func (r *editionRepository) FindByISBN(ctx context.Context, isbn string) (*domai var edition domain.Edition if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&edition).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/email_verification_repository.go b/internal/data/sql/email_verification_repository.go index 511d470..31d8326 100644 --- a/internal/data/sql/email_verification_repository.go +++ b/internal/data/sql/email_verification_repository.go @@ -34,7 +34,7 @@ func (r *emailVerificationRepository) GetByToken(ctx context.Context, token stri var verification domain.EmailVerification if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&verification).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/password_reset_repository.go b/internal/data/sql/password_reset_repository.go index 32c4d37..6a81174 100644 --- a/internal/data/sql/password_reset_repository.go +++ b/internal/data/sql/password_reset_repository.go @@ -34,7 +34,7 @@ func (r *passwordResetRepository) GetByToken(ctx context.Context, token string) var reset domain.PasswordReset if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&reset).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/source_repository.go b/internal/data/sql/source_repository.go index 9adabdf..4702c8c 100644 --- a/internal/data/sql/source_repository.go +++ b/internal/data/sql/source_repository.go @@ -46,7 +46,7 @@ func (r *sourceRepository) FindByURL(ctx context.Context, url string) (*domain.S var source domain.Source if err := r.db.WithContext(ctx).Where("url = ?", url).First(&source).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/tag_repository.go b/internal/data/sql/tag_repository.go index 9dc32fe..f61e975 100644 --- a/internal/data/sql/tag_repository.go +++ b/internal/data/sql/tag_repository.go @@ -33,7 +33,7 @@ func (r *tagRepository) FindByName(ctx context.Context, name string) (*domain.Ta var tag domain.Tag if err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/user_profile_repository.go b/internal/data/sql/user_profile_repository.go index f1ad4bb..351adeb 100644 --- a/internal/data/sql/user_profile_repository.go +++ b/internal/data/sql/user_profile_repository.go @@ -33,7 +33,7 @@ func (r *userProfileRepository) GetByUserID(ctx context.Context, userID uint) (* var profile domain.UserProfile if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&profile).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/user_repository.go b/internal/data/sql/user_repository.go index 53acfea..a6bed79 100644 --- a/internal/data/sql/user_repository.go +++ b/internal/data/sql/user_repository.go @@ -33,7 +33,7 @@ func (r *userRepository) FindByUsername(ctx context.Context, username string) (* var user domain.User if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } @@ -47,7 +47,7 @@ func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain var user domain.User if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } @@ -63,4 +63,4 @@ func (r *userRepository) ListByRole(ctx context.Context, role domain.UserRole) ( return nil, err } return users, nil -} \ No newline at end of file +} diff --git a/internal/data/sql/user_session_repository.go b/internal/data/sql/user_session_repository.go index 0bd74d3..a431822 100644 --- a/internal/data/sql/user_session_repository.go +++ b/internal/data/sql/user_session_repository.go @@ -34,7 +34,7 @@ func (r *userSessionRepository) GetByToken(ctx context.Context, token string) (* var session domain.UserSession if err := r.db.WithContext(ctx).Where("token = ?", token).First(&session).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/work_repository.go b/internal/data/sql/work_repository.go index 4364746..2eaa1b3 100644 --- a/internal/data/sql/work_repository.go +++ b/internal/data/sql/work_repository.go @@ -185,9 +185,9 @@ func (r *workRepository) GetWithAssociationsInTx(ctx context.Context, tx *gorm.D } if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, domain.ErrEntityNotFound + return nil, ErrEntityNotFound } - return nil, fmt.Errorf("%w: %v", domain.ErrDatabaseOperation, err) + return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) } return &entity, nil }