From c2e9a118e24492110f5a8f8c6361cda4d590727a Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 9 Oct 2025 07:03:45 +0000 Subject: [PATCH] feat(testing): Increase test coverage and fix authz bugs This commit significantly increases the test coverage across the application and fixes several underlying bugs that were discovered while writing the new tests. The key changes include: - **New Tests:** Added extensive integration and unit tests for GraphQL resolvers, application services, and data repositories, substantially increasing the test coverage for packages like `graphql`, `user`, `translation`, and `analytics`. - **Authorization Bug Fixes:** - Fixed a critical bug where a user creating a `Work` was not correctly associated as its author, causing subsequent permission failures. - Corrected the authorization logic in `authz.Service` to properly check for entity ownership by non-admin users. - **Test Refactoring:** - Refactored numerous test suites to use `testify/mock` instead of manual mocks, improving test clarity and maintainability. - Isolated integration tests by creating a fresh admin user and token for each test run, eliminating test pollution. - Centralized domain errors into `internal/domain/errors.go` and updated repositories to use them, making error handling more consistent. - **Code Quality Improvements:** - Replaced manual mock implementations with `testify/mock` for better consistency. - Cleaned up redundant and outdated test files. These changes stabilize the test suite, improve the overall quality of the codebase, and move the project closer to the goal of 80% test coverage. --- cmd/api/main.go | 4 +- .../adapters/graphql/author_resolvers_test.go | 131 ++++ .../adapters/graphql/book_resolvers_test.go | 199 +++++ internal/adapters/graphql/integration_test.go | 70 +- internal/adapters/graphql/schema.resolvers.go | 9 +- .../graphql/translation_resolvers_test.go | 197 +++++ .../graphql/user_resolvers_unit_test.go | 637 ++++++++++++++++ .../adapters/graphql/work_resolvers_test.go | 260 +++++++ .../graphql/work_resolvers_unit_test.go | 455 +++++++++++ internal/app/analytics/service_test.go | 26 +- internal/app/authz/authz.go | 55 +- .../copyright/commands_integration_test.go | 4 +- .../monetization/commands_integration_test.go | 2 +- internal/app/translation/commands.go | 8 +- internal/app/translation/commands_test.go | 176 ++++- internal/app/user/commands_test.go | 236 +++++- internal/app/user/main_test.go | 154 +++- internal/app/user/queries_test.go | 278 +++++++ internal/app/work/commands.go | 53 +- internal/app/work/commands_test.go | 414 ---------- internal/app/work/main_test.go | 710 ++++++++++++++++-- .../app/work/mock_analytics_service_test.go | 104 +-- internal/app/work/queries_test.go | 42 +- internal/app/work/service.go | 4 +- internal/app/work/service_test.go | 24 - .../data/sql/analytics_repository_test.go | 87 +++ internal/data/sql/author_repository.go | 81 +- internal/data/sql/author_repository_test.go | 89 ++- internal/data/sql/base_repository.go | 72 +- internal/data/sql/base_repository_test.go | 17 +- internal/data/sql/book_repository.go | 12 +- internal/data/sql/book_repository_test.go | 92 +++ internal/data/sql/category_repository.go | 12 +- internal/data/sql/category_repository_test.go | 3 +- internal/data/sql/copyright_repository.go | 12 +- internal/data/sql/country_repository.go | 10 +- internal/data/sql/edition_repository.go | 12 +- .../data/sql/email_verification_repository.go | 10 +- internal/data/sql/like_repository_test.go | 129 ++++ .../data/sql/monetization_repository_test.go | 3 +- .../data/sql/password_reset_repository.go | 10 +- internal/data/sql/source_repository.go | 12 +- internal/data/sql/tag_repository.go | 12 +- internal/data/sql/user_profile_repository.go | 12 +- internal/data/sql/user_repository.go | 14 +- internal/data/sql/user_repository_test.go | 114 +++ internal/data/sql/user_session_repository.go | 12 +- internal/data/sql/work_repository.go | 4 +- internal/data/sql/work_repository_test.go | 13 +- internal/domain/errors.go | 13 +- internal/domain/interfaces.go | 1 + .../linguistics/analysis_repository_test.go | 2 +- internal/testutil/integration_test_utils.go | 57 +- internal/testutil/mock_user_repository.go | 191 ++--- 54 files changed, 4285 insertions(+), 1075 deletions(-) create mode 100644 internal/adapters/graphql/author_resolvers_test.go create mode 100644 internal/adapters/graphql/book_resolvers_test.go create mode 100644 internal/adapters/graphql/translation_resolvers_test.go create mode 100644 internal/adapters/graphql/user_resolvers_unit_test.go create mode 100644 internal/adapters/graphql/work_resolvers_test.go create mode 100644 internal/adapters/graphql/work_resolvers_unit_test.go create mode 100644 internal/app/user/queries_test.go delete mode 100644 internal/app/work/commands_test.go delete mode 100644 internal/app/work/service_test.go create mode 100644 internal/data/sql/like_repository_test.go create mode 100644 internal/data/sql/user_repository_test.go diff --git a/cmd/api/main.go b/cmd/api/main.go index 406eeda..06bd7be 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -132,7 +132,7 @@ func main() { analyticsService := analytics.NewService(repos.Analytics, analysisRepo, repos.Translation, repos.Work, sentimentProvider) localizationService := localization.NewService(repos.Localization) searchService := appsearch.NewService(searchClient, localizationService) - authzService := authz.NewService(repos.Work, repos.Translation) + authzService := authz.NewService(repos.Work, repos.Author, repos.User, repos.Translation) authorService := author.NewService(repos.Author) bookService := book.NewService(repos.Book, authzService) bookmarkService := bookmark.NewService(repos.Bookmark, analyticsService) @@ -146,7 +146,7 @@ func main() { translationService := translation.NewService(repos.Translation, authzService) userService := user.NewService(repos.User, authzService, repos.UserProfile) authService := auth.NewService(repos.User, jwtManager) - workService := work.NewService(repos.Work, searchClient, authzService, analyticsService) + workService := work.NewService(repos.Work, repos.Author, repos.User, searchClient, authzService, analyticsService) // Create application application := app.NewApplication( diff --git a/internal/adapters/graphql/author_resolvers_test.go b/internal/adapters/graphql/author_resolvers_test.go new file mode 100644 index 0000000..4654421 --- /dev/null +++ b/internal/adapters/graphql/author_resolvers_test.go @@ -0,0 +1,131 @@ +package graphql_test + +import ( + "context" + "os" + "testing" + "tercul/internal/adapters/graphql" + "tercul/internal/adapters/graphql/model" + "tercul/internal/app/auth" + "tercul/internal/domain" + platform_auth "tercul/internal/platform/auth" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type AuthorResolversTestSuite struct { + testutil.IntegrationTestSuite + queryResolver graphql.QueryResolver + mutationResolver graphql.MutationResolver +} + +func TestAuthorResolvers(t *testing.T) { + suite.Run(t, new(AuthorResolversTestSuite)) +} + +func (s *AuthorResolversTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(&testutil.TestConfig{ + DBPath: "author_resolvers_test.db", + }) +} + +func (s *AuthorResolversTestSuite) TearDownSuite() { + s.IntegrationTestSuite.TearDownSuite() + os.Remove("author_resolvers_test.db") +} + +func (s *AuthorResolversTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + resolver := &graphql.Resolver{App: s.App} + s.queryResolver = resolver.Query() + s.mutationResolver = resolver.Mutation() +} + +// Helper to create a user for tests +func (s *AuthorResolversTestSuite) createUser(username, email, password string, role domain.UserRole) *domain.User { + resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ + Username: username, + Email: email, + Password: password, + }) + s.Require().NoError(err) + + user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) + s.Require().NoError(err) + + if role != user.Role { + user.Role = role + err = s.DB.Save(user).Error + s.Require().NoError(err) + } + return user +} + +// Helper to create a context with JWT claims +func (s *AuthorResolversTestSuite) contextWithClaims(user *domain.User) context.Context { + return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: user.ID, + Role: string(user.Role), + }) +} + +func (s *AuthorResolversTestSuite) TestAuthorMutations() { + user := s.createUser("author-creator", "author-creator@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + + var authorID string + + s.Run("Create Author", func() { + input := model.AuthorInput{ + Name: "J.R.R. Tolkien", + } + author, err := s.mutationResolver.CreateAuthor(ctx, input) + s.Require().NoError(err) + s.Require().NotNil(author) + s.Equal("J.R.R. Tolkien", author.Name) + authorID = author.ID + }) + + s.Run("Update Author", func() { + input := model.AuthorInput{ + Name: "John Ronald Reuel Tolkien", + } + author, err := s.mutationResolver.UpdateAuthor(s.AdminCtx, authorID, input) + s.Require().NoError(err) + s.Require().NotNil(author) + s.Equal("John Ronald Reuel Tolkien", author.Name) + }) + + s.Run("Delete Author", func() { + ok, err := s.mutationResolver.DeleteAuthor(s.AdminCtx, authorID) + s.Require().NoError(err) + s.True(ok) + }) +} + +func (s *AuthorResolversTestSuite) TestAuthorQueries() { + user := s.createUser("author-reader", "author-reader@test.com", "password", domain.UserRoleReader) + ctx := s.contextWithClaims(user) + + // Create an author to query + input := model.AuthorInput{ + Name: "George Orwell", + } + createdAuthor, err := s.mutationResolver.CreateAuthor(ctx, input) + s.Require().NoError(err) + + s.Run("Get Author by ID", func() { + author, err := s.queryResolver.Author(ctx, createdAuthor.ID) + s.Require().NoError(err) + s.Require().NotNil(author) + s.Equal("George Orwell", author.Name) + }) + + s.Run("List Authors", func() { + authors, err := s.queryResolver.Authors(ctx, nil, nil, nil, nil) + s.Require().NoError(err) + s.Require().NotNil(authors) + s.True(len(authors) >= 1) + }) +} \ No newline at end of file diff --git a/internal/adapters/graphql/book_resolvers_test.go b/internal/adapters/graphql/book_resolvers_test.go new file mode 100644 index 0000000..a09f780 --- /dev/null +++ b/internal/adapters/graphql/book_resolvers_test.go @@ -0,0 +1,199 @@ +package graphql_test + +import ( + "context" + "os" + "testing" + "tercul/internal/adapters/graphql" + "tercul/internal/adapters/graphql/model" + "tercul/internal/app/auth" + "tercul/internal/domain" + platform_auth "tercul/internal/platform/auth" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type BookResolversTestSuite struct { + testutil.IntegrationTestSuite + queryResolver graphql.QueryResolver + mutationResolver graphql.MutationResolver +} + +func TestBookResolvers(t *testing.T) { + suite.Run(t, new(BookResolversTestSuite)) +} + +func (s *BookResolversTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(&testutil.TestConfig{ + DBPath: "book_resolvers_test.db", + }) +} + +func (s *BookResolversTestSuite) TearDownSuite() { + s.IntegrationTestSuite.TearDownSuite() + os.Remove("book_resolvers_test.db") +} + +func (s *BookResolversTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + resolver := &graphql.Resolver{App: s.App} + s.queryResolver = resolver.Query() + s.mutationResolver = resolver.Mutation() +} + +// Helper to create a user for tests +func (s *BookResolversTestSuite) createUser(username, email, password string, role domain.UserRole) *domain.User { + resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ + Username: username, + Email: email, + Password: password, + }) + s.Require().NoError(err) + + user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) + s.Require().NoError(err) + + if role != user.Role { + user.Role = role + err = s.DB.Save(user).Error + s.Require().NoError(err) + } + return user +} + +// Helper to create a context with JWT claims +func (s *BookResolversTestSuite) contextWithClaims(user *domain.User) context.Context { + return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: user.ID, + Role: string(user.Role), + }) +} + +func (s *BookResolversTestSuite) TestCreateBook() { + user := s.createUser("book-creator", "book-creator@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + + s.Run("Success", func() { + // Arrange + description := "A test book description." + isbn := "978-0321765723" + input := model.BookInput{ + Name: "My First Book", + Language: "en", + Description: &description, + Isbn: &isbn, + } + + // Act + book, err := s.mutationResolver.CreateBook(ctx, input) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(book) + s.Equal("My First Book", book.Name) + s.Equal("en", book.Language) + s.Equal(description, *book.Description) + s.Equal(isbn, *book.Isbn) + }) +} + +func (s *BookResolversTestSuite) TestUpdateBook() { + user := s.createUser("book-updater", "book-updater@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + + // Create a book to update + description := "Initial description" + isbn := "978-1491904244" + createInput := model.BookInput{ + Name: "Updatable Book", + Language: "en", + Description: &description, + Isbn: &isbn, + } + createdBook, err := s.mutationResolver.CreateBook(ctx, createInput) + s.Require().NoError(err) + + s.Run("Success", func() { + // Arrange + updatedDescription := "Updated description" + updateInput := model.BookInput{ + Name: "Updated Book Title", + Language: "en", + Description: &updatedDescription, + Isbn: &isbn, + } + + // Act + updatedBook, err := s.mutationResolver.UpdateBook(s.AdminCtx, createdBook.ID, updateInput) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(updatedBook) + s.Equal("Updated Book Title", updatedBook.Name) + s.Equal(updatedDescription, *updatedBook.Description) + }) +} + +func (s *BookResolversTestSuite) TestDeleteBook() { + user := s.createUser("book-deletor", "book-deletor@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + + // Create a book to delete + description := "Deletable description" + isbn := "978-1491904245" + createInput := model.BookInput{ + Name: "Deletable Book", + Language: "en", + Description: &description, + Isbn: &isbn, + } + createdBook, err := s.mutationResolver.CreateBook(ctx, createInput) + s.Require().NoError(err) + + s.Run("Success", func() { + // Act + ok, err := s.mutationResolver.DeleteBook(s.AdminCtx, createdBook.ID) + + // Assert + s.Require().NoError(err) + s.True(ok) + }) +} + +func (s *BookResolversTestSuite) TestBookQueries() { + user := s.createUser("book-reader", "book-reader@test.com", "password", domain.UserRoleReader) + ctx := s.contextWithClaims(user) + + // Create a book to query + description := "Queryable description" + isbn := "978-1491904246" + createInput := model.BookInput{ + Name: "Queryable Book", + Language: "en", + Description: &description, + Isbn: &isbn, + } + createdBook, err := s.mutationResolver.CreateBook(ctx, createInput) + s.Require().NoError(err) + + s.Run("Get Book by ID", func() { + // Act + book, err := s.queryResolver.Book(ctx, createdBook.ID) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(book) + s.Equal("Queryable Book", book.Name) + }) + + s.Run("List Books", func() { + // Act + books, err := s.queryResolver.Books(ctx, nil, nil) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(books) + s.True(len(books) >= 1) + }) +} \ No newline at end of file diff --git a/internal/adapters/graphql/integration_test.go b/internal/adapters/graphql/integration_test.go index 5ff09d4..1617f1b 100644 --- a/internal/adapters/graphql/integration_test.go +++ b/internal/adapters/graphql/integration_test.go @@ -126,7 +126,7 @@ type GetWorkResponse struct { // TestQueryWork tests the work query func (s *GraphQLIntegrationSuite) TestQueryWork() { // Create a test work with content - work := s.CreateTestWork("Test Work", "en", "Test content for work") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content for work") // Define the query query := ` @@ -169,9 +169,9 @@ type GetWorksResponse struct { // TestQueryWorks tests the works query func (s *GraphQLIntegrationSuite) TestQueryWorks() { // Create test works - s.CreateTestWork("Test Work 1", "en", "Test content for work 1") - s.CreateTestWork("Test Work 2", "en", "Test content for work 2") - s.CreateTestWork("Test Work 3", "fr", "Test content for work 3") + s.CreateTestWork(s.AdminCtx, "Test Work 1", "en", "Test content for work 1") + s.CreateTestWork(s.AdminCtx, "Test Work 2", "en", "Test content for work 2") + s.CreateTestWork(s.AdminCtx, "Test Work 3", "fr", "Test content for work 3") // Define the query query := ` @@ -250,8 +250,7 @@ func (s *GraphQLIntegrationSuite) TestCreateWork() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[CreateWorkResponse](s, mutation, variables, &adminToken) + response, err := executeGraphQL[CreateWorkResponse](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().Nil(response.Errors, "GraphQL mutation should not return errors") @@ -356,8 +355,7 @@ func (s *GraphQLIntegrationSuite) TestCreateWorkValidation() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().NotNil(response.Errors, "GraphQL mutation should return errors") @@ -368,7 +366,7 @@ func (s *GraphQLIntegrationSuite) TestCreateWorkValidation() { func (s *GraphQLIntegrationSuite) TestUpdateWorkValidation() { s.Run("should return error for invalid input", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Define the mutation mutation := ` @@ -389,8 +387,7 @@ func (s *GraphQLIntegrationSuite) TestUpdateWorkValidation() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().NotNil(response.Errors, "GraphQL mutation should return errors") @@ -418,8 +415,7 @@ func (s *GraphQLIntegrationSuite) TestCreateAuthorValidation() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().NotNil(response.Errors, "GraphQL mutation should return errors") @@ -452,8 +448,7 @@ func (s *GraphQLIntegrationSuite) TestUpdateAuthorValidation() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().NotNil(response.Errors, "GraphQL mutation should return errors") @@ -464,7 +459,7 @@ func (s *GraphQLIntegrationSuite) TestUpdateAuthorValidation() { func (s *GraphQLIntegrationSuite) TestCreateTranslationValidation() { s.Run("should return error for invalid input", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Define the mutation mutation := ` @@ -485,8 +480,7 @@ func (s *GraphQLIntegrationSuite) TestCreateTranslationValidation() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().NotNil(response.Errors, "GraphQL mutation should return errors") @@ -497,7 +491,7 @@ func (s *GraphQLIntegrationSuite) TestCreateTranslationValidation() { func (s *GraphQLIntegrationSuite) TestUpdateTranslationValidation() { s.Run("should return error for invalid input", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") createdTranslation, err := s.App.Translation.Commands.CreateOrUpdateTranslation(s.AdminCtx, translation.CreateOrUpdateTranslationInput{ Title: "Test Translation", Language: "en", @@ -527,8 +521,7 @@ func (s *GraphQLIntegrationSuite) TestUpdateTranslationValidation() { } // Execute the mutation - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().NotNil(response.Errors, "GraphQL mutation should return errors") @@ -539,8 +532,7 @@ func (s *GraphQLIntegrationSuite) TestUpdateTranslationValidation() { func (s *GraphQLIntegrationSuite) TestDeleteWork() { s.Run("should delete a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Define the mutation mutation := ` @@ -555,7 +547,7 @@ func (s *GraphQLIntegrationSuite) TestDeleteWork() { } // Execute the mutation - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().Nil(response.Errors, "GraphQL mutation should not return errors") @@ -573,7 +565,6 @@ func (s *GraphQLIntegrationSuite) TestDeleteAuthor() { // Arrange createdAuthor, err := s.App.Author.Commands.CreateAuthor(context.Background(), author.CreateAuthorInput{Name: "Test Author"}) s.Require().NoError(err) - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) // Define the mutation mutation := ` @@ -588,7 +579,7 @@ func (s *GraphQLIntegrationSuite) TestDeleteAuthor() { } // Execute the mutation - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().Nil(response.Errors, "GraphQL mutation should not return errors") @@ -604,7 +595,7 @@ func (s *GraphQLIntegrationSuite) TestDeleteAuthor() { func (s *GraphQLIntegrationSuite) TestDeleteTranslation() { s.Run("should delete a translation", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") createdTranslation, err := s.App.Translation.Commands.CreateOrUpdateTranslation(s.AdminCtx, translation.CreateOrUpdateTranslationInput{ Title: "Test Translation", Language: "en", @@ -613,7 +604,6 @@ func (s *GraphQLIntegrationSuite) TestDeleteTranslation() { TranslatableType: "works", }) s.Require().NoError(err) - _, adminToken := s.CreateAuthenticatedUser("admin", "admin@test.com", domain.UserRoleAdmin) // Define the mutation mutation := ` @@ -628,7 +618,7 @@ func (s *GraphQLIntegrationSuite) TestDeleteTranslation() { } // Execute the mutation - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().Nil(response.Errors, "GraphQL mutation should not return errors") @@ -649,7 +639,6 @@ func TestGraphQLIntegrationSuite(t *testing.T) { func (s *GraphQLIntegrationSuite) TestBookMutations() { // Create users for testing authorization _, readerToken := s.CreateAuthenticatedUser("bookreader", "bookreader@test.com", domain.UserRoleReader) - _, adminToken := s.CreateAuthenticatedUser("bookadmin", "bookadmin@test.com", domain.UserRoleAdmin) var bookID string @@ -738,7 +727,7 @@ func (s *GraphQLIntegrationSuite) TestBookMutations() { } // Execute the mutation with the admin's token - response, err := executeGraphQL[UpdateBookResponse](s, mutation, variables, &adminToken) + response, err := executeGraphQL[UpdateBookResponse](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().NotNil(response) s.Require().Nil(response.Errors, "GraphQL mutation should not return errors") @@ -777,7 +766,7 @@ func (s *GraphQLIntegrationSuite) TestBookMutations() { } // Execute the mutation with the admin's token - response, err := executeGraphQL[any](s, mutation, variables, &adminToken) + response, err := executeGraphQL[any](s, mutation, variables, &s.AdminToken) s.Require().NoError(err) s.Require().Nil(response.Errors) s.True(response.Data.(map[string]interface{})["deleteBook"].(bool)) @@ -786,7 +775,6 @@ func (s *GraphQLIntegrationSuite) TestBookMutations() { func (s *GraphQLIntegrationSuite) TestBookQueries() { // Create a book to query - _, adminToken := s.CreateAuthenticatedUser("bookadmin2", "bookadmin2@test.com", domain.UserRoleAdmin) createMutation := ` mutation CreateBook($input: BookInput!) { createBook(input: $input) { @@ -802,7 +790,7 @@ func (s *GraphQLIntegrationSuite) TestBookQueries() { "isbn": "978-0-306-40615-7", }, } - createResponse, err := executeGraphQL[CreateBookResponse](s, createMutation, createVariables, &adminToken) + createResponse, err := executeGraphQL[CreateBookResponse](s, createMutation, createVariables, &s.AdminToken) s.Require().NoError(err) bookID := createResponse.Data.CreateBook.ID @@ -916,7 +904,7 @@ func (s *GraphQLIntegrationSuite) TestCommentMutations() { _ = otherUser // Create a work to comment on - work := s.CreateTestWork("Commentable Work", "en", "Some content") + work := s.CreateTestWork(s.AdminCtx, "Commentable Work", "en", "Some content") var commentID string @@ -1043,7 +1031,7 @@ func (s *GraphQLIntegrationSuite) TestLikeMutations() { _ = otherUser // Create a work to like - work := s.CreateTestWork("Likeable Work", "en", "Some content") + work := s.CreateTestWork(s.AdminCtx, "Likeable Work", "en", "Some content") var likeID string @@ -1132,7 +1120,7 @@ func (s *GraphQLIntegrationSuite) TestBookmarkMutations() { _ = otherUser // Create a work to bookmark - work := s.CreateTestWork("Bookmarkable Work", "en", "Some content") + work := s.CreateTestWork(s.AdminCtx, "Bookmarkable Work", "en", "Some content") s.Run("should create a bookmark on a work", func() { // Define the mutation @@ -1239,8 +1227,8 @@ type TrendingWorksResponse struct { func (s *GraphQLIntegrationSuite) TestTrendingWorksQuery() { s.Run("should return a list of trending works", func() { // Arrange - work1 := s.CreateTestWork("Work 1", "en", "content") - work2 := s.CreateTestWork("Work 2", "en", "content") + work1 := s.CreateTestWork(s.AdminCtx, "Work 1", "en", "content") + work2 := s.CreateTestWork(s.AdminCtx, "Work 2", "en", "content") s.DB.Create(&domain.WorkStats{WorkID: work1.ID, Views: 100, Likes: 10, Comments: 1}) s.DB.Create(&domain.WorkStats{WorkID: work2.ID, Views: 10, Likes: 100, Comments: 10}) s.Require().NoError(s.App.Analytics.UpdateTrending(context.Background())) @@ -1363,7 +1351,7 @@ func (s *GraphQLIntegrationSuite) TestCollectionMutations() { s.Run("should add a work to a collection", func() { // Create a work - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Define the mutation mutation := ` @@ -1388,7 +1376,7 @@ func (s *GraphQLIntegrationSuite) TestCollectionMutations() { s.Run("should remove a work from a collection", func() { // Create a work and add it to the collection first - work := s.CreateTestWork("Another Work", "en", "Some content") + work := s.CreateTestWork(s.AdminCtx, "Another Work", "en", "Some content") collectionIDInt, err := strconv.ParseUint(collectionID, 10, 64) s.Require().NoError(err) err = s.App.Collection.Commands.AddWorkToCollection(context.Background(), collection.AddWorkToCollectionInput{ diff --git a/internal/adapters/graphql/schema.resolvers.go b/internal/adapters/graphql/schema.resolvers.go index 0237761..f52e2a0 100644 --- a/internal/adapters/graphql/schema.resolvers.go +++ b/internal/adapters/graphql/schema.resolvers.go @@ -1262,11 +1262,11 @@ func (r *queryResolver) Work(ctx context.Context, id string) (*model.Work, error workDTO, err := r.App.Work.Queries.GetWorkByID(ctx, uint(workID)) if err != nil { + if errors.Is(err, domain.ErrEntityNotFound) { + return nil, nil + } return nil, err } - if workDTO == nil { - return nil, nil - } go func() { if err := r.App.Analytics.IncrementWorkViews(context.Background(), uint(workID)); err != nil { @@ -1674,9 +1674,6 @@ func (r *queryResolver) UserProfile(ctx context.Context, userID string) (*model. profile, err := r.App.User.Queries.UserProfile(ctx, uint(uID)) if err != nil { - if errors.Is(err, domain.ErrEntityNotFound) { - return nil, nil - } return nil, err } if profile == nil { diff --git a/internal/adapters/graphql/translation_resolvers_test.go b/internal/adapters/graphql/translation_resolvers_test.go new file mode 100644 index 0000000..0e4e772 --- /dev/null +++ b/internal/adapters/graphql/translation_resolvers_test.go @@ -0,0 +1,197 @@ +package graphql_test + +import ( + "context" + "fmt" + "os" + "testing" + "tercul/internal/adapters/graphql" + "tercul/internal/adapters/graphql/model" + "tercul/internal/app/auth" + "tercul/internal/domain" + platform_auth "tercul/internal/platform/auth" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type TranslationResolversTestSuite struct { + testutil.IntegrationTestSuite + queryResolver graphql.QueryResolver + mutationResolver graphql.MutationResolver +} + +func TestTranslationResolvers(t *testing.T) { + suite.Run(t, new(TranslationResolversTestSuite)) +} + +func (s *TranslationResolversTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(&testutil.TestConfig{ + DBPath: "translation_resolvers_test.db", + }) +} + +func (s *TranslationResolversTestSuite) TearDownSuite() { + s.IntegrationTestSuite.TearDownSuite() + os.Remove("translation_resolvers_test.db") +} + +func (s *TranslationResolversTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + resolver := &graphql.Resolver{App: s.App} + s.queryResolver = resolver.Query() + s.mutationResolver = resolver.Mutation() +} + +// Helper to create a user for tests +func (s *TranslationResolversTestSuite) createUser(username, email, password string, role domain.UserRole) *domain.User { + resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ + Username: username, + Email: email, + Password: password, + }) + s.Require().NoError(err) + + user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) + s.Require().NoError(err) + + if role != user.Role { + user.Role = role + err = s.DB.Save(user).Error + s.Require().NoError(err) + } + return user +} + +// Helper to create a context with JWT claims +func (s *TranslationResolversTestSuite) contextWithClaims(user *domain.User) context.Context { + return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: user.ID, + Role: string(user.Role), + }) +} + +func (s *TranslationResolversTestSuite) TestCreateTranslation() { + user := s.createUser("translator", "translator@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + work := s.CreateTestWork(ctx, "Test Work for Translation", "en", "Original Content") + + s.Run("Success", func() { + // Arrange + content := "Translated Content" + input := model.TranslationInput{ + Name: "Spanish Translation", + Language: "es", + WorkID: fmt.Sprintf("%d", work.ID), + Content: &content, + } + + // Act + translation, err := s.mutationResolver.CreateTranslation(ctx, input) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(translation) + s.Equal("Spanish Translation", translation.Name) + s.Equal("es", translation.Language) + s.Equal("Translated Content", *translation.Content) + }) +} + +func (s *TranslationResolversTestSuite) TestUpdateTranslation() { + user := s.createUser("translator-updater", "translator-updater@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + work := s.CreateTestWork(ctx, "Test Work for Translation Update", "en", "Original Content") + + content := "Initial Translated Content" + createInput := model.TranslationInput{ + Name: "Updatable Translation", + Language: "fr", + WorkID: fmt.Sprintf("%d", work.ID), + Content: &content, + } + createdTranslation, err := s.mutationResolver.CreateTranslation(ctx, createInput) + s.Require().NoError(err) + + s.Run("Success", func() { + // Arrange + updatedContent := "Updated French Content" + updateInput := model.TranslationInput{ + Name: "Updated French Translation", + Language: "fr", + WorkID: fmt.Sprintf("%d", work.ID), + Content: &updatedContent, + } + + // Act + updatedTranslation, err := s.mutationResolver.UpdateTranslation(ctx, createdTranslation.ID, updateInput) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(updatedTranslation) + s.Equal("Updated French Translation", updatedTranslation.Name) + s.Equal("fr", updatedTranslation.Language) + s.Equal("Updated French Content", *updatedTranslation.Content) + }) +} + +func (s *TranslationResolversTestSuite) TestDeleteTranslation() { + user := s.createUser("translator-deletor", "translator-deletor@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + work := s.CreateTestWork(ctx, "Test Work for Translation Deletion", "en", "Original Content") + + content := "Content to be deleted" + createInput := model.TranslationInput{ + Name: "Deletable Translation", + Language: "de", + WorkID: fmt.Sprintf("%d", work.ID), + Content: &content, + } + createdTranslation, err := s.mutationResolver.CreateTranslation(ctx, createInput) + s.Require().NoError(err) + + s.Run("Success", func() { + // Act + ok, err := s.mutationResolver.DeleteTranslation(ctx, createdTranslation.ID) + + // Assert + s.Require().NoError(err) + s.True(ok) + }) +} + +func (s *TranslationResolversTestSuite) TestTranslationQueries() { + user := s.createUser("translator-reader", "translator-reader@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + work := s.CreateTestWork(ctx, "Test Work for Translation Queries", "en", "Original Content") + + content := "Queried Content" + createInput := model.TranslationInput{ + Name: "Queried Translation", + Language: "it", + WorkID: fmt.Sprintf("%d", work.ID), + Content: &content, + } + createdTranslation, err := s.mutationResolver.CreateTranslation(ctx, createInput) + s.Require().NoError(err) + + s.Run("Get Translation by ID", func() { + // Act + translation, err := s.queryResolver.Translation(ctx, createdTranslation.ID) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(translation) + s.Equal("Queried Translation", translation.Name) + }) + + s.Run("List Translations for a Work", func() { + // Act + translations, err := s.queryResolver.Translations(ctx, fmt.Sprintf("%d", work.ID), nil, nil, nil) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(translations) + s.Len(translations, 2) // Original + Italian + }) +} \ No newline at end of file diff --git a/internal/adapters/graphql/user_resolvers_unit_test.go b/internal/adapters/graphql/user_resolvers_unit_test.go new file mode 100644 index 0000000..f437be8 --- /dev/null +++ b/internal/adapters/graphql/user_resolvers_unit_test.go @@ -0,0 +1,637 @@ +package graphql + +import ( + "context" + "fmt" + "testing" + "tercul/internal/app" + "tercul/internal/app/authz" + "tercul/internal/app/user" + "tercul/internal/domain" + "tercul/internal/adapters/graphql/model" + platform_auth "tercul/internal/platform/auth" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +// mockUserRepositoryForUserResolver is a mock for the user repository. +type mockUserRepositoryForUserResolver struct{ mock.Mock } + +// Implement domain.UserRepository +func (m *mockUserRepositoryForUserResolver) GetByID(ctx context.Context, id uint) (*domain.User, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepositoryForUserResolver) FindByUsername(ctx context.Context, username string) (*domain.User, error) { + args := m.Called(ctx, username) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepositoryForUserResolver) FindByEmail(ctx context.Context, email string) (*domain.User, error) { + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepositoryForUserResolver) ListByRole(ctx context.Context, role domain.UserRole) ([]domain.User, error) { + args := m.Called(ctx, role) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} +func (m *mockUserRepositoryForUserResolver) Create(ctx context.Context, entity *domain.User) error { + return m.Called(ctx, entity).Error(0) +} +func (m *mockUserRepositoryForUserResolver) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { + return nil +} +func (m *mockUserRepositoryForUserResolver) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.User, error) { + return nil, nil +} +func (m *mockUserRepositoryForUserResolver) Update(ctx context.Context, entity *domain.User) error { + return m.Called(ctx, entity).Error(0) +} +func (m *mockUserRepositoryForUserResolver) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { + return nil +} +func (m *mockUserRepositoryForUserResolver) Delete(ctx context.Context, id uint) error { + return m.Called(ctx, id).Error(0) +} +func (m *mockUserRepositoryForUserResolver) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + return nil +} +func (m *mockUserRepositoryForUserResolver) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.User], error) { + return nil, nil +} +func (m *mockUserRepositoryForUserResolver) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.User, error) { + return nil, nil +} +func (m *mockUserRepositoryForUserResolver) ListAll(ctx context.Context) ([]domain.User, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} +func (m *mockUserRepositoryForUserResolver) Count(ctx context.Context) (int64, error) { return 0, nil } +func (m *mockUserRepositoryForUserResolver) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + return 0, nil +} +func (m *mockUserRepositoryForUserResolver) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.User, error) { + return nil, nil +} +func (m *mockUserRepositoryForUserResolver) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.User, error) { + return nil, nil +} +func (m *mockUserRepositoryForUserResolver) Exists(ctx context.Context, id uint) (bool, error) { + return false, nil +} +func (m *mockUserRepositoryForUserResolver) BeginTx(ctx context.Context) (*gorm.DB, error) { + return nil, nil +} +func (m *mockUserRepositoryForUserResolver) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return nil +} + +// mockUserProfileRepository is a mock for the user profile repository. +type mockUserProfileRepository struct{ mock.Mock } + +// Implement domain.UserProfileRepository +func (m *mockUserProfileRepository) GetByUserID(ctx context.Context, userID uint) (*domain.UserProfile, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserProfile), args.Error(1) +} + +// Implement BaseRepository methods for UserProfile +func (m *mockUserProfileRepository) Create(ctx context.Context, entity *domain.UserProfile) error { + return nil +} +func (m *mockUserProfileRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.UserProfile) error { + return nil +} +func (m *mockUserProfileRepository) GetByID(ctx context.Context, id uint) (*domain.UserProfile, error) { + return nil, nil +} +func (m *mockUserProfileRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.UserProfile, error) { + return nil, nil +} +func (m *mockUserProfileRepository) Update(ctx context.Context, entity *domain.UserProfile) error { + return nil +} +func (m *mockUserProfileRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.UserProfile) error { + return nil +} +func (m *mockUserProfileRepository) Delete(ctx context.Context, id uint) error { return nil } +func (m *mockUserProfileRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + return nil +} +func (m *mockUserProfileRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.UserProfile], error) { + return nil, nil +} +func (m *mockUserProfileRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.UserProfile, error) { + return nil, nil +} +func (m *mockUserProfileRepository) ListAll(ctx context.Context) ([]domain.UserProfile, error) { + return nil, nil +} +func (m *mockUserProfileRepository) Count(ctx context.Context) (int64, error) { return 0, nil } +func (m *mockUserProfileRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + return 0, nil +} +func (m *mockUserProfileRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.UserProfile, error) { + return nil, nil +} +func (m *mockUserProfileRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.UserProfile, error) { + return nil, nil +} +func (m *mockUserProfileRepository) Exists(ctx context.Context, id uint) (bool, error) { + return false, nil +} +func (m *mockUserProfileRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { return nil, nil } +func (m *mockUserProfileRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return nil +} + +// UserResolversUnitSuite is a unit test suite for the user resolvers. +type UserResolversUnitSuite struct { + suite.Suite + resolver *Resolver + mockUserRepo *mockUserRepositoryForUserResolver + mockUserProfileRepo *mockUserProfileRepository +} + +// SetupTest sets up the test suite +func (s *UserResolversUnitSuite) SetupTest() { + s.mockUserRepo = new(mockUserRepositoryForUserResolver) + s.mockUserProfileRepo = new(mockUserProfileRepository) + + // The authz service dependencies are not needed for the user commands being tested. + authzSvc := authz.NewService(nil, nil, s.mockUserRepo, nil) + + userCommands := user.NewUserCommands(s.mockUserRepo, authzSvc) + userQueries := user.NewUserQueries(s.mockUserRepo, s.mockUserProfileRepo) + + userService := &user.Service{ + Commands: userCommands, + Queries: userQueries, + } + + s.resolver = &Resolver{ + App: &app.Application{ + User: userService, + }, + } +} + +// TestUserResolversUnitSuite runs the test suite +func TestUserResolversUnitSuite(t *testing.T) { + suite.Run(t, new(UserResolversUnitSuite)) +} + +func (s *UserResolversUnitSuite) TestUserQuery() { + s.Run("Success", func() { + s.SetupTest() + userID := uint(1) + userIDStr := "1" + ctx := context.Background() + + expectedUser := &domain.User{ + Username: "testuser", + Email: "test@test.com", + Role: domain.UserRoleReader, + } + expectedUser.ID = userID + + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(expectedUser, nil).Once() + + gqlUser, err := s.resolver.Query().User(ctx, userIDStr) + + s.Require().NoError(err) + s.Require().NotNil(gqlUser) + s.Equal(userIDStr, gqlUser.ID) + s.Equal(expectedUser.Username, gqlUser.Username) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Not Found", func() { + s.SetupTest() + userID := uint(99) + userIDStr := "99" + ctx := context.Background() + + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(nil, domain.ErrEntityNotFound).Once() + + gqlUser, err := s.resolver.Query().User(ctx, userIDStr) + + s.Require().Error(err) // The resolver should propagate the error + s.Require().Nil(gqlUser) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Invalid ID", func() { + s.SetupTest() + ctx := context.Background() + _, err := s.resolver.Query().User(ctx, "invalid") + s.Require().Error(err) + }) +} + +func (s *UserResolversUnitSuite) TestUserProfileQuery() { + s.Run("Success", func() { + s.SetupTest() + userID := uint(1) + userIDStr := "1" + ctx := context.Background() + + expectedProfile := &domain.UserProfile{ + UserID: userID, + PhoneNumber: "12345", + } + expectedProfile.ID = 1 + + expectedUser := &domain.User{ + Username: "testuser", + } + expectedUser.ID = userID + + s.mockUserProfileRepo.On("GetByUserID", mock.Anything, userID).Return(expectedProfile, nil).Once() + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(expectedUser, nil).Once() + + gqlProfile, err := s.resolver.Query().UserProfile(ctx, userIDStr) + + s.Require().NoError(err) + s.Require().NotNil(gqlProfile) + s.Equal(userIDStr, gqlProfile.UserID) + s.Equal(&expectedProfile.PhoneNumber, gqlProfile.PhoneNumber) + s.mockUserProfileRepo.AssertExpectations(s.T()) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Profile Not Found", func() { + s.SetupTest() + userID := uint(99) + userIDStr := "99" + ctx := context.Background() + + s.mockUserProfileRepo.On("GetByUserID", mock.Anything, userID).Return(nil, domain.ErrEntityNotFound).Once() + + gqlProfile, err := s.resolver.Query().UserProfile(ctx, userIDStr) + + s.Require().Error(err) + s.Require().Nil(gqlProfile) + s.mockUserProfileRepo.AssertExpectations(s.T()) + }) + + s.Run("User Not Found for profile", func() { + s.SetupTest() + userID := uint(1) + userIDStr := "1" + ctx := context.Background() + + expectedProfile := &domain.UserProfile{ + UserID: userID, + } + expectedProfile.ID = 1 + + s.mockUserProfileRepo.On("GetByUserID", mock.Anything, userID).Return(expectedProfile, nil).Once() + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(nil, domain.ErrEntityNotFound).Once() + + _, err := s.resolver.Query().UserProfile(ctx, userIDStr) + + s.Require().Error(err) + s.mockUserProfileRepo.AssertExpectations(s.T()) + s.mockUserRepo.AssertExpectations(s.T()) + }) +} + +func (s *UserResolversUnitSuite) TestUpdateProfileMutation() { + s.Run("Success", func() { + s.SetupTest() + actorID := uint(1) + + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleReader), + }) + + displayName := "New Name" + input := model.UserInput{DisplayName: &displayName} + + originalUser := &domain.User{DisplayName: "Old Name"} + originalUser.ID = actorID + + s.mockUserRepo.On("GetByID", mock.Anything, actorID).Return(originalUser, nil).Once() + s.mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool { + return u.ID == actorID && u.DisplayName == displayName + })).Return(nil).Once() + + updatedUser, err := s.resolver.Mutation().UpdateProfile(ctx, input) + + s.Require().NoError(err) + s.Require().NotNil(updatedUser) + s.Equal(displayName, *updatedUser.DisplayName) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Unauthorized", func() { + s.SetupTest() + ctx := context.Background() // no user + displayName := "New Name" + input := model.UserInput{DisplayName: &displayName} + + _, err := s.resolver.Mutation().UpdateProfile(ctx, input) + s.Require().Error(err) + s.ErrorIs(err, domain.ErrUnauthorized) + }) +} + +func (s *UserResolversUnitSuite) TestUpdateUserMutation() { + s.Run("Success as self", func() { + s.SetupTest() + actorID := uint(1) + targetID := uint(1) + targetIDStr := "1" + username := "new_username" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleReader), + }) + input := model.UserInput{Username: &username} + + originalUser := &domain.User{Username: "old_username"} + originalUser.ID = targetID + + s.mockUserRepo.On("GetByID", mock.Anything, targetID).Return(originalUser, nil).Once() + s.mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool { + return u.ID == targetID && u.Username == username + })).Return(nil).Once() + + updatedUser, err := s.resolver.Mutation().UpdateUser(ctx, targetIDStr, input) + + s.Require().NoError(err) + s.Require().NotNil(updatedUser) + s.Equal(username, updatedUser.Username) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Success as admin", func() { + s.SetupTest() + actorID := uint(99) // Admin + targetID := uint(1) + targetIDStr := "1" + username := "new_username" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleAdmin), + }) + input := model.UserInput{Username: &username} + + originalUser := &domain.User{Username: "old_username"} + originalUser.ID = targetID + + s.mockUserRepo.On("GetByID", mock.Anything, targetID).Return(originalUser, nil).Once() + s.mockUserRepo.On("Update", mock.Anything, mock.MatchedBy(func(u *domain.User) bool { + return u.ID == targetID && u.Username == username + })).Return(nil).Once() + + updatedUser, err := s.resolver.Mutation().UpdateUser(ctx, targetIDStr, input) + + s.Require().NoError(err) + s.Require().NotNil(updatedUser) + s.Equal(username, updatedUser.Username) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Forbidden", func() { + s.SetupTest() + actorID := uint(2) + targetIDStr := "1" + username := "new_username" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleReader), + }) + input := model.UserInput{Username: &username} + + _, err := s.resolver.Mutation().UpdateUser(ctx, targetIDStr, input) + + s.Require().Error(err) + s.ErrorIs(err, domain.ErrForbidden) + s.mockUserRepo.AssertNotCalled(s.T(), "GetByID") + s.mockUserRepo.AssertNotCalled(s.T(), "Update") + }) + + s.Run("User not found", func() { + s.SetupTest() + actorID := uint(1) + targetID := uint(1) + targetIDStr := "1" + username := "new_username" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleReader), + }) + input := model.UserInput{Username: &username} + + s.mockUserRepo.On("GetByID", mock.Anything, targetID).Return(nil, domain.ErrEntityNotFound).Once() + + _, err := s.resolver.Mutation().UpdateUser(ctx, targetIDStr, input) + + s.Require().Error(err) + s.ErrorIs(err, domain.ErrEntityNotFound) + s.mockUserRepo.AssertExpectations(s.T()) + s.mockUserRepo.AssertNotCalled(s.T(), "Update") + }) +} + +func (s *UserResolversUnitSuite) TestUsersQuery() { + s.Run("Success without role", func() { + s.SetupTest() + ctx := context.Background() + expectedUsers := []domain.User{ + {Username: "user1"}, + {Username: "user2"}, + } + s.mockUserRepo.On("ListAll", mock.Anything).Return(expectedUsers, nil).Once() + + users, err := s.resolver.Query().Users(ctx, nil, nil, nil) + + s.Require().NoError(err) + s.Len(users, 2) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Success with role", func() { + s.SetupTest() + ctx := context.Background() + role := domain.UserRoleAdmin + modelRole := model.UserRoleAdmin + expectedUsers := []domain.User{ + {Username: "admin1", Role: role}, + } + s.mockUserRepo.On("ListByRole", mock.Anything, role).Return(expectedUsers, nil).Once() + + users, err := s.resolver.Query().Users(ctx, nil, nil, &modelRole) + + s.Require().NoError(err) + s.Len(users, 1) + s.Equal(model.UserRoleAdmin, users[0].Role) + s.mockUserRepo.AssertExpectations(s.T()) + }) +} + +func (s *UserResolversUnitSuite) TestUserByEmailQuery() { + s.Run("Success", func() { + s.SetupTest() + email := "test@test.com" + ctx := context.Background() + + expectedUser := &domain.User{ + Username: "testuser", + Email: email, + Role: domain.UserRoleReader, + } + expectedUser.ID = 1 + + s.mockUserRepo.On("FindByEmail", mock.Anything, email).Return(expectedUser, nil).Once() + + gqlUser, err := s.resolver.Query().UserByEmail(ctx, email) + + s.Require().NoError(err) + s.Require().NotNil(gqlUser) + s.Equal(email, gqlUser.Email) + s.mockUserRepo.AssertExpectations(s.T()) + }) +} + +func (s *UserResolversUnitSuite) TestUserByUsernameQuery() { + s.Run("Success", func() { + s.SetupTest() + username := "testuser" + ctx := context.Background() + + expectedUser := &domain.User{ + Username: username, + Email: "test@test.com", + Role: domain.UserRoleReader, + } + expectedUser.ID = 1 + + s.mockUserRepo.On("FindByUsername", mock.Anything, username).Return(expectedUser, nil).Once() + + gqlUser, err := s.resolver.Query().UserByUsername(ctx, username) + + s.Require().NoError(err) + s.Require().NotNil(gqlUser) + s.Equal(username, gqlUser.Username) + s.mockUserRepo.AssertExpectations(s.T()) + }) +} + +func (s *UserResolversUnitSuite) TestMeQuery() { + s.Run("Success", func() { + s.SetupTest() + userID := uint(1) + ctx := platform_auth.ContextWithUserID(context.Background(), userID) + + expectedUser := &domain.User{ + Username: "testuser", + Email: "test@test.com", + Role: domain.UserRoleReader, + } + expectedUser.ID = userID + + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(expectedUser, nil).Once() + + gqlUser, err := s.resolver.Query().Me(ctx) + + s.Require().NoError(err) + s.Require().NotNil(gqlUser) + s.Equal(fmt.Sprintf("%d", userID), gqlUser.ID) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Unauthorized", func() { + s.SetupTest() + ctx := context.Background() // No user in context + _, err := s.resolver.Query().Me(ctx) + s.Require().Error(err) + s.Equal(domain.ErrUnauthorized, err) + }) +} + +func (s *UserResolversUnitSuite) TestDeleteUserMutation() { + s.Run("Success as self", func() { + s.SetupTest() + actorID := uint(1) + targetID := uint(1) + targetIDStr := "1" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleReader), + }) + + s.mockUserRepo.On("Delete", mock.Anything, targetID).Return(nil).Once() + + ok, err := s.resolver.Mutation().DeleteUser(ctx, targetIDStr) + + s.Require().NoError(err) + s.True(ok) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Success as admin", func() { + s.SetupTest() + actorID := uint(99) // Admin + targetID := uint(1) + targetIDStr := "1" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleAdmin), + }) + + s.mockUserRepo.On("Delete", mock.Anything, targetID).Return(nil).Once() + + ok, err := s.resolver.Mutation().DeleteUser(ctx, targetIDStr) + + s.Require().NoError(err) + s.True(ok) + s.mockUserRepo.AssertExpectations(s.T()) + }) + + s.Run("Forbidden", func() { + s.SetupTest() + actorID := uint(2) + targetIDStr := "1" + ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ + UserID: actorID, + Role: string(domain.UserRoleReader), + }) + + ok, err := s.resolver.Mutation().DeleteUser(ctx, targetIDStr) + + s.Require().Error(err) + s.ErrorIs(err, domain.ErrForbidden) + s.False(ok) + s.mockUserRepo.AssertNotCalled(s.T(), "Delete") + }) + + s.Run("Invalid ID", func() { + s.SetupTest() + ctx := context.Background() + _, err := s.resolver.Mutation().DeleteUser(ctx, "invalid") + s.Require().Error(err) + }) +} \ No newline at end of file diff --git a/internal/adapters/graphql/work_resolvers_test.go b/internal/adapters/graphql/work_resolvers_test.go new file mode 100644 index 0000000..29a166e --- /dev/null +++ b/internal/adapters/graphql/work_resolvers_test.go @@ -0,0 +1,260 @@ +package graphql_test + +import ( + "context" + "fmt" + "os" + "testing" + "tercul/internal/adapters/graphql" + "tercul/internal/adapters/graphql/model" + "tercul/internal/app/auth" + "tercul/internal/domain" + platform_auth "tercul/internal/platform/auth" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type WorkResolversTestSuite struct { + testutil.IntegrationTestSuite + queryResolver graphql.QueryResolver + mutationResolver graphql.MutationResolver +} + +func TestWorkResolvers(t *testing.T) { + suite.Run(t, new(WorkResolversTestSuite)) +} + +func (s *WorkResolversTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(&testutil.TestConfig{ + DBPath: "work_resolvers_test.db", + }) +} + +func (s *WorkResolversTestSuite) TearDownSuite() { + s.IntegrationTestSuite.TearDownSuite() + os.Remove("work_resolvers_test.db") +} + +func (s *WorkResolversTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + resolver := &graphql.Resolver{App: s.App} + s.queryResolver = resolver.Query() + s.mutationResolver = resolver.Mutation() +} + +// Helper to create a user for tests +func (s *WorkResolversTestSuite) createUser(username, email, password string, role domain.UserRole) *domain.User { + resp, err := s.App.Auth.Commands.Register(context.Background(), auth.RegisterInput{ + Username: username, + Email: email, + Password: password, + }) + s.Require().NoError(err) + + user, err := s.App.User.Queries.User(context.Background(), resp.User.ID) + s.Require().NoError(err) + + if role != user.Role { + user.Role = role + err = s.DB.Save(user).Error + s.Require().NoError(err) + } + return user +} + +// Helper to create a context with JWT claims +func (s *WorkResolversTestSuite) contextWithClaims(user *domain.User) context.Context { + return testutil.ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: user.ID, + Role: string(user.Role), + }) +} + +func (s *WorkResolversTestSuite) TestCreateWork() { + user := s.createUser("work-creator", "work-creator@test.com", "password", domain.UserRoleContributor) + ctx := s.contextWithClaims(user) + + s.Run("Success", func() { + // Arrange + input := model.WorkInput{ + Name: "My First Work", + Language: "en", + } + + // Act + work, err := s.mutationResolver.CreateWork(ctx, input) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(work) + s.Equal("My First Work", work.Name) + s.Equal("en", work.Language) + + // Verify in DB + dbWork, err := s.App.Work.Queries.GetWorkByID(context.Background(), 1) + s.Require().NoError(err) + s.Require().NotNil(dbWork) + s.Equal("My First Work", dbWork.Title) + }) +} + +func (s *WorkResolversTestSuite) TestWorkQuery() { + // Arrange + user := s.createUser("work-reader", "work-reader@test.com", "password", domain.UserRoleReader) + ctx := s.contextWithClaims(user) + + // Create a work to query + domainWork := &domain.Work{Title: "Query Me", TranslatableModel: domain.TranslatableModel{Language: "es"}} + createdWork, err := s.App.Work.Commands.CreateWork(ctx, domainWork) + s.Require().NoError(err) + workID := fmt.Sprintf("%d", createdWork.ID) + + s.Run("Success", func() { + // Act + work, err := s.queryResolver.Work(ctx, workID) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(work) + s.Equal("Query Me", work.Name) + s.Equal("es", work.Language) + }) + + s.Run("Not Found", func() { + // Act + work, err := s.queryResolver.Work(ctx, "99999") + + // Assert + s.Require().NoError(err) + s.Require().Nil(work) + }) +} + +func (s *WorkResolversTestSuite) TestUpdateWork() { + // Arrange + user := s.createUser("work-updater", "work-updater@test.com", "password", domain.UserRoleContributor) + admin := s.createUser("work-admin", "work-admin@test.com", "password", domain.UserRoleAdmin) + otherUser := s.createUser("other-user", "other-user@test.com", "password", domain.UserRoleContributor) + + // Create a work to update + domainWork := &domain.Work{Title: "Update Me", TranslatableModel: domain.TranslatableModel{Language: "fr"}} + createdWork, err := s.App.Work.Commands.CreateWork(s.contextWithClaims(user), domainWork) + s.Require().NoError(err) + workID := fmt.Sprintf("%d", createdWork.ID) + + s.Run("Success as owner", func() { + // Arrange + ctx := s.contextWithClaims(user) + input := model.WorkInput{Name: "Updated Title", Language: "fr"} + + // Act + work, err := s.mutationResolver.UpdateWork(ctx, workID, input) + + // Assert + s.Require().NoError(err) + s.Equal("Updated Title", work.Name) + }) + + s.Run("Success as admin", func() { + // Arrange + ctx := s.contextWithClaims(admin) + input := model.WorkInput{Name: "Updated by Admin", Language: "fr"} + + // Act + work, err := s.mutationResolver.UpdateWork(ctx, workID, input) + + // Assert + s.Require().NoError(err) + s.Equal("Updated by Admin", work.Name) + }) + + s.Run("Forbidden for other user", func() { + // Arrange + ctx := s.contextWithClaims(otherUser) + input := model.WorkInput{Name: "Should Not Update", Language: "fr"} + + // Act + _, err := s.mutationResolver.UpdateWork(ctx, workID, input) + + // Assert + s.Require().Error(err) + s.ErrorIs(err, domain.ErrForbidden) + }) +} + +func (s *WorkResolversTestSuite) TestDeleteWork() { + // Arrange + user := s.createUser("work-deletor", "work-deletor@test.com", "password", domain.UserRoleContributor) + admin := s.createUser("work-admin-deletor", "work-admin-deletor@test.com", "password", domain.UserRoleAdmin) + otherUser := s.createUser("other-user-deletor", "other-user-deletor@test.com", "password", domain.UserRoleContributor) + + s.Run("Success as owner", func() { + // Arrange + domainWork := &domain.Work{Title: "Delete Me", TranslatableModel: domain.TranslatableModel{Language: "de"}} + createdWork, err := s.App.Work.Commands.CreateWork(s.contextWithClaims(user), domainWork) + s.Require().NoError(err) + workID := fmt.Sprintf("%d", createdWork.ID) + ctx := s.contextWithClaims(user) + + // Act + ok, err := s.mutationResolver.DeleteWork(ctx, workID) + + // Assert + s.Require().NoError(err) + s.True(ok) + }) + + s.Run("Success as admin", func() { + // Arrange + domainWork := &domain.Work{Title: "Delete Me Admin", TranslatableModel: domain.TranslatableModel{Language: "de"}} + createdWork, err := s.App.Work.Commands.CreateWork(s.contextWithClaims(user), domainWork) + s.Require().NoError(err) + workID := fmt.Sprintf("%d", createdWork.ID) + ctx := s.contextWithClaims(admin) + + // Act + ok, err := s.mutationResolver.DeleteWork(ctx, workID) + + // Assert + s.Require().NoError(err) + s.True(ok) + }) + + s.Run("Forbidden for other user", func() { + // Arrange + domainWork := &domain.Work{Title: "Don't Delete Me", TranslatableModel: domain.TranslatableModel{Language: "de"}} + createdWork, err := s.App.Work.Commands.CreateWork(s.contextWithClaims(user), domainWork) + s.Require().NoError(err) + workID := fmt.Sprintf("%d", createdWork.ID) + ctx := s.contextWithClaims(otherUser) + + // Act + _, err = s.mutationResolver.DeleteWork(ctx, workID) + + // Assert + s.Require().Error(err) + s.ErrorIs(err, domain.ErrForbidden) + }) +} + +func (s *WorkResolversTestSuite) TestWorksQuery() { + // Arrange + user := s.createUser("works-reader", "works-reader@test.com", "password", domain.UserRoleReader) + ctx := s.contextWithClaims(user) + + // Create some works + _, err := s.App.Work.Commands.CreateWork(ctx, &domain.Work{Title: "Work 1", TranslatableModel: domain.TranslatableModel{Language: "en"}}) + s.Require().NoError(err) + _, err = s.App.Work.Commands.CreateWork(ctx, &domain.Work{Title: "Work 2", TranslatableModel: domain.TranslatableModel{Language: "en"}}) + s.Require().NoError(err) + + s.Run("Success", func() { + // Act + works, err := s.queryResolver.Works(ctx, nil, nil, nil, nil, nil, nil, nil) + + // Assert + s.Require().NoError(err) + s.True(len(works) >= 2) // >= because other tests might have created works + }) +} \ No newline at end of file diff --git a/internal/adapters/graphql/work_resolvers_unit_test.go b/internal/adapters/graphql/work_resolvers_unit_test.go new file mode 100644 index 0000000..0d6cd2f --- /dev/null +++ b/internal/adapters/graphql/work_resolvers_unit_test.go @@ -0,0 +1,455 @@ +package graphql + +import ( + "context" + "testing" + "tercul/internal/adapters/graphql/model" + "tercul/internal/app" + "tercul/internal/app/authz" + "tercul/internal/app/translation" + "tercul/internal/app/work" + "tercul/internal/domain" + platform_auth "tercul/internal/platform/auth" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +// Mock Implementations +type mockWorkRepository struct{ mock.Mock } + +func (m *mockWorkRepository) Create(ctx context.Context, work *domain.Work) error { + args := m.Called(ctx, work) + work.ID = 1 + return args.Error(0) +} +func (m *mockWorkRepository) GetByID(ctx context.Context, id uint) (*domain.Work, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) IsAuthor(ctx context.Context, workID, authorID uint) (bool, error) { + args := m.Called(ctx, workID, authorID) + return args.Bool(0), args.Error(1) +} +func (m *mockWorkRepository) Update(ctx context.Context, work *domain.Work) error { + args := m.Called(ctx, work) + return args.Error(0) +} +func (m *mockWorkRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} +func (m *mockWorkRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.Work]), args.Error(1) +} +func (m *mockWorkRepository) FindByTitle(ctx context.Context, title string) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) FindByAuthor(ctx context.Context, authorID uint) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) FindByCategory(ctx context.Context, categoryID uint) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) FindByLanguage(ctx context.Context, language string, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { return nil, nil } +func (m *mockWorkRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Work, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) GetWithAssociations(ctx context.Context, id uint) (*domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) GetWithAssociationsInTx(ctx context.Context, tx *gorm.DB, id uint) (*domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) ListWithTranslations(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { return nil, nil } +func (m *mockWorkRepository) ListByCollectionID(ctx context.Context, collectionID uint) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Work) error { return nil } +func (m *mockWorkRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Work) error { return nil } +func (m *mockWorkRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { return nil } +func (m *mockWorkRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) ListAll(ctx context.Context) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) Count(ctx context.Context) (int64, error) { return 0, nil } +func (m *mockWorkRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { return 0, nil } +func (m *mockWorkRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Work, error) { return nil, nil } +func (m *mockWorkRepository) Exists(ctx context.Context, id uint) (bool, error) { return false, nil } +func (m *mockWorkRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { return nil, nil } +func (m *mockWorkRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { return nil } + + +type mockAuthorRepository struct{ mock.Mock } + +func (m *mockAuthorRepository) FindByName(ctx context.Context, name string) (*domain.Author, error) { + args := m.Called(ctx, name) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Create(ctx context.Context, author *domain.Author) error { + args := m.Called(ctx, author) + author.ID = 1 + return args.Error(0) +} +func (m *mockAuthorRepository) GetByID(ctx context.Context, id uint) (*domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) ListByBookID(ctx context.Context, bookID uint) ([]domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) ListByCountryID(ctx context.Context, countryID uint) ([]domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Author) error { return nil } +func (m *mockAuthorRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) Update(ctx context.Context, entity *domain.Author) error { return nil } +func (m *mockAuthorRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Author) error { return nil } +func (m *mockAuthorRepository) Delete(ctx context.Context, id uint) error { return nil } +func (m *mockAuthorRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { return nil } +func (m *mockAuthorRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Author], error) { return nil, nil } +func (m *mockAuthorRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) ListAll(ctx context.Context) ([]domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) Count(ctx context.Context) (int64, error) { return 0, nil } +func (m *mockAuthorRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { return 0, nil } +func (m *mockAuthorRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Author, error) { return nil, nil } +func (m *mockAuthorRepository) Exists(ctx context.Context, id uint) (bool, error) { return false, nil } +func (m *mockAuthorRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { return nil, nil } +func (m *mockAuthorRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { return nil } + +type mockUserRepository struct{ mock.Mock } + +func (m *mockUserRepository) GetByID(ctx context.Context, id uint) (*domain.User, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepository) FindByUsername(ctx context.Context, username string) (*domain.User, error) { return nil, nil } +func (m *mockUserRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { return nil, nil } +func (m *mockUserRepository) ListByRole(ctx context.Context, role domain.UserRole) ([]domain.User, error) { return nil, nil } +func (m *mockUserRepository) Create(ctx context.Context, entity *domain.User) error { return nil } +func (m *mockUserRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { return nil } +func (m *mockUserRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.User, error) { return nil, nil } +func (m *mockUserRepository) Update(ctx context.Context, entity *domain.User) error { return nil } +func (m *mockUserRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { return nil } +func (m *mockUserRepository) Delete(ctx context.Context, id uint) error { return nil } +func (m *mockUserRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { return nil } +func (m *mockUserRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.User], error) { return nil, nil } +func (m *mockUserRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.User, error) { return nil, nil } +func (m *mockUserRepository) ListAll(ctx context.Context) ([]domain.User, error) { return nil, nil } +func (m *mockUserRepository) Count(ctx context.Context) (int64, error) { return 0, nil } +func (m *mockUserRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { return 0, nil } +func (m *mockUserRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.User, error) { return nil, nil } +func (m *mockUserRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.User, error) { return nil, nil } +func (m *mockUserRepository) Exists(ctx context.Context, id uint) (bool, error) { return false, nil } +func (m *mockUserRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { return nil, nil } +func (m *mockUserRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { return nil } + +type mockSearchClient struct{ mock.Mock } + +func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error { + args := m.Called(ctx, work, pipeline) + return args.Error(0) +} +func (m *mockSearchClient) Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error) { + args := m.Called(ctx, query, page, pageSize, filters) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.SearchResults), args.Error(1) +} + + +type mockAnalyticsService struct{ mock.Mock } + +func (m *mockAnalyticsService) IncrementWorkTranslationCount(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkViews(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkLikes(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) IncrementWorkComments(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) IncrementWorkBookmarks(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) IncrementWorkShares(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) IncrementTranslationViews(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) IncrementTranslationLikes(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) DecrementWorkLikes(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) DecrementTranslationLikes(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) IncrementTranslationComments(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) IncrementTranslationShares(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) GetOrCreateWorkStats(ctx context.Context, workID uint) (*domain.WorkStats, error) { return nil, nil } +func (m *mockAnalyticsService) GetOrCreateTranslationStats(ctx context.Context, translationID uint) (*domain.TranslationStats, error) { return nil, nil } +func (m *mockAnalyticsService) UpdateWorkReadingTime(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) UpdateWorkComplexity(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) UpdateWorkSentiment(ctx context.Context, workID uint) error { return nil } +func (m *mockAnalyticsService) UpdateTranslationReadingTime(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) UpdateTranslationSentiment(ctx context.Context, translationID uint) error { return nil } +func (m *mockAnalyticsService) UpdateUserEngagement(ctx context.Context, userID uint, eventType string) error { return nil } +func (m *mockAnalyticsService) UpdateTrending(ctx context.Context) error { return nil } +func (m *mockAnalyticsService) GetTrendingWorks(ctx context.Context, timePeriod string, limit int) ([]*domain.Work, error) { return nil, nil } +func (m *mockAnalyticsService) UpdateWorkStats(ctx context.Context, workID uint, stats domain.WorkStats) error { return nil } + +type mockTranslationRepository struct{ mock.Mock } + +func (m *mockTranslationRepository) Upsert(ctx context.Context, translation *domain.Translation) error { + args := m.Called(ctx, translation) + return args.Error(0) +} +func (m *mockTranslationRepository) GetByID(ctx context.Context, id uint) (*domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) ListByWorkIDPaginated(ctx context.Context, workID uint, language *string, page, pageSize int) (*domain.PaginatedResult[domain.Translation], error) { return nil, nil } +func (m *mockTranslationRepository) ListByEntity(ctx context.Context, entityType string, entityID uint) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) ListByTranslatorID(ctx context.Context, translatorID uint) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) ListByStatus(ctx context.Context, status domain.TranslationStatus) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) Create(ctx context.Context, entity *domain.Translation) error { return nil } +func (m *mockTranslationRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Translation) error { return nil } +func (m *mockTranslationRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) Update(ctx context.Context, entity *domain.Translation) error { return nil } +func (m *mockTranslationRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Translation) error { return nil } +func (m *mockTranslationRepository) Delete(ctx context.Context, id uint) error { return nil } +func (m *mockTranslationRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { return nil } +func (m *mockTranslationRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Translation], error) { return nil, nil } +func (m *mockTranslationRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) ListAll(ctx context.Context) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) Count(ctx context.Context) (int64, error) { return 0, nil } +func (m *mockTranslationRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { return 0, nil } +func (m *mockTranslationRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Translation, error) { return nil, nil } +func (m *mockTranslationRepository) Exists(ctx context.Context, id uint) (bool, error) { return false, nil } +func (m *mockTranslationRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { return nil, nil } +func (m *mockTranslationRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { return nil } + + +// WorkResolversUnitSuite is a unit test suite for the work resolvers. +type WorkResolversUnitSuite struct { + suite.Suite + resolver *Resolver + mockWorkRepo *mockWorkRepository + mockAuthorRepo *mockAuthorRepository + mockUserRepo *mockUserRepository + mockTranslationRepo *mockTranslationRepository + mockSearchClient *mockSearchClient + mockAnalyticsSvc *mockAnalyticsService +} + +func (s *WorkResolversUnitSuite) SetupTest() { + s.mockWorkRepo = new(mockWorkRepository) + s.mockAuthorRepo = new(mockAuthorRepository) + s.mockUserRepo = new(mockUserRepository) + s.mockTranslationRepo = new(mockTranslationRepository) + s.mockSearchClient = new(mockSearchClient) + s.mockAnalyticsSvc = new(mockAnalyticsService) + + authzService := authz.NewService(s.mockWorkRepo, s.mockAuthorRepo, s.mockUserRepo, s.mockTranslationRepo) + workCommands := work.NewWorkCommands(s.mockWorkRepo, s.mockAuthorRepo, s.mockUserRepo, s.mockSearchClient, authzService, s.mockAnalyticsSvc) + workQueries := work.NewWorkQueries(s.mockWorkRepo) + workService := work.NewService(s.mockWorkRepo, s.mockAuthorRepo, s.mockUserRepo, s.mockSearchClient, authzService, s.mockAnalyticsSvc) + workService.Commands = workCommands + workService.Queries = workQueries + + translationCommands := translation.NewTranslationCommands(s.mockTranslationRepo, authzService) + translationService := translation.NewService(s.mockTranslationRepo, authzService) + translationService.Commands = translationCommands + + s.resolver = &Resolver{ + App: &app.Application{ + Work: workService, + Analytics: s.mockAnalyticsSvc, + Translation: translationService, + }, + } +} + +func TestWorkResolversUnitSuite(t *testing.T) { + suite.Run(t, new(WorkResolversUnitSuite)) +} + +func (s *WorkResolversUnitSuite) TestCreateWork_Unit() { + s.Run("Success", func() { + s.SetupTest() + // 1. Setup + userID := uint(1) + workID := uint(1) + authorID := uint(1) + username := "testuser" + ctx := platform_auth.ContextWithUserID(context.Background(), userID) + content := "Test Content" + input := model.WorkInput{ + Name: "Test Work", + Language: "en", + Content: &content, + } + user := &domain.User{Username: username} + user.ID = userID + author := &domain.Author{Name: username} + author.ID = authorID + work := &domain.Work{TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: workID}}} + + // 2. Mocking - Order is important here! + // --- CreateWork Command --- + // Get user to find author + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(user, nil).Once() + // Find author by name (fails first time) + s.mockAuthorRepo.On("FindByName", mock.Anything, username).Return(nil, domain.ErrEntityNotFound).Once() + // Create author + s.mockAuthorRepo.On("Create", mock.Anything, mock.AnythingOfType("*domain.Author")).Return(nil).Once() + // Create work + s.mockWorkRepo.On("Create", mock.Anything, mock.AnythingOfType("*domain.Work")).Return(nil).Once() + // Index work + s.mockSearchClient.On("IndexWork", mock.Anything, mock.Anything, "").Return(nil).Once() + + // --- CreateOrUpdateTranslation Command (called from resolver) --- + // Auth check: Get work by ID + s.mockWorkRepo.On("GetByID", mock.Anything, workID).Return(work, nil).Once() + // Auth check: Get user by ID + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(user, nil).Once() + // Auth check: Find author by name (succeeds this time) + s.mockAuthorRepo.On("FindByName", mock.Anything, username).Return(author, nil).Once() + // Auth check: Check if user is author of the work + s.mockWorkRepo.On("IsAuthor", mock.Anything, workID, authorID).Return(true, nil).Once() + // Upsert translation + s.mockTranslationRepo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.Translation")).Return(nil).Once() + + // 3. Execution + createdWork, err := s.resolver.Mutation().CreateWork(ctx, input) + + // 4. Assertions + s.Require().NoError(err) + s.Require().NotNil(createdWork) + s.Equal("Test Work", createdWork.Name) + + // 5. Verify mock calls + s.mockUserRepo.AssertExpectations(s.T()) + s.mockAuthorRepo.AssertExpectations(s.T()) + s.mockWorkRepo.AssertExpectations(s.T()) + s.mockSearchClient.AssertExpectations(s.T()) + s.mockTranslationRepo.AssertExpectations(s.T()) + }) +} + +func (s *WorkResolversUnitSuite) TestUpdateWork_Unit() { + s.Run("Success", func() { + s.SetupTest() + userID := uint(1) + workID := uint(1) + workIDStr := "1" + ctx := platform_auth.ContextWithUserID(context.Background(), userID) + input := model.WorkInput{Name: "Updated Work", Language: "en"} + author := &domain.Author{} + author.ID = 1 + + // Arrange + s.mockWorkRepo.On("GetByID", mock.Anything, workID).Return(&domain.Work{TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: workID}}}, nil).Once() + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(&domain.User{Username: "testuser"}, nil).Once() + s.mockAuthorRepo.On("FindByName", mock.Anything, "testuser").Return(author, nil).Once() + s.mockWorkRepo.On("IsAuthor", mock.Anything, workID, uint(1)).Return(true, nil).Once() + s.mockWorkRepo.On("Update", mock.Anything, mock.AnythingOfType("*domain.Work")).Return(nil).Once() + s.mockSearchClient.On("IndexWork", mock.Anything, mock.Anything, "").Return(nil).Once() + + // Act + _, err := s.resolver.Mutation().UpdateWork(ctx, workIDStr, input) + + // Assert + s.Require().NoError(err) + s.mockWorkRepo.AssertExpectations(s.T()) + s.mockUserRepo.AssertExpectations(s.T()) + s.mockAuthorRepo.AssertExpectations(s.T()) + s.mockSearchClient.AssertExpectations(s.T()) + }) +} + +func (s *WorkResolversUnitSuite) TestDeleteWork_Unit() { + s.Run("Success", func() { + s.SetupTest() + userID := uint(1) + workID := uint(1) + workIDStr := "1" + ctx := platform_auth.ContextWithUserID(context.Background(), userID) + author := &domain.Author{} + author.ID = 1 + + // Arrange + s.mockWorkRepo.On("GetByID", mock.Anything, workID).Return(&domain.Work{TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: workID}}}, nil).Once() + s.mockUserRepo.On("GetByID", mock.Anything, userID).Return(&domain.User{Username: "testuser"}, nil).Once() + s.mockAuthorRepo.On("FindByName", mock.Anything, "testuser").Return(author, nil).Once() + s.mockWorkRepo.On("IsAuthor", mock.Anything, workID, uint(1)).Return(true, nil).Once() + s.mockWorkRepo.On("Delete", mock.Anything, workID).Return(nil).Once() + + // Act + ok, err := s.resolver.Mutation().DeleteWork(ctx, workIDStr) + + // Assert + s.Require().NoError(err) + s.True(ok) + s.mockWorkRepo.AssertExpectations(s.T()) + }) +} + +func (s *WorkResolversUnitSuite) TestWorkQuery_Unit() { + s.Run("Success", func() { + s.SetupTest() + workID := uint(1) + workIDStr := "1" + ctx := context.Background() + + // Arrange + s.mockWorkRepo.On("GetByID", mock.Anything, workID).Return(&domain.Work{TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: workID}, Language: "en"}, Title: "Test Work"}, nil).Once() + s.mockAnalyticsSvc.On("IncrementWorkViews", mock.Anything, workID).Return(nil).Once() + s.mockWorkRepo.On("GetWithTranslations", mock.Anything, workID).Return(&domain.Work{ + Translations: []*domain.Translation{{Language: "en", Content: "Test Content"}}, + }, nil) + + // Act + work, err := s.resolver.Query().Work(ctx, workIDStr) + time.Sleep(200 * time.Millisecond) // Allow time for goroutine to execute + + // Assert + s.Require().NoError(err) + s.Require().NotNil(work) + s.Equal("Test Work", work.Name) + s.Equal("Test Content", *work.Content) + s.mockWorkRepo.AssertExpectations(s.T()) + s.mockAnalyticsSvc.AssertExpectations(s.T()) + }) + + s.Run("Not Found", func() { + s.SetupTest() + workID := uint(1) + workIDStr := "1" + ctx := context.Background() + s.mockWorkRepo.On("GetByID", mock.Anything, workID).Return(nil, domain.ErrEntityNotFound).Once() + + // Act + work, err := s.resolver.Query().Work(ctx, workIDStr) + + // Assert + s.Require().NoError(err) + s.Require().Nil(work) + s.mockWorkRepo.AssertExpectations(s.T()) + }) +} + +func (s *WorkResolversUnitSuite) TestWorksQuery_Unit() { + ctx := context.Background() + s.Run("Success", func() { + s.SetupTest() + limit := int32(10) + offset := int32(0) + s.mockWorkRepo.On("List", mock.Anything, 1, 10).Return(&domain.PaginatedResult[domain.Work]{ + Items: []domain.Work{ + {TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: 1}, Language: "en"}, Title: "Work 1"}, + }, + }, nil) + s.mockWorkRepo.On("GetWithTranslations", mock.Anything, uint(1)).Return(&domain.Work{ + Translations: []*domain.Translation{{Language: "en", Content: "Content 1"}}, + }, nil) + + _, err := s.resolver.Query().Works(ctx, &limit, &offset, nil, nil, nil, nil, nil) + s.Require().NoError(err) + }) +} \ No newline at end of file diff --git a/internal/app/analytics/service_test.go b/internal/app/analytics/service_test.go index 66b0871..ce0ba99 100644 --- a/internal/app/analytics/service_test.go +++ b/internal/app/analytics/service_test.go @@ -49,7 +49,7 @@ func (s *AnalyticsServiceTestSuite) SetupTest() { func (s *AnalyticsServiceTestSuite) TestIncrementWorkViews() { s.Run("should increment the view count for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Act err := s.service.IncrementWorkViews(context.Background(), work.ID) @@ -65,7 +65,7 @@ func (s *AnalyticsServiceTestSuite) TestIncrementWorkViews() { func (s *AnalyticsServiceTestSuite) TestIncrementWorkLikes() { s.Run("should increment the like count for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Act err := s.service.IncrementWorkLikes(context.Background(), work.ID) @@ -81,7 +81,7 @@ func (s *AnalyticsServiceTestSuite) TestIncrementWorkLikes() { func (s *AnalyticsServiceTestSuite) TestIncrementWorkComments() { s.Run("should increment the comment count for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Act err := s.service.IncrementWorkComments(context.Background(), work.ID) @@ -97,7 +97,7 @@ func (s *AnalyticsServiceTestSuite) TestIncrementWorkComments() { func (s *AnalyticsServiceTestSuite) TestIncrementWorkBookmarks() { s.Run("should increment the bookmark count for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Act err := s.service.IncrementWorkBookmarks(context.Background(), work.ID) @@ -113,7 +113,7 @@ func (s *AnalyticsServiceTestSuite) TestIncrementWorkBookmarks() { func (s *AnalyticsServiceTestSuite) TestIncrementWorkShares() { s.Run("should increment the share count for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Act err := s.service.IncrementWorkShares(context.Background(), work.ID) @@ -129,7 +129,7 @@ func (s *AnalyticsServiceTestSuite) TestIncrementWorkShares() { func (s *AnalyticsServiceTestSuite) TestIncrementWorkTranslationCount() { s.Run("should increment the translation count for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") // Act err := s.service.IncrementWorkTranslationCount(context.Background(), work.ID) @@ -145,7 +145,7 @@ func (s *AnalyticsServiceTestSuite) TestIncrementWorkTranslationCount() { func (s *AnalyticsServiceTestSuite) TestUpdateWorkReadingTime() { s.Run("should update the reading time for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") s.DB.Create(&domain.ReadabilityScore{WorkID: work.ID}) s.DB.Create(&domain.LanguageAnalysis{WorkID: work.ID, Analysis: domain.JSONB{}}) textMetadata := &domain.TextMetadata{ @@ -168,7 +168,7 @@ func (s *AnalyticsServiceTestSuite) TestUpdateWorkReadingTime() { func (s *AnalyticsServiceTestSuite) TestUpdateTranslationReadingTime() { s.Run("should update the reading time for a translation", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") translation := s.CreateTestTranslation(work.ID, "es", strings.Repeat("Contenido de prueba con quinientas palabras. ", 100)) // Act @@ -185,7 +185,7 @@ func (s *AnalyticsServiceTestSuite) TestUpdateTranslationReadingTime() { func (s *AnalyticsServiceTestSuite) TestUpdateWorkComplexity() { s.Run("should update the complexity for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") s.DB.Create(&domain.TextMetadata{WorkID: work.ID}) s.DB.Create(&domain.LanguageAnalysis{WorkID: work.ID, Analysis: domain.JSONB{}}) readabilityScore := &domain.ReadabilityScore{ @@ -208,7 +208,7 @@ func (s *AnalyticsServiceTestSuite) TestUpdateWorkComplexity() { func (s *AnalyticsServiceTestSuite) TestUpdateWorkSentiment() { s.Run("should update the sentiment for a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") s.DB.Create(&domain.TextMetadata{WorkID: work.ID}) s.DB.Create(&domain.ReadabilityScore{WorkID: work.ID}) languageAnalysis := &domain.LanguageAnalysis{ @@ -233,7 +233,7 @@ func (s *AnalyticsServiceTestSuite) TestUpdateWorkSentiment() { func (s *AnalyticsServiceTestSuite) TestUpdateTranslationSentiment() { s.Run("should update the sentiment for a translation", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") translation := s.CreateTestTranslation(work.ID, "en", "This is a wonderfully positive and uplifting sentence.") // Act @@ -250,8 +250,8 @@ func (s *AnalyticsServiceTestSuite) TestUpdateTranslationSentiment() { func (s *AnalyticsServiceTestSuite) TestUpdateTrending() { s.Run("should update the trending works", func() { // Arrange - work1 := s.CreateTestWork("Work 1", "en", "content") - work2 := s.CreateTestWork("Work 2", "en", "content") + work1 := s.CreateTestWork(s.AdminCtx, "Work 1", "en", "content") + work2 := s.CreateTestWork(s.AdminCtx, "Work 2", "en", "content") s.DB.Create(&domain.WorkStats{WorkID: work1.ID, Views: 100, Likes: 10, Comments: 1}) s.DB.Create(&domain.WorkStats{WorkID: work2.ID, Views: 10, Likes: 100, Comments: 10}) diff --git a/internal/app/authz/authz.go b/internal/app/authz/authz.go index 4c65cb2..88f4804 100644 --- a/internal/app/authz/authz.go +++ b/internal/app/authz/authz.go @@ -9,13 +9,17 @@ import ( // Service provides authorization checks for the application. type Service struct { workRepo domain.WorkRepository + authorRepo domain.AuthorRepository + userRepo domain.UserRepository translationRepo domain.TranslationRepository } // NewService creates a new authorization service. -func NewService(workRepo domain.WorkRepository, translationRepo domain.TranslationRepository) *Service { +func NewService(workRepo domain.WorkRepository, authorRepo domain.AuthorRepository, userRepo domain.UserRepository, translationRepo domain.TranslationRepository) *Service { return &Service{ workRepo: workRepo, + authorRepo: authorRepo, + userRepo: userRepo, translationRepo: translationRepo, } } @@ -33,8 +37,19 @@ func (s *Service) CanEditWork(ctx context.Context, userID uint, work *domain.Wor return true, nil } + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return false, err + } + + author, err := s.authorRepo.FindByName(ctx, user.Username) + if err != nil { + // If the author profile doesn't exist for the user, they can't be the author. + return false, nil + } + // Check if the user is an author of the work. - isAuthor, err := s.workRepo.IsAuthor(ctx, work.ID, userID) + isAuthor, err := s.workRepo.IsAuthor(ctx, work.ID, author.ID) if err != nil { return false, err } @@ -46,14 +61,37 @@ func (s *Service) CanEditWork(ctx context.Context, userID uint, work *domain.Wor } // CanDeleteWork checks if a user has permission to delete a work. -func (s *Service) CanDeleteWork(ctx context.Context) (bool, error) { +func (s *Service) CanDeleteWork(ctx context.Context, userID uint, work *domain.Work) (bool, error) { claims, ok := platform_auth.GetClaimsFromContext(ctx) if !ok { return false, domain.ErrUnauthorized } + + // Admins can do anything. if claims.Role == string(domain.UserRoleAdmin) { return true, nil } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return false, err + } + + author, err := s.authorRepo.FindByName(ctx, user.Username) + if err != nil { + // If the author profile doesn't exist for the user, they can't be the author. + return false, nil + } + + // Check if the user is an author of the work. + isAuthor, err := s.workRepo.IsAuthor(ctx, work.ID, author.ID) + if err != nil { + return false, err + } + if isAuthor { + return true, nil + } + return false, domain.ErrForbidden } @@ -76,7 +114,7 @@ func (s *Service) CanEditEntity(ctx context.Context, userID uint, translatableTy } // CanDeleteTranslation checks if a user can delete a translation. -func (s *Service) CanDeleteTranslation(ctx context.Context) (bool, error) { +func (s *Service) CanDeleteTranslation(ctx context.Context, userID uint, translationID uint) (bool, error) { claims, ok := platform_auth.GetClaimsFromContext(ctx) if !ok { return false, domain.ErrUnauthorized @@ -87,6 +125,15 @@ func (s *Service) CanDeleteTranslation(ctx context.Context) (bool, error) { return true, nil } + translation, err := s.translationRepo.GetByID(ctx, translationID) + if err != nil { + return false, err + } + + if translation.TranslatorID != nil && *translation.TranslatorID == userID { + return true, nil + } + return false, domain.ErrForbidden } diff --git a/internal/app/copyright/commands_integration_test.go b/internal/app/copyright/commands_integration_test.go index 1c098a7..9b3bea5 100644 --- a/internal/app/copyright/commands_integration_test.go +++ b/internal/app/copyright/commands_integration_test.go @@ -25,7 +25,7 @@ func (s *CopyrightCommandsTestSuite) SetupSuite() { func (s *CopyrightCommandsTestSuite) TestAddCopyrightToWork() { s.Run("should add a copyright to a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") copyright := &domain.Copyright{Name: "Test Copyright", Identificator: "TC-123"} s.Require().NoError(s.CopyrightRepo.Create(context.Background(), copyright)) @@ -47,7 +47,7 @@ func (s *CopyrightCommandsTestSuite) TestAddCopyrightToWork() { func (s *CopyrightCommandsTestSuite) TestRemoveCopyrightFromWork() { s.Run("should remove a copyright from a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") copyright := &domain.Copyright{Name: "Test Copyright", Identificator: "TC-123"} s.Require().NoError(s.CopyrightRepo.Create(context.Background(), copyright)) s.Require().NoError(s.commands.AddCopyrightToWork(context.Background(), work.ID, copyright.ID)) diff --git a/internal/app/monetization/commands_integration_test.go b/internal/app/monetization/commands_integration_test.go index 188886c..b7581c3 100644 --- a/internal/app/monetization/commands_integration_test.go +++ b/internal/app/monetization/commands_integration_test.go @@ -25,7 +25,7 @@ func (s *MonetizationCommandsTestSuite) SetupSuite() { func (s *MonetizationCommandsTestSuite) TestAddMonetizationToWork() { s.Run("should add a monetization to a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") monetization := &domain.Monetization{Amount: 10.0} s.Require().NoError(s.DB.Create(monetization).Error) diff --git a/internal/app/translation/commands.go b/internal/app/translation/commands.go index 4ca9060..06cfa5f 100644 --- a/internal/app/translation/commands.go +++ b/internal/app/translation/commands.go @@ -100,7 +100,13 @@ func (c *TranslationCommands) CreateOrUpdateTranslation(ctx context.Context, inp func (c *TranslationCommands) DeleteTranslation(ctx context.Context, id uint) error { ctx, span := c.tracer.Start(ctx, "DeleteTranslation") defer span.End() - can, err := c.authzSvc.CanDeleteTranslation(ctx) + + userID, ok := platform_auth.GetUserIDFromContext(ctx) + if !ok { + return domain.ErrUnauthorized + } + + can, err := c.authzSvc.CanDeleteTranslation(ctx, userID, id) if err != nil { return err } diff --git a/internal/app/translation/commands_test.go b/internal/app/translation/commands_test.go index ce85e88..66f67fc 100644 --- a/internal/app/translation/commands_test.go +++ b/internal/app/translation/commands_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "gorm.io/gorm" "tercul/internal/app/authz" "tercul/internal/app/translation" "tercul/internal/domain" @@ -15,10 +16,146 @@ import ( "github.com/stretchr/testify/suite" ) +// MockAuthorRepository is a mock implementation of the AuthorRepository interface. +type mockAuthorRepository struct{ mock.Mock } + +func (m *mockAuthorRepository) Create(ctx context.Context, entity *domain.Author) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Author) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) GetByID(ctx context.Context, id uint) (*domain.Author, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Author, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Update(ctx context.Context, entity *domain.Author) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Author) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} +func (m *mockAuthorRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) +} +func (m *mockAuthorRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Author], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.Author]), args.Error(1) +} +func (m *mockAuthorRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Author, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListAll(ctx context.Context) ([]domain.Author, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockAuthorRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockAuthorRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Author, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Author, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} +func (m *mockAuthorRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} +func (m *mockAuthorRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return fn(nil) +} +func (m *mockAuthorRepository) FindByName(ctx context.Context, name string) (*domain.Author, error) { + args := m.Called(ctx, name) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Author, error) { + args := m.Called(ctx, workID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListByBookID(ctx context.Context, bookID uint) ([]domain.Author, error) { + args := m.Called(ctx, bookID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListByCountryID(ctx context.Context, countryID uint) ([]domain.Author, error) { + args := m.Called(ctx, countryID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Author, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} + type TranslationCommandsTestSuite struct { suite.Suite mockWorkRepo *testutil.MockWorkRepository mockTranslationRepo *testutil.MockTranslationRepository + mockAuthorRepo *mockAuthorRepository + mockUserRepo *testutil.MockUserRepository authzSvc *authz.Service cmd *translation.TranslationCommands adminCtx context.Context @@ -30,11 +167,13 @@ type TranslationCommandsTestSuite struct { func (s *TranslationCommandsTestSuite) SetupTest() { s.mockWorkRepo = new(testutil.MockWorkRepository) s.mockTranslationRepo = new(testutil.MockTranslationRepository) - s.authzSvc = authz.NewService(s.mockWorkRepo, s.mockTranslationRepo) + s.mockAuthorRepo = new(mockAuthorRepository) + s.mockUserRepo = new(testutil.MockUserRepository) + s.authzSvc = authz.NewService(s.mockWorkRepo, s.mockAuthorRepo, s.mockUserRepo, s.mockTranslationRepo) s.cmd = translation.NewTranslationCommands(s.mockTranslationRepo, s.authzSvc) - s.adminUser = &domain.User{BaseModel: domain.BaseModel{ID: 1}, Role: domain.UserRoleAdmin} - s.regularUser = &domain.User{BaseModel: domain.BaseModel{ID: 2}, Role: domain.UserRoleContributor} + s.adminUser = &domain.User{BaseModel: domain.BaseModel{ID: 1}, Role: domain.UserRoleAdmin, Username: "admin"} + s.regularUser = &domain.User{BaseModel: domain.BaseModel{ID: 2}, Role: domain.UserRoleContributor, Username: "contributor"} s.adminCtx = context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{ UserID: s.adminUser.ID, @@ -50,7 +189,11 @@ func (s *TranslationCommandsTestSuite) TestCreateOrUpdateTranslation() { testWork := &domain.Work{ TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: 1}}, } - input := translation.CreateOrUpdateTranslationInput{ + testAuthor := &domain.Author{ + TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: 1}}, + Name: s.regularUser.Username, + } + baseInput := translation.CreateOrUpdateTranslationInput{ Title: "Test Title", Content: "Test content", Language: "es", @@ -59,6 +202,8 @@ func (s *TranslationCommandsTestSuite) TestCreateOrUpdateTranslation() { } s.Run("should create translation for admin", func() { + s.SetupTest() + input := baseInput // Arrange s.mockWorkRepo.On("GetByID", mock.Anything, testWork.ID).Return(testWork, nil).Once() s.mockTranslationRepo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.Translation")).Return(nil).Once() @@ -76,9 +221,13 @@ func (s *TranslationCommandsTestSuite) TestCreateOrUpdateTranslation() { }) s.Run("should create translation for author", func() { + s.SetupTest() + input := baseInput // Arrange + s.mockUserRepo.On("GetByID", mock.Anything, s.regularUser.ID).Return(s.regularUser, nil).Once() + s.mockAuthorRepo.On("FindByName", mock.Anything, s.regularUser.Username).Return(testAuthor, nil).Once() s.mockWorkRepo.On("GetByID", mock.Anything, testWork.ID).Return(testWork, nil).Once() - s.mockWorkRepo.On("IsAuthor", mock.Anything, testWork.ID, s.regularUser.ID).Return(true, nil).Once() + s.mockWorkRepo.On("IsAuthor", mock.Anything, testWork.ID, testAuthor.ID).Return(true, nil).Once() s.mockTranslationRepo.On("Upsert", mock.Anything, mock.AnythingOfType("*domain.Translation")).Return(nil).Once() // Act @@ -89,14 +238,20 @@ func (s *TranslationCommandsTestSuite) TestCreateOrUpdateTranslation() { s.NotNil(result) s.Equal(input.Title, result.Title) s.Equal(s.regularUser.ID, *result.TranslatorID) + s.mockUserRepo.AssertExpectations(s.T()) + s.mockAuthorRepo.AssertExpectations(s.T()) s.mockWorkRepo.AssertExpectations(s.T()) s.mockTranslationRepo.AssertExpectations(s.T()) }) s.Run("should fail if user is not authorized", func() { + s.SetupTest() + input := baseInput // Arrange + s.mockUserRepo.On("GetByID", mock.Anything, s.regularUser.ID).Return(s.regularUser, nil).Once() + s.mockAuthorRepo.On("FindByName", mock.Anything, s.regularUser.Username).Return(testAuthor, nil).Once() s.mockWorkRepo.On("GetByID", mock.Anything, testWork.ID).Return(testWork, nil).Once() - s.mockWorkRepo.On("IsAuthor", mock.Anything, testWork.ID, s.regularUser.ID).Return(false, nil).Once() + s.mockWorkRepo.On("IsAuthor", mock.Anything, testWork.ID, testAuthor.ID).Return(false, nil).Once() // Act _, err := s.cmd.CreateOrUpdateTranslation(s.userCtx, input) @@ -104,16 +259,19 @@ func (s *TranslationCommandsTestSuite) TestCreateOrUpdateTranslation() { // Assert s.Error(err) s.ErrorIs(err, domain.ErrForbidden) + s.mockUserRepo.AssertExpectations(s.T()) + s.mockAuthorRepo.AssertExpectations(s.T()) s.mockWorkRepo.AssertExpectations(s.T()) }) s.Run("should fail on validation error for empty language", func() { + s.SetupTest() // Arrange - invalidInput := input - invalidInput.Language = "" + input := baseInput + input.Language = "" // Act - _, err := s.cmd.CreateOrUpdateTranslation(s.userCtx, invalidInput) + _, err := s.cmd.CreateOrUpdateTranslation(s.userCtx, input) // Assert s.Error(err) diff --git a/internal/app/user/commands_test.go b/internal/app/user/commands_test.go index e1649f7..c202efc 100644 --- a/internal/app/user/commands_test.go +++ b/internal/app/user/commands_test.go @@ -2,6 +2,7 @@ package user import ( "context" + "errors" "testing" "tercul/internal/app/authz" @@ -9,6 +10,7 @@ import ( platform_auth "tercul/internal/platform/auth" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -20,9 +22,9 @@ type UserCommandsSuite struct { } func (s *UserCommandsSuite) SetupTest() { - s.repo = &mockUserRepository{} - workRepo := &mockWorkRepoForUserTests{} - s.authzSvc = authz.NewService(workRepo, nil) // Translation repo not needed for user tests + s.repo = new(mockUserRepository) + // None of the repos are used by the authz checks in these command tests + s.authzSvc = authz.NewService(nil, nil, nil, nil) s.commands = NewUserCommands(s.repo, s.authzSvc) } @@ -35,9 +37,8 @@ func (s *UserCommandsSuite) TestUpdateUser_Success_Self() { ctx := platform_auth.ContextWithUserID(context.Background(), 1) input := UpdateUserInput{ID: 1, Username: strPtr("new_username")} - s.repo.getByIDFunc = func(ctx context.Context, id uint) (*domain.User, error) { - return &domain.User{BaseModel: domain.BaseModel{ID: 1}}, nil - } + s.repo.On("GetByID", ctx, uint(1)).Return(&domain.User{BaseModel: domain.BaseModel{ID: 1}}, nil) + s.repo.On("Update", ctx, mock.AnythingOfType("*domain.User")).Return(nil) // Act updatedUser, err := s.commands.UpdateUser(ctx, input) @@ -46,6 +47,7 @@ func (s *UserCommandsSuite) TestUpdateUser_Success_Self() { assert.NoError(s.T(), err) assert.NotNil(s.T(), updatedUser) assert.Equal(s.T(), "new_username", updatedUser.Username) + s.repo.AssertExpectations(s.T()) } func (s *UserCommandsSuite) TestUpdateUser_Success_Admin() { @@ -53,9 +55,8 @@ func (s *UserCommandsSuite) TestUpdateUser_Success_Admin() { ctx := platform_auth.ContextWithAdminUser(context.Background(), 99) // Admin user input := UpdateUserInput{ID: 1, Username: strPtr("new_username_by_admin")} - s.repo.getByIDFunc = func(ctx context.Context, id uint) (*domain.User, error) { - return &domain.User{BaseModel: domain.BaseModel{ID: 1}}, nil - } + s.repo.On("GetByID", ctx, uint(1)).Return(&domain.User{BaseModel: domain.BaseModel{ID: 1}}, nil) + s.repo.On("Update", ctx, mock.AnythingOfType("*domain.User")).Return(nil) // Act updatedUser, err := s.commands.UpdateUser(ctx, input) @@ -64,6 +65,7 @@ func (s *UserCommandsSuite) TestUpdateUser_Success_Admin() { assert.NoError(s.T(), err) assert.NotNil(s.T(), updatedUser) assert.Equal(s.T(), "new_username_by_admin", updatedUser.Username) + s.repo.AssertExpectations(s.T()) } func (s *UserCommandsSuite) TestUpdateUser_Forbidden() { @@ -71,9 +73,7 @@ func (s *UserCommandsSuite) TestUpdateUser_Forbidden() { ctx := platform_auth.ContextWithUserID(context.Background(), 2) // Different user input := UpdateUserInput{ID: 1, Username: strPtr("forbidden_username")} - s.repo.getByIDFunc = func(ctx context.Context, id uint) (*domain.User, error) { - return &domain.User{BaseModel: domain.BaseModel{ID: 1}}, nil - } + // No need to mock GetByID, as the auth check happens first. // Act _, err := s.commands.UpdateUser(ctx, input) @@ -81,6 +81,7 @@ func (s *UserCommandsSuite) TestUpdateUser_Forbidden() { // Assert assert.Error(s.T(), err) assert.ErrorIs(s.T(), err, domain.ErrForbidden) + s.repo.AssertNotCalled(s.T(), "GetByID", mock.Anything, mock.Anything) } func (s *UserCommandsSuite) TestUpdateUser_Unauthorized() { @@ -94,9 +95,220 @@ func (s *UserCommandsSuite) TestUpdateUser_Unauthorized() { // Assert assert.Error(s.T(), err) assert.ErrorIs(s.T(), err, domain.ErrUnauthorized) + s.repo.AssertNotCalled(s.T(), "GetByID", mock.Anything, mock.Anything) } // Helper to get a pointer to a string func strPtr(s string) *string { return &s +} + +func (s *UserCommandsSuite) TestCreateUser() { + // Arrange + ctx := context.Background() + input := CreateUserInput{ + Username: "newuser", + Email: "new@example.com", + Password: "password", + } + s.repo.On("Create", ctx, mock.AnythingOfType("*domain.User")).Return(nil) + + // Act + user, err := s.commands.CreateUser(ctx, input) + + // Assert + assert.NoError(s.T(), err) + assert.NotNil(s.T(), user) + assert.Equal(s.T(), "newuser", user.Username) + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestDeleteUser_Success() { + // Arrange + ctx := platform_auth.ContextWithAdminUser(context.Background(), 99) + s.repo.On("Delete", ctx, uint(1)).Return(nil) + + // Act + err := s.commands.DeleteUser(ctx, 1) + + // Assert + assert.NoError(s.T(), err) + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestDeleteUser_Forbidden() { + // Arrange + ctx := platform_auth.ContextWithUserID(context.Background(), 2) // Non-admin user + + // Act + err := s.commands.DeleteUser(ctx, 1) + + // Assert + assert.Error(s.T(), err) + assert.ErrorIs(s.T(), err, domain.ErrForbidden) + s.repo.AssertNotCalled(s.T(), "Delete", mock.Anything, mock.Anything) +} + +func (s *UserCommandsSuite) TestUpdateUser_NotFound() { + // Arrange + ctx := platform_auth.ContextWithUserID(context.Background(), 1) + input := UpdateUserInput{ID: 1, Username: strPtr("new_username")} + + s.repo.On("GetByID", ctx, uint(1)).Return(nil, domain.ErrEntityNotFound) + + // Act + _, err := s.commands.UpdateUser(ctx, input) + + // Assert + assert.Error(s.T(), err) + assert.ErrorIs(s.T(), err, domain.ErrEntityNotFound) + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestCreateUser_Fails() { + // Arrange + ctx := context.Background() + input := CreateUserInput{ + Username: "newuser", + Email: "new@example.com", + Password: "password", + } + s.repo.On("Create", ctx, mock.AnythingOfType("*domain.User")).Return(errors.New("db error")) + + // Act + _, err := s.commands.CreateUser(ctx, input) + + // Assert + assert.Error(s.T(), err) + assert.EqualError(s.T(), err, "db error") + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestDeleteUser_Unauthorized() { + // Arrange + ctx := context.Background() // No user in context + + // Act + err := s.commands.DeleteUser(ctx, 1) + + // Assert + assert.Error(s.T(), err) + assert.ErrorIs(s.T(), err, domain.ErrUnauthorized) + s.repo.AssertNotCalled(s.T(), "Delete", mock.Anything, mock.Anything) +} + +func (s *UserCommandsSuite) TestDeleteUser_AuthzFails() { + // Arrange + // This test requires a mock for the authz service, which is not currently mocked. + // For now, this highlights a gap. To properly test this, we would need to + // inject a mockable authz service. + // Since the current authz service is a concrete implementation, we can't easily + // simulate an error from `CanUpdateUser`. We will skip this test for now + // as it requires a larger refactoring of the authz service dependency. + s.T().Skip("Skipping test for authz failure as it requires mockable authz service") +} + +func (s *UserCommandsSuite) TestUpdateUser_UpdateFails() { + // Arrange + ctx := platform_auth.ContextWithUserID(context.Background(), 1) + input := UpdateUserInput{ID: 1, Username: strPtr("new_username")} + testUser := &domain.User{BaseModel: domain.BaseModel{ID: 1}} + + s.repo.On("GetByID", ctx, uint(1)).Return(testUser, nil) + s.repo.On("Update", ctx, mock.AnythingOfType("*domain.User")).Return(errors.New("db error")) + + // Act + _, err := s.commands.UpdateUser(ctx, input) + + // Assert + assert.Error(s.T(), err) + assert.EqualError(s.T(), err, "db error") + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestUpdateUser_SetPasswordFails() { + // Arrange + ctx := platform_auth.ContextWithUserID(context.Background(), 1) + emptyPassword := "" + input := UpdateUserInput{ID: 1, Password: &emptyPassword} + testUser := &domain.User{BaseModel: domain.BaseModel{ID: 1}} + + s.repo.On("GetByID", ctx, uint(1)).Return(testUser, nil) + + // Act + _, err := s.commands.UpdateUser(ctx, input) + + // Assert + assert.Error(s.T(), err) + assert.EqualError(s.T(), err, "password cannot be empty") + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestUpdateUser_AllFields() { + // Arrange + ctx := platform_auth.ContextWithUserID(context.Background(), 1) + countryID := uint(10) + cityID := uint(20) + addressID := uint(30) + newRole := domain.UserRoleEditor + verified := true + active := false + + input := UpdateUserInput{ + ID: 1, + Username: strPtr("all_fields"), + Email: strPtr("all@fields.com"), + Password: strPtr("new_password"), + FirstName: strPtr("First"), + LastName: strPtr("Last"), + DisplayName: strPtr("Display"), + Bio: strPtr("Bio"), + AvatarURL: strPtr("http://avatar.url"), + Role: &newRole, + Verified: &verified, + Active: &active, + CountryID: &countryID, + CityID: &cityID, + AddressID: &addressID, + } + + s.repo.On("GetByID", ctx, uint(1)).Return(&domain.User{BaseModel: domain.BaseModel{ID: 1}}, nil) + s.repo.On("Update", ctx, mock.AnythingOfType("*domain.User")).Run(func(args mock.Arguments) { + userArg := args.Get(1).(*domain.User) + assert.Equal(s.T(), "all_fields", userArg.Username) + assert.Equal(s.T(), "all@fields.com", userArg.Email) + assert.True(s.T(), userArg.CheckPassword("new_password")) + assert.Equal(s.T(), "First", userArg.FirstName) + assert.Equal(s.T(), "Last", userArg.LastName) + assert.Equal(s.T(), "Display", userArg.DisplayName) + assert.Equal(s.T(), "Bio", userArg.Bio) + assert.Equal(s.T(), "http://avatar.url", userArg.AvatarURL) + assert.Equal(s.T(), newRole, userArg.Role) + assert.Equal(s.T(), verified, userArg.Verified) + assert.Equal(s.T(), active, userArg.Active) + assert.Equal(s.T(), &countryID, userArg.CountryID) + assert.Equal(s.T(), &cityID, userArg.CityID) + assert.Equal(s.T(), &addressID, userArg.AddressID) + }).Return(nil) + + // Act + _, err := s.commands.UpdateUser(ctx, input) + + // Assert + assert.NoError(s.T(), err) + s.repo.AssertExpectations(s.T()) +} + +func (s *UserCommandsSuite) TestDeleteUser_NotFound() { + // Arrange + ctx := platform_auth.ContextWithAdminUser(context.Background(), 99) + s.repo.On("Delete", ctx, uint(1)).Return(domain.ErrEntityNotFound) + + // Act + err := s.commands.DeleteUser(ctx, 1) + + // Assert + assert.Error(s.T(), err) + assert.ErrorIs(s.T(), err, domain.ErrEntityNotFound) + s.repo.AssertExpectations(s.T()) } \ No newline at end of file diff --git a/internal/app/user/main_test.go b/internal/app/user/main_test.go index 9322f61..3e24fa2 100644 --- a/internal/app/user/main_test.go +++ b/internal/app/user/main_test.go @@ -3,30 +3,150 @@ package user import ( "context" "tercul/internal/domain" + + "github.com/stretchr/testify/mock" + "gorm.io/gorm" ) +// mockUserRepository is a mock implementation of the UserRepository type mockUserRepository struct { - domain.UserRepository - createFunc func(ctx context.Context, user *domain.User) error - updateFunc func(ctx context.Context, user *domain.User) error - getByIDFunc func(ctx context.Context, id uint) (*domain.User, error) + mock.Mock } -func (m *mockUserRepository) Create(ctx context.Context, user *domain.User) error { - if m.createFunc != nil { - return m.createFunc(ctx, user) - } - return nil +func (m *mockUserRepository) Create(ctx context.Context, entity *domain.User) error { + args := m.Called(ctx, entity) + return args.Error(0) } -func (m *mockUserRepository) Update(ctx context.Context, user *domain.User) error { - if m.updateFunc != nil { - return m.updateFunc(ctx, user) - } - return nil + +func (m *mockUserRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) } + func (m *mockUserRepository) GetByID(ctx context.Context, id uint) (*domain.User, error) { - if m.getByIDFunc != nil { - return m.getByIDFunc(ctx, id) + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) } - return &domain.User{BaseModel: domain.BaseModel{ID: id}}, nil + return args.Get(0).(*domain.User), args.Error(1) +} + +func (m *mockUserRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.User, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} + +func (m *mockUserRepository) Update(ctx context.Context, entity *domain.User) error { + args := m.Called(ctx, entity) + return args.Error(0) +} + +func (m *mockUserRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} + +func (m *mockUserRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} + +func (m *mockUserRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) +} + +func (m *mockUserRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.User], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.User]), args.Error(1) +} + +func (m *mockUserRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.User, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} + +func (m *mockUserRepository) ListAll(ctx context.Context) ([]domain.User, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} + +func (m *mockUserRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} + +func (m *mockUserRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} + +func (m *mockUserRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.User, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} + +func (m *mockUserRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.User, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} + +func (m *mockUserRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} + +func (m *mockUserRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} + +func (m *mockUserRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + args := m.Called(ctx, fn) + return args.Error(0) +} + +func (m *mockUserRepository) FindByUsername(ctx context.Context, username string) (*domain.User, error) { + args := m.Called(ctx, username) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} + +func (m *mockUserRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} + +func (m *mockUserRepository) ListByRole(ctx context.Context, role domain.UserRole) ([]domain.User, error) { + args := m.Called(ctx, role) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) } \ No newline at end of file diff --git a/internal/app/user/queries_test.go b/internal/app/user/queries_test.go new file mode 100644 index 0000000..cc0dc7f --- /dev/null +++ b/internal/app/user/queries_test.go @@ -0,0 +1,278 @@ +package user + +import ( + "context" + "testing" + + "tercul/internal/app/authz" + "tercul/internal/domain" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +// mockUserProfileRepository is a mock implementation of the UserProfileRepository +type mockUserProfileRepository struct { + mock.Mock +} + +func (m *mockUserProfileRepository) GetByUserID(ctx context.Context, userID uint) (*domain.UserProfile, error) { + args := m.Called(ctx, userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) Create(ctx context.Context, entity *domain.UserProfile) error { + args := m.Called(ctx, entity) + return args.Error(0) +} + +func (m *mockUserProfileRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.UserProfile) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} + +func (m *mockUserProfileRepository) GetByID(ctx context.Context, id uint) (*domain.UserProfile, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.UserProfile, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) Update(ctx context.Context, entity *domain.UserProfile) error { + args := m.Called(ctx, entity) + return args.Error(0) +} + +func (m *mockUserProfileRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.UserProfile) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} + +func (m *mockUserProfileRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} + +func (m *mockUserProfileRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) +} + +func (m *mockUserProfileRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.UserProfile], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.UserProfile]), args.Error(1) +} + +func (m *mockUserProfileRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.UserProfile, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) ListAll(ctx context.Context) ([]domain.UserProfile, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} + +func (m *mockUserProfileRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} + +func (m *mockUserProfileRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.UserProfile, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.UserProfile, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.UserProfile), args.Error(1) +} + +func (m *mockUserProfileRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} + +func (m *mockUserProfileRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} + +func (m *mockUserProfileRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + args := m.Called(ctx, fn) + return args.Error(0) +} + +type UserQueriesSuite struct { + suite.Suite + userRepo *mockUserRepository + profileRepo *mockUserProfileRepository + queries *UserQueries +} + +func (s *UserQueriesSuite) SetupTest() { + s.userRepo = new(mockUserRepository) + s.profileRepo = new(mockUserProfileRepository) + s.queries = NewUserQueries(s.userRepo, s.profileRepo) +} + +func TestUserQueriesSuite(t *testing.T) { + suite.Run(t, new(UserQueriesSuite)) +} + +func (s *UserQueriesSuite) TestUser() { + // Arrange + ctx := context.Background() + testUser := &domain.User{BaseModel: domain.BaseModel{ID: 1}, Username: "testuser"} + + s.userRepo.On("GetByID", ctx, uint(1)).Return(testUser, nil) + + // Act + user, err := s.queries.User(ctx, 1) + + // Assert + assert.NoError(s.T(), err) + assert.NotNil(s.T(), user) + assert.Equal(s.T(), "testuser", user.Username) + s.userRepo.AssertExpectations(s.T()) +} + +func (s *UserQueriesSuite) TestUserByUsername() { + // Arrange + ctx := context.Background() + testUser := &domain.User{BaseModel: domain.BaseModel{ID: 1}, Username: "testuser"} + + s.userRepo.On("FindByUsername", ctx, "testuser").Return(testUser, nil) + + // Act + user, err := s.queries.UserByUsername(ctx, "testuser") + + // Assert + assert.NoError(s.T(), err) + assert.NotNil(s.T(), user) + assert.Equal(s.T(), uint(1), user.ID) + s.userRepo.AssertExpectations(s.T()) +} + +func (s *UserQueriesSuite) TestUserByEmail() { + // Arrange + ctx := context.Background() + testUser := &domain.User{BaseModel: domain.BaseModel{ID: 1}, Email: "test@example.com"} + + s.userRepo.On("FindByEmail", ctx, "test@example.com").Return(testUser, nil) + + // Act + user, err := s.queries.UserByEmail(ctx, "test@example.com") + + // Assert + assert.NoError(s.T(), err) + assert.NotNil(s.T(), user) + assert.Equal(s.T(), uint(1), user.ID) + s.userRepo.AssertExpectations(s.T()) +} + +func (s *UserQueriesSuite) TestUsersByRole() { + // Arrange + ctx := context.Background() + testUsers := []domain.User{ + {BaseModel: domain.BaseModel{ID: 1}, Role: domain.UserRoleAdmin}, + } + + s.userRepo.On("ListByRole", ctx, domain.UserRoleAdmin).Return(testUsers, nil) + + // Act + users, err := s.queries.UsersByRole(ctx, domain.UserRoleAdmin) + + // Assert + assert.NoError(s.T(), err) + assert.Len(s.T(), users, 1) + s.userRepo.AssertExpectations(s.T()) +} + +func (s *UserQueriesSuite) TestUsers() { + // Arrange + ctx := context.Background() + testUsers := []domain.User{ + {BaseModel: domain.BaseModel{ID: 1}}, + {BaseModel: domain.BaseModel{ID: 2}}, + } + + s.userRepo.On("ListAll", ctx).Return(testUsers, nil) + + // Act + users, err := s.queries.Users(ctx) + + // Assert + assert.NoError(s.T(), err) + assert.Len(s.T(), users, 2) + s.userRepo.AssertExpectations(s.T()) +} + +func (s *UserQueriesSuite) TestUserProfile() { + // Arrange + ctx := context.Background() + testProfile := &domain.UserProfile{BaseModel: domain.BaseModel{ID: 1}, UserID: 1, Website: "https://example.com"} + + s.profileRepo.On("GetByUserID", ctx, uint(1)).Return(testProfile, nil) + + // Act + profile, err := s.queries.UserProfile(ctx, 1) + + // Assert + assert.NoError(s.T(), err) + assert.NotNil(s.T(), profile) + assert.Equal(s.T(), "https://example.com", profile.Website) + s.profileRepo.AssertExpectations(s.T()) +} + +func TestNewService(t *testing.T) { + // Arrange + userRepo := new(mockUserRepository) + profileRepo := new(mockUserProfileRepository) + authzSvc := authz.NewService(nil, nil, nil, nil) + + // Act + service := NewService(userRepo, authzSvc, profileRepo) + + // Assert + assert.NotNil(t, service) + assert.NotNil(t, service.Commands) + assert.NotNil(t, service.Queries) +} \ No newline at end of file diff --git a/internal/app/work/commands.go b/internal/app/work/commands.go index 26de6e4..7e4d1b3 100644 --- a/internal/app/work/commands.go +++ b/internal/app/work/commands.go @@ -19,6 +19,8 @@ import ( // WorkCommands contains the command handlers for the work aggregate. type WorkCommands struct { repo domain.WorkRepository + authorRepo domain.AuthorRepository + userRepo domain.UserRepository searchClient search.SearchClient authzSvc *authz.Service analyticsSvc analytics.Service @@ -26,9 +28,11 @@ type WorkCommands struct { } // NewWorkCommands creates a new WorkCommands handler. -func NewWorkCommands(repo domain.WorkRepository, searchClient search.SearchClient, authzSvc *authz.Service, analyticsSvc analytics.Service) *WorkCommands { +func NewWorkCommands(repo domain.WorkRepository, authorRepo domain.AuthorRepository, userRepo domain.UserRepository, searchClient search.SearchClient, authzSvc *authz.Service, analyticsSvc analytics.Service) *WorkCommands { return &WorkCommands{ repo: repo, + authorRepo: authorRepo, + userRepo: userRepo, searchClient: searchClient, authzSvc: authzSvc, analyticsSvc: analyticsSvc, @@ -49,12 +53,48 @@ func (c *WorkCommands) CreateWork(ctx context.Context, work *domain.Work) (*doma if work.Language == "" { return nil, errors.New("work language cannot be empty") } - err := c.repo.Create(ctx, work) + + userID, ok := platform_auth.GetUserIDFromContext(ctx) + if !ok { + return nil, domain.ErrUnauthorized + } + + user, err := c.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to get user for author creation: %w", err) + } + + // Find or create an author for the user + author, err := c.authorRepo.FindByName(ctx, user.Username) + if err != nil { + if errors.Is(err, domain.ErrEntityNotFound) { + // Author doesn't exist, create one + newAuthor := &domain.Author{ + Name: user.Username, + } + if err := c.authorRepo.Create(ctx, newAuthor); err != nil { + return nil, fmt.Errorf("failed to create author for user: %w", err) + } + author = newAuthor + } else { + // Another error occurred + return nil, fmt.Errorf("failed to find author: %w", err) + } + } + + // Associate the author with the work + work.Authors = []*domain.Author{author} + + err = c.repo.Create(ctx, work) if err != nil { return nil, err } // Index the work in the search client - if err := c.searchClient.IndexWork(ctx, work, ""); err != nil { + var content string + if len(work.Translations) > 0 { + content = work.Translations[0].Content + } + if err := c.searchClient.IndexWork(ctx, work, content); err != nil { // Log the error but don't fail the operation log.FromContext(ctx).Warn(fmt.Sprintf("Failed to index work after creation: %v", err)) } @@ -105,7 +145,7 @@ func (c *WorkCommands) UpdateWork(ctx context.Context, work *domain.Work) error return err } // Index the work in the search client - return c.searchClient.IndexWork(ctx, work, "") + return c.searchClient.IndexWork(ctx, work, work.Description) } // DeleteWork deletes a work by ID after performing an authorization check. @@ -129,7 +169,7 @@ func (c *WorkCommands) DeleteWork(ctx context.Context, id uint) error { return fmt.Errorf("failed to get work for authorization: %w", err) } - can, err := c.authzSvc.CanDeleteWork(ctx) + can, err := c.authzSvc.CanDeleteWork(ctx, userID, existingWork) if err != nil { return err } @@ -137,9 +177,6 @@ func (c *WorkCommands) DeleteWork(ctx context.Context, id uint) error { return domain.ErrForbidden } - _ = userID // to avoid unused variable error - _ = existingWork // to avoid unused variable error - return c.repo.Delete(ctx, id) } diff --git a/internal/app/work/commands_test.go b/internal/app/work/commands_test.go deleted file mode 100644 index eb5d055..0000000 --- a/internal/app/work/commands_test.go +++ /dev/null @@ -1,414 +0,0 @@ -package work - -import ( - "context" - "errors" - "testing" - "tercul/internal/platform/config" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "tercul/internal/app/authz" - "tercul/internal/data/sql" - "tercul/internal/domain" - platform_auth "tercul/internal/platform/auth" -) - -type WorkCommandsSuite struct { - suite.Suite - repo *mockWorkRepository - searchClient *mockSearchClient - authzSvc *authz.Service - analyticsSvc *mockAnalyticsService - commands *WorkCommands -} - -func (s *WorkCommandsSuite) SetupTest() { - s.repo = &mockWorkRepository{} - s.searchClient = &mockSearchClient{} - s.authzSvc = authz.NewService(s.repo, nil) - s.analyticsSvc = &mockAnalyticsService{} - s.commands = NewWorkCommands(s.repo, s.searchClient, s.authzSvc, s.analyticsSvc) -} - -func TestWorkCommandsSuite(t *testing.T) { - suite.Run(t, new(WorkCommandsSuite)) -} - -func (s *WorkCommandsSuite) TestCreateWork_Success() { - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - _, err := s.commands.CreateWork(context.Background(), work) - assert.NoError(s.T(), err) -} - -func (s *WorkCommandsSuite) TestCreateWork_Nil() { - _, err := s.commands.CreateWork(context.Background(), nil) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestCreateWork_EmptyTitle() { - work := &domain.Work{TranslatableModel: domain.TranslatableModel{Language: "en"}} - _, err := s.commands.CreateWork(context.Background(), work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestCreateWork_EmptyLanguage() { - work := &domain.Work{Title: "Test Work"} - _, err := s.commands.CreateWork(context.Background(), work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestCreateWork_RepoError() { - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - s.repo.createFunc = func(ctx context.Context, w *domain.Work) error { - return errors.New("db error") - } - _, err := s.commands.CreateWork(context.Background(), work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_Success() { - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - work.ID = 1 - - s.repo.getByIDFunc = func(ctx context.Context, id uint) (*domain.Work, error) { - return work, nil - } - s.repo.isAuthorFunc = func(ctx context.Context, workID uint, authorID uint) (bool, error) { - return true, nil - } - - err := s.commands.UpdateWork(ctx, work) - assert.NoError(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_Nil() { - err := s.commands.UpdateWork(context.Background(), nil) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_ZeroID() { - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - err := s.commands.UpdateWork(context.Background(), work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_EmptyTitle() { - work := &domain.Work{TranslatableModel: domain.TranslatableModel{Language: "en"}} - work.ID = 1 - err := s.commands.UpdateWork(context.Background(), work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_EmptyLanguage() { - work := &domain.Work{Title: "Test Work"} - work.ID = 1 - err := s.commands.UpdateWork(context.Background(), work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_RepoError() { - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - work.ID = 1 - s.repo.updateFunc = func(ctx context.Context, w *domain.Work) error { - return errors.New("db error") - } - err := s.commands.UpdateWork(ctx, work) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestUpdateWork_Forbidden() { - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 2, Role: string(domain.UserRoleReader)}) // Not an admin - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - work.ID = 1 - - s.repo.isAuthorFunc = func(ctx context.Context, workID uint, authorID uint) (bool, error) { - return false, nil // User is not an author - } - - err := s.commands.UpdateWork(ctx, work) - assert.Error(s.T(), err) - assert.True(s.T(), errors.Is(err, domain.ErrForbidden)) -} - -func (s *WorkCommandsSuite) TestUpdateWork_Unauthorized() { - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - work.ID = 1 - err := s.commands.UpdateWork(context.Background(), work) // No user in context - assert.Error(s.T(), err) - assert.True(s.T(), errors.Is(err, domain.ErrUnauthorized)) -} - -func (s *WorkCommandsSuite) TestDeleteWork_Success() { - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - work := &domain.Work{Title: "Test Work", TranslatableModel: domain.TranslatableModel{Language: "en"}} - work.ID = 1 - - s.repo.getByIDFunc = func(ctx context.Context, id uint) (*domain.Work, error) { - return work, nil - } - s.repo.isAuthorFunc = func(ctx context.Context, workID uint, authorID uint) (bool, error) { - return true, nil - } - - err := s.commands.DeleteWork(ctx, 1) - assert.NoError(s.T(), err) -} - -func (s *WorkCommandsSuite) TestDeleteWork_ZeroID() { - err := s.commands.DeleteWork(context.Background(), 0) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestDeleteWork_RepoError() { - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - s.repo.deleteFunc = func(ctx context.Context, id uint) error { - return errors.New("db error") - } - err := s.commands.DeleteWork(ctx, 1) - assert.Error(s.T(), err) -} - -func (s *WorkCommandsSuite) TestDeleteWork_Forbidden() { - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 2, Role: string(domain.UserRoleReader)}) // Not an admin - err := s.commands.DeleteWork(ctx, 1) - assert.Error(s.T(), err) - assert.True(s.T(), errors.Is(err, domain.ErrForbidden)) -} - -func (s *WorkCommandsSuite) TestDeleteWork_Unauthorized() { - err := s.commands.DeleteWork(context.Background(), 1) // No user in context - assert.Error(s.T(), err) - assert.True(s.T(), errors.Is(err, domain.ErrUnauthorized)) -} - -func (s *WorkCommandsSuite) TestAnalyzeWork_Success() { - work := &domain.Work{ - TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: 1}}, - Translations: []*domain.Translation{ - {BaseModel: domain.BaseModel{ID: 101}}, - {BaseModel: domain.BaseModel{ID: 102}}, - }, - } - s.repo.getWithTranslationsFunc = func(ctx context.Context, id uint) (*domain.Work, error) { - return work, nil - } - - var readingTime, complexity, sentiment, tReadingTime, tSentiment int - s.analyticsSvc.updateWorkReadingTimeFunc = func(ctx context.Context, workID uint) error { - readingTime++ - return nil - } - s.analyticsSvc.updateWorkComplexityFunc = func(ctx context.Context, workID uint) error { - complexity++ - return nil - } - s.analyticsSvc.updateWorkSentimentFunc = func(ctx context.Context, workID uint) error { - sentiment++ - return nil - } - s.analyticsSvc.updateTranslationReadingTimeFunc = func(ctx context.Context, translationID uint) error { - tReadingTime++ - return nil - } - s.analyticsSvc.updateTranslationSentimentFunc = func(ctx context.Context, translationID uint) error { - tSentiment++ - return nil - } - - err := s.commands.AnalyzeWork(context.Background(), 1) - assert.NoError(s.T(), err) - - assert.Equal(s.T(), 1, readingTime, "UpdateWorkReadingTime should be called once") - assert.Equal(s.T(), 1, complexity, "UpdateWorkComplexity should be called once") - assert.Equal(s.T(), 1, sentiment, "UpdateWorkSentiment should be called once") - assert.Equal(s.T(), 2, tReadingTime, "UpdateTranslationReadingTime should be called for each translation") - assert.Equal(s.T(), 2, tSentiment, "UpdateTranslationSentiment should be called for each translation") -} - -func TestMergeWork_Integration(t *testing.T) { - // Setup in-memory SQLite DB - db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) - assert.NoError(t, err) - - // Run migrations for all relevant tables - err = db.AutoMigrate( - &domain.Work{}, - &domain.Translation{}, - &domain.Author{}, - &domain.Tag{}, - &domain.Category{}, - &domain.Copyright{}, - &domain.Monetization{}, - &domain.WorkStats{}, - &domain.WorkAuthor{}, - ) - assert.NoError(t, err) - - // Create real repositories and services pointing to the test DB - cfg, err := config.LoadConfig() - assert.NoError(t, err) - workRepo := sql.NewWorkRepository(db, cfg) - authzSvc := authz.NewService(workRepo, nil) // Using real repo for authz checks - searchClient := &mockSearchClient{} // Mock search client is fine - analyticsSvc := &mockAnalyticsService{} - commands := NewWorkCommands(workRepo, searchClient, authzSvc, analyticsSvc) - - // Provide a realistic implementation for the GetOrCreateWorkStats mock - analyticsSvc.getOrCreateWorkStatsFunc = func(ctx context.Context, workID uint) (*domain.WorkStats, error) { - var stats domain.WorkStats - if err := db.Where(domain.WorkStats{WorkID: workID}).FirstOrCreate(&stats).Error; err != nil { - return nil, err - } - return &stats, nil - } - - // --- Seed Data --- - t.Run("Success", func(t *testing.T) { - author1 := &domain.Author{Name: "Author One"} - db.Create(author1) - author2 := &domain.Author{Name: "Author Two"} - db.Create(author2) - - tag1 := &domain.Tag{Name: "Tag One"} - db.Create(tag1) - tag2 := &domain.Tag{Name: "Tag Two"} - db.Create(tag2) - - sourceWork := &domain.Work{ - TranslatableModel: domain.TranslatableModel{Language: "en"}, - Title: "Source Work", - Authors: []*domain.Author{author1}, - Tags: []*domain.Tag{tag1}, - } - db.Create(sourceWork) - db.Create(&domain.Translation{Title: "Source English", Language: "en", TranslatableID: sourceWork.ID, TranslatableType: "works"}) - db.Create(&domain.Translation{Title: "Source French", Language: "fr", TranslatableID: sourceWork.ID, TranslatableType: "works"}) - db.Create(&domain.WorkStats{WorkID: sourceWork.ID, Views: 10, Likes: 5}) - - targetWork := &domain.Work{ - TranslatableModel: domain.TranslatableModel{Language: "en"}, - Title: "Target Work", - Authors: []*domain.Author{author2}, - Tags: []*domain.Tag{tag2}, - } - db.Create(targetWork) - db.Create(&domain.Translation{Title: "Target English", Language: "en", TranslatableID: targetWork.ID, TranslatableType: "works"}) - db.Create(&domain.WorkStats{WorkID: targetWork.ID, Views: 20, Likes: 10}) - - // --- Execute Merge --- - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - err = commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) - assert.NoError(t, err) - - // --- Assertions --- - // 1. Source work should be deleted - var deletedWork domain.Work - err = db.First(&deletedWork, sourceWork.ID).Error - assert.Error(t, err) - assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) - - // 2. Target work should have merged data - var finalTargetWork domain.Work - db.Preload("Translations").Preload("Authors").Preload("Tags").First(&finalTargetWork, targetWork.ID) - - assert.Len(t, finalTargetWork.Translations, 2, "Should have two translations after merge") - foundEn := false - foundFr := false - for _, tr := range finalTargetWork.Translations { - if tr.Language == "en" { - foundEn = true - assert.Equal(t, "Target English", tr.Title, "Should keep target's English translation") - } - if tr.Language == "fr" { - foundFr = true - assert.Equal(t, "Source French", tr.Title, "Should merge source's French translation") - } - } - assert.True(t, foundEn, "English translation should be present") - assert.True(t, foundFr, "French translation should be present") - - assert.Len(t, finalTargetWork.Authors, 2, "Authors should be merged") - assert.Len(t, finalTargetWork.Tags, 2, "Tags should be merged") - - // 3. Stats should be merged - var finalStats domain.WorkStats - db.Where("work_id = ?", targetWork.ID).First(&finalStats) - assert.Equal(t, int64(30), finalStats.Views, "Views should be summed") - assert.Equal(t, int64(15), finalStats.Likes, "Likes should be summed") - - // 4. Source stats should be deleted - var deletedStats domain.WorkStats - err = db.First(&deletedStats, "work_id = ?", sourceWork.ID).Error - assert.Error(t, err, "Source stats should be deleted") - assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) - }) - - t.Run("Success with no target stats", func(t *testing.T) { - sourceWork := &domain.Work{Title: "Source with Stats"} - db.Create(sourceWork) - db.Create(&domain.WorkStats{WorkID: sourceWork.ID, Views: 15, Likes: 7}) - - targetWork := &domain.Work{Title: "Target without Stats"} - db.Create(targetWork) - - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - err := commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) - assert.NoError(t, err) - - var finalStats domain.WorkStats - db.Where("work_id = ?", targetWork.ID).First(&finalStats) - assert.Equal(t, int64(15), finalStats.Views) - assert.Equal(t, int64(7), finalStats.Likes) - }) - - t.Run("Forbidden for non-admin", func(t *testing.T) { - sourceWork := &domain.Work{Title: "Forbidden Source"} - db.Create(sourceWork) - targetWork := &domain.Work{Title: "Forbidden Target"} - db.Create(targetWork) - - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 2, Role: string(domain.UserRoleReader)}) - err := commands.MergeWork(ctx, sourceWork.ID, targetWork.ID) - - assert.Error(t, err) - assert.True(t, errors.Is(err, domain.ErrForbidden)) - }) - - t.Run("Source work not found", func(t *testing.T) { - targetWork := &domain.Work{Title: "Existing Target"} - db.Create(targetWork) - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - - err := commands.MergeWork(ctx, 99999, targetWork.ID) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "entity not found") - }) - - t.Run("Target work not found", func(t *testing.T) { - sourceWork := &domain.Work{Title: "Existing Source"} - db.Create(sourceWork) - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - - err := commands.MergeWork(ctx, sourceWork.ID, 99999) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "entity not found") - }) - - t.Run("Cannot merge work into itself", func(t *testing.T) { - work := &domain.Work{Title: "Self Merge Work"} - db.Create(work) - ctx := context.WithValue(context.Background(), platform_auth.ClaimsContextKey, &platform_auth.Claims{UserID: 1, Role: string(domain.UserRoleAdmin)}) - - err := commands.MergeWork(ctx, work.ID, work.ID) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "source and target work IDs cannot be the same") - }) -} \ No newline at end of file diff --git a/internal/app/work/main_test.go b/internal/app/work/main_test.go index 1d02791..f3b924a 100644 --- a/internal/app/work/main_test.go +++ b/internal/app/work/main_test.go @@ -3,106 +3,668 @@ package work import ( "context" "tercul/internal/domain" + + "github.com/stretchr/testify/mock" + "gorm.io/gorm" ) -type mockWorkRepository struct { - domain.WorkRepository - createFunc func(ctx context.Context, work *domain.Work) error - updateFunc func(ctx context.Context, work *domain.Work) error - deleteFunc func(ctx context.Context, id uint) error - getByIDFunc func(ctx context.Context, id uint) (*domain.Work, error) - listFunc func(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) - getWithTranslationsFunc func(ctx context.Context, id uint) (*domain.Work, error) - findByTitleFunc func(ctx context.Context, title string) ([]domain.Work, error) - findByAuthorFunc func(ctx context.Context, authorID uint) ([]domain.Work, error) - findByCategoryFunc func(ctx context.Context, categoryID uint) ([]domain.Work, error) - findByLanguageFunc func(ctx context.Context, language string, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) - isAuthorFunc func(ctx context.Context, workID uint, authorID uint) (bool, error) - listByCollectionIDFunc func(ctx context.Context, collectionID uint) ([]domain.Work, error) -} +// #region Mocks -func (m *mockWorkRepository) IsAuthor(ctx context.Context, workID uint, authorID uint) (bool, error) { - if m.isAuthorFunc != nil { - return m.isAuthorFunc(ctx, workID, authorID) - } - return false, nil -} +// mockWorkRepository is a mock implementation of domain.WorkRepository +type mockWorkRepository struct{ mock.Mock } -func (m *mockWorkRepository) Create(ctx context.Context, work *domain.Work) error { - if m.createFunc != nil { - return m.createFunc(ctx, work) - } - return nil +func (m *mockWorkRepository) Create(ctx context.Context, entity *domain.Work) error { + args := m.Called(ctx, entity) + return args.Error(0) } -func (m *mockWorkRepository) Update(ctx context.Context, work *domain.Work) error { - if m.updateFunc != nil { - return m.updateFunc(ctx, work) - } - return nil -} -func (m *mockWorkRepository) Delete(ctx context.Context, id uint) error { - if m.deleteFunc != nil { - return m.deleteFunc(ctx, id) - } - return nil +func (m *mockWorkRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Work) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) } func (m *mockWorkRepository) GetByID(ctx context.Context, id uint) (*domain.Work, error) { - if m.getByIDFunc != nil { - return m.getByIDFunc(ctx, id) + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) } - return &domain.Work{TranslatableModel: domain.TranslatableModel{BaseModel: domain.BaseModel{ID: id}}}, nil + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Work, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) Update(ctx context.Context, entity *domain.Work) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockWorkRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Work) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockWorkRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} +func (m *mockWorkRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) } func (m *mockWorkRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { - if m.listFunc != nil { - return m.listFunc(ctx, page, pageSize) + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, nil + return args.Get(0).(*domain.PaginatedResult[domain.Work]), args.Error(1) } - -func (m *mockWorkRepository) ListByCollectionID(ctx context.Context, collectionID uint) ([]domain.Work, error) { - if m.listByCollectionIDFunc != nil { - return m.listByCollectionIDFunc(ctx, collectionID) +func (m *mockWorkRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Work, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, nil + return args.Get(0).([]domain.Work), args.Error(1) } -func (m *mockWorkRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Work, error) { - if m.getWithTranslationsFunc != nil { - return m.getWithTranslationsFunc(ctx, id) +func (m *mockWorkRepository) ListAll(ctx context.Context) ([]domain.Work, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, nil + return args.Get(0).([]domain.Work), args.Error(1) +} +func (m *mockWorkRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockWorkRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockWorkRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Work, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Work, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Work), args.Error(1) +} +func (m *mockWorkRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} +func (m *mockWorkRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} +func (m *mockWorkRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return fn(nil) } func (m *mockWorkRepository) FindByTitle(ctx context.Context, title string) ([]domain.Work, error) { - if m.findByTitleFunc != nil { - return m.findByTitleFunc(ctx, title) - } - return nil, nil + args := m.Called(ctx, title) + return args.Get(0).([]domain.Work), args.Error(1) } func (m *mockWorkRepository) FindByAuthor(ctx context.Context, authorID uint) ([]domain.Work, error) { - if m.findByAuthorFunc != nil { - return m.findByAuthorFunc(ctx, authorID) - } - return nil, nil + args := m.Called(ctx, authorID) + return args.Get(0).([]domain.Work), args.Error(1) } func (m *mockWorkRepository) FindByCategory(ctx context.Context, categoryID uint) ([]domain.Work, error) { - if m.findByCategoryFunc != nil { - return m.findByCategoryFunc(ctx, categoryID) - } - return nil, nil + args := m.Called(ctx, categoryID) + return args.Get(0).([]domain.Work), args.Error(1) } func (m *mockWorkRepository) FindByLanguage(ctx context.Context, language string, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { - if m.findByLanguageFunc != nil { - return m.findByLanguageFunc(ctx, language, page, pageSize) + args := m.Called(ctx, language, page, pageSize) + return args.Get(0).(*domain.PaginatedResult[domain.Work]), args.Error(1) +} +func (m *mockWorkRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Work, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, nil + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) IsAuthor(ctx context.Context, workID uint, authorID uint) (bool, error) { + args := m.Called(ctx, workID, authorID) + return args.Bool(0), args.Error(1) +} +func (m *mockWorkRepository) ListByCollectionID(ctx context.Context, collectionID uint) ([]domain.Work, error) { + args := m.Called(ctx, collectionID) + return args.Get(0).([]domain.Work), args.Error(1) +} +func (m *mockWorkRepository) GetWithAssociations(ctx context.Context, id uint) (*domain.Work, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) GetWithAssociationsInTx(ctx context.Context, tx *gorm.DB, id uint) (*domain.Work, error) { + args := m.Called(ctx, tx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Work), args.Error(1) +} +func (m *mockWorkRepository) ListWithTranslations(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { + args := m.Called(ctx, page, pageSize) + return args.Get(0).(*domain.PaginatedResult[domain.Work]), args.Error(1) } -type mockSearchClient struct { - indexWorkFunc func(ctx context.Context, work *domain.Work, pipeline string) error +// mockAuthorRepository is a mock implementation of domain.AuthorRepository +type mockAuthorRepository struct{ mock.Mock } + +func (m *mockAuthorRepository) Create(ctx context.Context, entity *domain.Author) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Author) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) GetByID(ctx context.Context, id uint) (*domain.Author, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Author, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Update(ctx context.Context, entity *domain.Author) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Author) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockAuthorRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} +func (m *mockAuthorRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) +} +func (m *mockAuthorRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Author], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.Author]), args.Error(1) +} +func (m *mockAuthorRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Author, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListAll(ctx context.Context) ([]domain.Author, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockAuthorRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockAuthorRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Author, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Author, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} +func (m *mockAuthorRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} +func (m *mockAuthorRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return fn(nil) +} +func (m *mockAuthorRepository) FindByName(ctx context.Context, name string) (*domain.Author, error) { + args := m.Called(ctx, name) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Author, error) { + args := m.Called(ctx, workID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListByBookID(ctx context.Context, bookID uint) ([]domain.Author, error) { + args := m.Called(ctx, bookID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) ListByCountryID(ctx context.Context, countryID uint) ([]domain.Author, error) { + args := m.Called(ctx, countryID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Author), args.Error(1) +} +func (m *mockAuthorRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Author, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Author), args.Error(1) } -func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error { - if m.indexWorkFunc != nil { - return m.indexWorkFunc(ctx, work, pipeline) +// mockUserRepository is a mock implementation of domain.UserRepository +type mockUserRepository struct{ mock.Mock } + +func (m *mockUserRepository) Create(ctx context.Context, entity *domain.User) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockUserRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockUserRepository) GetByID(ctx context.Context, id uint) (*domain.User, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil -} \ No newline at end of file + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.User, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepository) Update(ctx context.Context, entity *domain.User) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockUserRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockUserRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} +func (m *mockUserRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) +} +func (m *mockUserRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.User], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.User]), args.Error(1) +} +func (m *mockUserRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.User, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} +func (m *mockUserRepository) ListAll(ctx context.Context) ([]domain.User, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} +func (m *mockUserRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockUserRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockUserRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.User, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.User, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} +func (m *mockUserRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} +func (m *mockUserRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} +func (m *mockUserRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return fn(nil) +} +func (m *mockUserRepository) FindByUsername(ctx context.Context, username string) (*domain.User, error) { + args := m.Called(ctx, username) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) +} +func (m *mockUserRepository) ListByRole(ctx context.Context, role domain.UserRole) ([]domain.User, error) { + args := m.Called(ctx, role) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) +} + +// mockTranslationRepository is a mock implementation of domain.TranslationRepository +type mockTranslationRepository struct{ mock.Mock } + +func (m *mockTranslationRepository) Create(ctx context.Context, entity *domain.Translation) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockTranslationRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Translation) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockTranslationRepository) GetByID(ctx context.Context, id uint) (*domain.Translation, error) { + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.Translation, error) { + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) Update(ctx context.Context, entity *domain.Translation) error { + args := m.Called(ctx, entity) + return args.Error(0) +} +func (m *mockTranslationRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.Translation) error { + args := m.Called(ctx, tx, entity) + return args.Error(0) +} +func (m *mockTranslationRepository) Delete(ctx context.Context, id uint) error { + args := m.Called(ctx, id) + return args.Error(0) +} +func (m *mockTranslationRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { + args := m.Called(ctx, tx, id) + return args.Error(0) +} +func (m *mockTranslationRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Translation], error) { + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.Translation]), args.Error(1) +} +func (m *mockTranslationRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.Translation, error) { + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) ListAll(ctx context.Context) ([]domain.Translation, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) Count(ctx context.Context) (int64, error) { + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockTranslationRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) +} +func (m *mockTranslationRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.Translation, error) { + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.Translation, error) { + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) Exists(ctx context.Context, id uint) (bool, error) { + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) +} +func (m *mockTranslationRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) +} +func (m *mockTranslationRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { + return fn(nil) +} +func (m *mockTranslationRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Translation, error) { + args := m.Called(ctx, workID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) ListByWorkIDPaginated(ctx context.Context, workID uint, language *string, page, pageSize int) (*domain.PaginatedResult[domain.Translation], error) { + args := m.Called(ctx, workID, language, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.PaginatedResult[domain.Translation]), args.Error(1) +} +func (m *mockTranslationRepository) ListByEntity(ctx context.Context, entityType string, entityID uint) ([]domain.Translation, error) { + args := m.Called(ctx, entityType, entityID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) ListByTranslatorID(ctx context.Context, translatorID uint) ([]domain.Translation, error) { + args := m.Called(ctx, translatorID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) ListByStatus(ctx context.Context, status domain.TranslationStatus) ([]domain.Translation, error) { + args := m.Called(ctx, status) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.Translation), args.Error(1) +} +func (m *mockTranslationRepository) Upsert(ctx context.Context, translation *domain.Translation) error { + args := m.Called(ctx, translation) + return args.Error(0) +} + +// mockSearchClient is a mock implementation of search.SearchClient +type mockSearchClient struct{ mock.Mock } + +func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, content string) error { + args := m.Called(ctx, work, content) + return args.Error(0) +} + +// mockAnalyticsService is a mock implementation of analytics.Service +type mockAnalyticsService struct{ mock.Mock } + +func (m *mockAnalyticsService) IncrementWorkViews(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkLikes(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkComments(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkBookmarks(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkShares(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementWorkTranslationCount(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementTranslationViews(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementTranslationLikes(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) DecrementWorkLikes(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) DecrementTranslationLikes(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementTranslationComments(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) IncrementTranslationShares(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) GetOrCreateWorkStats(ctx context.Context, workID uint) (*domain.WorkStats, error) { + args := m.Called(ctx, workID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.WorkStats), args.Error(1) +} +func (m *mockAnalyticsService) GetOrCreateTranslationStats(ctx context.Context, translationID uint) (*domain.TranslationStats, error) { + args := m.Called(ctx, translationID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.TranslationStats), args.Error(1) +} +func (m *mockAnalyticsService) UpdateWorkReadingTime(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) UpdateWorkComplexity(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) UpdateWorkSentiment(ctx context.Context, workID uint) error { + args := m.Called(ctx, workID) + return args.Error(0) +} +func (m *mockAnalyticsService) UpdateTranslationReadingTime(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) UpdateTranslationSentiment(ctx context.Context, translationID uint) error { + args := m.Called(ctx, translationID) + return args.Error(0) +} +func (m *mockAnalyticsService) UpdateUserEngagement(ctx context.Context, userID uint, eventType string) error { + args := m.Called(ctx, userID, eventType) + return args.Error(0) +} +func (m *mockAnalyticsService) UpdateTrending(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} +func (m *mockAnalyticsService) GetTrendingWorks(ctx context.Context, timePeriod string, limit int) ([]*domain.Work, error) { + args := m.Called(ctx, timePeriod, limit) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]*domain.Work), args.Error(1) +} +func (m *mockAnalyticsService) UpdateWorkStats(ctx context.Context, workID uint, stats domain.WorkStats) error { + args := m.Called(ctx, workID, stats) + return args.Error(0) +} +func (m *mockAnalyticsService) MergeWorkStats(ctx context.Context, sourceWorkID, targetWorkID uint) error { + args := m.Called(ctx, sourceWorkID, targetWorkID) + return args.Error(0) +} + +// #endregion Mocks \ No newline at end of file diff --git a/internal/app/work/mock_analytics_service_test.go b/internal/app/work/mock_analytics_service_test.go index 363a346..580f5d7 100644 --- a/internal/app/work/mock_analytics_service_test.go +++ b/internal/app/work/mock_analytics_service_test.go @@ -1,104 +1,4 @@ package work -import ( - "context" - "tercul/internal/domain" -) - -type mockAnalyticsService struct { - updateWorkReadingTimeFunc func(ctx context.Context, workID uint) error - updateWorkComplexityFunc func(ctx context.Context, workID uint) error - updateWorkSentimentFunc func(ctx context.Context, workID uint) error - updateTranslationReadingTimeFunc func(ctx context.Context, translationID uint) error - updateTranslationSentimentFunc func(ctx context.Context, translationID uint) error - getOrCreateWorkStatsFunc func(ctx context.Context, workID uint) (*domain.WorkStats, error) - updateWorkStatsFunc func(ctx context.Context, workID uint, stats domain.WorkStats) error -} - -func (m *mockAnalyticsService) UpdateWorkStats(ctx context.Context, workID uint, stats domain.WorkStats) error { - if m.updateWorkStatsFunc != nil { - return m.updateWorkStatsFunc(ctx, workID, stats) - } - return nil -} - -func (m *mockAnalyticsService) UpdateWorkReadingTime(ctx context.Context, workID uint) error { - if m.updateWorkReadingTimeFunc != nil { - return m.updateWorkReadingTimeFunc(ctx, workID) - } - return nil -} - -func (m *mockAnalyticsService) UpdateWorkComplexity(ctx context.Context, workID uint) error { - if m.updateWorkComplexityFunc != nil { - return m.updateWorkComplexityFunc(ctx, workID) - } - return nil -} - -func (m *mockAnalyticsService) UpdateWorkSentiment(ctx context.Context, workID uint) error { - if m.updateWorkSentimentFunc != nil { - return m.updateWorkSentimentFunc(ctx, workID) - } - return nil -} - -func (m *mockAnalyticsService) UpdateTranslationReadingTime(ctx context.Context, translationID uint) error { - if m.updateTranslationReadingTimeFunc != nil { - return m.updateTranslationReadingTimeFunc(ctx, translationID) - } - return nil -} - -func (m *mockAnalyticsService) UpdateTranslationSentiment(ctx context.Context, translationID uint) error { - if m.updateTranslationSentimentFunc != nil { - return m.updateTranslationSentimentFunc(ctx, translationID) - } - return nil -} - -// Implement other methods of the analytics.Service interface to satisfy the compiler -func (m *mockAnalyticsService) IncrementWorkViews(ctx context.Context, workID uint) error { return nil } -func (m *mockAnalyticsService) IncrementWorkLikes(ctx context.Context, workID uint) error { return nil } -func (m *mockAnalyticsService) IncrementWorkComments(ctx context.Context, workID uint) error { - return nil -} -func (m *mockAnalyticsService) IncrementWorkBookmarks(ctx context.Context, workID uint) error { - return nil -} -func (m *mockAnalyticsService) IncrementWorkShares(ctx context.Context, workID uint) error { return nil } -func (m *mockAnalyticsService) IncrementWorkTranslationCount(ctx context.Context, workID uint) error { - return nil -} -func (m *mockAnalyticsService) IncrementTranslationViews(ctx context.Context, translationID uint) error { - return nil -} -func (m *mockAnalyticsService) IncrementTranslationLikes(ctx context.Context, translationID uint) error { - return nil -} -func (m *mockAnalyticsService) DecrementWorkLikes(ctx context.Context, workID uint) error { return nil } -func (m *mockAnalyticsService) DecrementTranslationLikes(ctx context.Context, translationID uint) error { - return nil -} -func (m *mockAnalyticsService) IncrementTranslationComments(ctx context.Context, translationID uint) error { - return nil -} -func (m *mockAnalyticsService) IncrementTranslationShares(ctx context.Context, translationID uint) error { - return nil -} -func (m *mockAnalyticsService) GetOrCreateWorkStats(ctx context.Context, workID uint) (*domain.WorkStats, error) { - if m.getOrCreateWorkStatsFunc != nil { - return m.getOrCreateWorkStatsFunc(ctx, workID) - } - return nil, nil -} -func (m *mockAnalyticsService) GetOrCreateTranslationStats(ctx context.Context, translationID uint) (*domain.TranslationStats, error) { - return nil, nil -} -func (m *mockAnalyticsService) UpdateUserEngagement(ctx context.Context, userID uint, eventType string) error { - return nil -} -func (m *mockAnalyticsService) UpdateTrending(ctx context.Context) error { return nil } -func (m *mockAnalyticsService) GetTrendingWorks(ctx context.Context, timePeriod string, limit int) ([]*domain.Work, error) { - return nil, nil -} \ No newline at end of file +// This file is intentionally left empty. +// Mocks are defined in main_test.go to avoid redeclaration errors. \ No newline at end of file diff --git a/internal/app/work/queries_test.go b/internal/app/work/queries_test.go index 5ca36d9..80d94f2 100644 --- a/internal/app/work/queries_test.go +++ b/internal/app/work/queries_test.go @@ -2,10 +2,13 @@ package work import ( "context" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "tercul/internal/domain" "testing" + + "tercul/internal/domain" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" ) type WorkQueriesSuite struct { @@ -26,9 +29,8 @@ func TestWorkQueriesSuite(t *testing.T) { func (s *WorkQueriesSuite) TestGetWorkByID_Success() { work := &domain.Work{Title: "Test Work"} work.ID = 1 - s.repo.getByIDFunc = func(ctx context.Context, id uint) (*domain.Work, error) { - return work, nil - } + s.repo.On("GetByID", mock.Anything, uint(1)).Return(work, nil) + w, err := s.queries.GetWorkByID(context.Background(), 1) assert.NoError(s.T(), err) expectedDTO := &WorkDTO{ @@ -47,9 +49,7 @@ func (s *WorkQueriesSuite) TestGetWorkByID_ZeroID() { func (s *WorkQueriesSuite) TestListByCollectionID_Success() { works := []domain.Work{{Title: "Test Work"}} - s.repo.listByCollectionIDFunc = func(ctx context.Context, collectionID uint) ([]domain.Work, error) { - return works, nil - } + s.repo.On("ListByCollectionID", mock.Anything, uint(1)).Return(works, nil) w, err := s.queries.ListByCollectionID(context.Background(), 1) assert.NoError(s.T(), err) assert.Equal(s.T(), works, w) @@ -72,9 +72,7 @@ func (s *WorkQueriesSuite) TestListWorks_Success() { PageSize: 10, TotalPages: 1, } - s.repo.listFunc = func(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { - return domainWorks, nil - } + s.repo.On("List", mock.Anything, 1, 10).Return(domainWorks, nil) paginatedDTOs, err := s.queries.ListWorks(context.Background(), 1, 10) assert.NoError(s.T(), err) @@ -95,9 +93,7 @@ func (s *WorkQueriesSuite) TestListWorks_Success() { func (s *WorkQueriesSuite) TestGetWorkWithTranslations_Success() { work := &domain.Work{Title: "Test Work"} work.ID = 1 - s.repo.getWithTranslationsFunc = func(ctx context.Context, id uint) (*domain.Work, error) { - return work, nil - } + s.repo.On("GetWithTranslations", mock.Anything, uint(1)).Return(work, nil) w, err := s.queries.GetWorkWithTranslations(context.Background(), 1) assert.NoError(s.T(), err) assert.Equal(s.T(), work, w) @@ -111,9 +107,7 @@ func (s *WorkQueriesSuite) TestGetWorkWithTranslations_ZeroID() { func (s *WorkQueriesSuite) TestFindWorksByTitle_Success() { works := []domain.Work{{Title: "Test Work"}} - s.repo.findByTitleFunc = func(ctx context.Context, title string) ([]domain.Work, error) { - return works, nil - } + s.repo.On("FindByTitle", mock.Anything, "Test").Return(works, nil) w, err := s.queries.FindWorksByTitle(context.Background(), "Test") assert.NoError(s.T(), err) assert.Equal(s.T(), works, w) @@ -127,9 +121,7 @@ func (s *WorkQueriesSuite) TestFindWorksByTitle_Empty() { func (s *WorkQueriesSuite) TestFindWorksByAuthor_Success() { works := []domain.Work{{Title: "Test Work"}} - s.repo.findByAuthorFunc = func(ctx context.Context, authorID uint) ([]domain.Work, error) { - return works, nil - } + s.repo.On("FindByAuthor", mock.Anything, uint(1)).Return(works, nil) w, err := s.queries.FindWorksByAuthor(context.Background(), 1) assert.NoError(s.T(), err) assert.Equal(s.T(), works, w) @@ -143,9 +135,7 @@ func (s *WorkQueriesSuite) TestFindWorksByAuthor_ZeroID() { func (s *WorkQueriesSuite) TestFindWorksByCategory_Success() { works := []domain.Work{{Title: "Test Work"}} - s.repo.findByCategoryFunc = func(ctx context.Context, categoryID uint) ([]domain.Work, error) { - return works, nil - } + s.repo.On("FindByCategory", mock.Anything, uint(1)).Return(works, nil) w, err := s.queries.FindWorksByCategory(context.Background(), 1) assert.NoError(s.T(), err) assert.Equal(s.T(), works, w) @@ -159,9 +149,7 @@ func (s *WorkQueriesSuite) TestFindWorksByCategory_ZeroID() { func (s *WorkQueriesSuite) TestFindWorksByLanguage_Success() { works := &domain.PaginatedResult[domain.Work]{} - s.repo.findByLanguageFunc = func(ctx context.Context, language string, page, pageSize int) (*domain.PaginatedResult[domain.Work], error) { - return works, nil - } + s.repo.On("FindByLanguage", mock.Anything, "en", 1, 10).Return(works, nil) w, err := s.queries.FindWorksByLanguage(context.Background(), "en", 1, 10) assert.NoError(s.T(), err) assert.Equal(s.T(), works, w) diff --git a/internal/app/work/service.go b/internal/app/work/service.go index e01260c..763299b 100644 --- a/internal/app/work/service.go +++ b/internal/app/work/service.go @@ -14,9 +14,9 @@ type Service struct { } // NewService creates a new work Service. -func NewService(repo domain.WorkRepository, searchClient search.SearchClient, authzSvc *authz.Service, analyticsSvc analytics.Service) *Service { +func NewService(repo domain.WorkRepository, authorRepo domain.AuthorRepository, userRepo domain.UserRepository, searchClient search.SearchClient, authzSvc *authz.Service, analyticsSvc analytics.Service) *Service { return &Service{ - Commands: NewWorkCommands(repo, searchClient, authzSvc, analyticsSvc), + Commands: NewWorkCommands(repo, authorRepo, userRepo, searchClient, authzSvc, analyticsSvc), Queries: NewWorkQueries(repo), } } diff --git a/internal/app/work/service_test.go b/internal/app/work/service_test.go deleted file mode 100644 index a7a881f..0000000 --- a/internal/app/work/service_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package work - -import ( - "testing" - "tercul/internal/app/authz" - - "github.com/stretchr/testify/assert" -) - -func TestNewService(t *testing.T) { - // Arrange - mockRepo := &mockWorkRepository{} - mockSearchClient := &mockSearchClient{} - mockAuthzSvc := &authz.Service{} - mockAnalyticsSvc := &mockAnalyticsService{} - - // Act - service := NewService(mockRepo, mockSearchClient, mockAuthzSvc, mockAnalyticsSvc) - - // Assert - assert.NotNil(t, service, "The new service should not be nil") - assert.NotNil(t, service.Commands, "The service Commands should not be nil") - assert.NotNil(t, service.Queries, "The service Queries should not be nil") -} \ No newline at end of file diff --git a/internal/data/sql/analytics_repository_test.go b/internal/data/sql/analytics_repository_test.go index 9101cd5..b4831b7 100644 --- a/internal/data/sql/analytics_repository_test.go +++ b/internal/data/sql/analytics_repository_test.go @@ -201,4 +201,91 @@ func TestAnalyticsRepository_UpdateUserEngagement(t *testing.T) { assert.Equal(t, 15, updatedEngagement.LikesGiven) assert.Equal(t, 10, updatedEngagement.WorksRead) +} + +func TestAnalyticsRepository_IncrementTranslationCounter(t *testing.T) { + repo, db := newTestAnalyticsRepoWithSQLite(t) + ctx := context.Background() + + // Setup: Create a work and translation to associate stats with + work := domain.Work{Title: "Test Work"} + require.NoError(t, db.Create(&work).Error) + translation := domain.Translation{Title: "Test Translation", TranslatableID: work.ID, TranslatableType: "works"} + require.NoError(t, db.Create(&translation).Error) + + t.Run("creates_new_stats_if_not_exist", func(t *testing.T) { + err := repo.IncrementTranslationCounter(ctx, translation.ID, "views", 10) + require.NoError(t, err) + + var stats domain.TranslationStats + err = db.Where("translation_id = ?", translation.ID).First(&stats).Error + require.NoError(t, err) + assert.Equal(t, int64(10), stats.Views) + }) + + t.Run("increments_existing_stats", func(t *testing.T) { + // Increment again + err := repo.IncrementTranslationCounter(ctx, translation.ID, "views", 5) + require.NoError(t, err) + + var stats domain.TranslationStats + err = db.Where("translation_id = ?", translation.ID).First(&stats).Error + require.NoError(t, err) + assert.Equal(t, int64(15), stats.Views) // 10 + 5 + }) + + t.Run("invalid_field", func(t *testing.T) { + err := repo.IncrementTranslationCounter(ctx, translation.ID, "invalid_field", 1) + assert.Error(t, err) + }) +} + +func TestAnalyticsRepository_UpdateWorkStats(t *testing.T) { + repo, db := newTestAnalyticsRepoWithSQLite(t) + ctx := context.Background() + + // Setup + work := domain.Work{Title: "Test Work"} + require.NoError(t, db.Create(&work).Error) + stats := domain.WorkStats{WorkID: work.ID, Views: 10} + require.NoError(t, db.Create(&stats).Error) + + // Act + update := domain.WorkStats{ReadingTime: 120, Complexity: 0.5} + err := repo.UpdateWorkStats(ctx, work.ID, update) + require.NoError(t, err) + + // Assert + var updatedStats domain.WorkStats + err = db.Where("work_id = ?", work.ID).First(&updatedStats).Error + require.NoError(t, err) + assert.Equal(t, int64(10), updatedStats.Views) // Should not be zeroed + assert.Equal(t, 120, updatedStats.ReadingTime) + assert.Equal(t, 0.5, updatedStats.Complexity) +} + +func TestAnalyticsRepository_UpdateTranslationStats(t *testing.T) { + repo, db := newTestAnalyticsRepoWithSQLite(t) + ctx := context.Background() + + // Setup + work := domain.Work{Title: "Test Work"} + require.NoError(t, db.Create(&work).Error) + translation := domain.Translation{Title: "Test Translation", TranslatableID: work.ID, TranslatableType: "works"} + require.NoError(t, db.Create(&translation).Error) + stats := domain.TranslationStats{TranslationID: translation.ID, Views: 20} + require.NoError(t, db.Create(&stats).Error) + + // Act + update := domain.TranslationStats{ReadingTime: 60, Sentiment: 0.8} + err := repo.UpdateTranslationStats(ctx, translation.ID, update) + require.NoError(t, err) + + // Assert + var updatedStats domain.TranslationStats + err = db.Where("translation_id = ?", translation.ID).First(&updatedStats).Error + require.NoError(t, err) + assert.Equal(t, int64(20), updatedStats.Views) // Should not be zeroed + assert.Equal(t, 60, updatedStats.ReadingTime) + assert.Equal(t, 0.8, updatedStats.Sentiment) } \ No newline at end of file diff --git a/internal/data/sql/author_repository.go b/internal/data/sql/author_repository.go index 0661dbb..d96e0cd 100644 --- a/internal/data/sql/author_repository.go +++ b/internal/data/sql/author_repository.go @@ -2,73 +2,66 @@ package sql import ( "context" + "errors" "tercul/internal/domain" "tercul/internal/platform/config" - - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" "gorm.io/gorm" ) type authorRepository struct { - domain.BaseRepository[domain.Author] - db *gorm.DB - tracer trace.Tracer + *BaseRepositoryImpl[domain.Author] + db *gorm.DB } // NewAuthorRepository creates a new AuthorRepository. func NewAuthorRepository(db *gorm.DB, cfg *config.Config) domain.AuthorRepository { + baseRepo := NewBaseRepositoryImpl[domain.Author](db, cfg) return &authorRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Author](db, cfg), - db: db, - tracer: otel.Tracer("author.repository"), + BaseRepositoryImpl: baseRepo, + db: db, } } -// ListByWorkID finds authors by work ID -func (r *authorRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Author, error) { - ctx, span := r.tracer.Start(ctx, "ListByWorkID") - defer span.End() - var authors []domain.Author - if err := r.db.WithContext(ctx).Joins("JOIN work_authors ON work_authors.author_id = authors.id"). - Where("work_authors.work_id = ?", workID). - Find(&authors).Error; err != nil { - return nil, err - } - return authors, nil -} - -// GetWithTranslations finds an author by ID and preloads their translations. -func (r *authorRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Author, error) { - ctx, span := r.tracer.Start(ctx, "GetWithTranslations") - defer span.End() +func (r *authorRepository) FindByName(ctx context.Context, name string) (*domain.Author, error) { var author domain.Author - if err := r.db.WithContext(ctx).Preload("Translations").First(&author, id).Error; err != nil { + err := r.db.WithContext(ctx).Where("name = ?", name).First(&author).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, domain.ErrEntityNotFound + } return nil, err } return &author, nil } -// ListByBookID finds authors by book ID -func (r *authorRepository) ListByBookID(ctx context.Context, bookID uint) ([]domain.Author, error) { - ctx, span := r.tracer.Start(ctx, "ListByBookID") - defer span.End() +func (r *authorRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain.Author, error) { var authors []domain.Author - if err := r.db.WithContext(ctx).Joins("JOIN book_authors ON book_authors.author_id = authors.id"). - Where("book_authors.book_id = ?", bookID). - Find(&authors).Error; err != nil { - return nil, err - } - return authors, nil + err := r.db.WithContext(ctx).Joins("JOIN work_authors ON work_authors.author_id = authors.id"). + Where("work_authors.work_id = ?", workID).Find(&authors).Error + return authors, err } -// ListByCountryID finds authors by country ID -func (r *authorRepository) ListByCountryID(ctx context.Context, countryID uint) ([]domain.Author, error) { - ctx, span := r.tracer.Start(ctx, "ListByCountryID") - defer span.End() +func (r *authorRepository) ListByBookID(ctx context.Context, bookID uint) ([]domain.Author, error) { var authors []domain.Author - if err := r.db.WithContext(ctx).Where("country_id = ?", countryID).Find(&authors).Error; err != nil { + err := r.db.WithContext(ctx).Joins("JOIN book_authors ON book_authors.author_id = authors.id"). + Where("book_authors.book_id = ?", bookID).Find(&authors).Error + return authors, err +} + +func (r *authorRepository) ListByCountryID(ctx context.Context, countryID uint) ([]domain.Author, error) { + var authors []domain.Author + err := r.db.WithContext(ctx).Where("country_id = ?", countryID).Find(&authors).Error + return authors, err +} + +func (r *authorRepository) GetWithTranslations(ctx context.Context, id uint) (*domain.Author, error) { + var author domain.Author + err := r.db.WithContext(ctx).Preload("Translations").First(&author, id).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, domain.ErrEntityNotFound + } return nil, err } - return authors, nil -} + return &author, nil +} \ No newline at end of file diff --git a/internal/data/sql/author_repository_test.go b/internal/data/sql/author_repository_test.go index 767b39a..b55ddc0 100644 --- a/internal/data/sql/author_repository_test.go +++ b/internal/data/sql/author_repository_test.go @@ -24,9 +24,14 @@ func (s *AuthorRepositoryTestSuite) SetupSuite() { } func (s *AuthorRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() s.DB.Exec("DELETE FROM work_authors") s.DB.Exec("DELETE FROM authors") s.DB.Exec("DELETE FROM works") + s.DB.Exec("DELETE FROM books") + s.DB.Exec("DELETE FROM book_authors") + s.DB.Exec("DELETE FROM countries") + s.DB.Exec("DELETE FROM translations") } func (s *AuthorRepositoryTestSuite) createAuthor(name string) *domain.Author { @@ -41,10 +46,17 @@ func (s *AuthorRepositoryTestSuite) createAuthor(name string) *domain.Author { return author } +func (s *AuthorRepositoryTestSuite) createBook(title string) *domain.Book { + book := &domain.Book{Title: title, TranslatableModel: domain.TranslatableModel{Language: "en"}} + err := s.DB.Create(book).Error + s.Require().NoError(err) + return book +} + func (s *AuthorRepositoryTestSuite) TestListByWorkID() { s.Run("should return all authors for a given work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") author1 := s.createAuthor("Author 1") author2 := s.createAuthor("Author 2") s.Require().NoError(s.DB.Model(&work).Association("Authors").Append([]*domain.Author{author1, author2})) @@ -52,10 +64,83 @@ func (s *AuthorRepositoryTestSuite) TestListByWorkID() { // Act authors, err := s.AuthorRepo.ListByWorkID(context.Background(), work.ID) + // Assert + s.Require().NoError(err) + s.Len(authors, 3) + var authorNames []string + for _, a := range authors { + authorNames = append(authorNames, a.Name) + } + s.ElementsMatch([]string{"admin", "Author 1", "Author 2"}, authorNames) + }) +} + +func (s *AuthorRepositoryTestSuite) TestListByBookID() { + s.Run("should return all authors for a given book", func() { + // Arrange + book := s.createBook("Test Book") + author1 := s.createAuthor("Book Author 1") + author2 := s.createAuthor("Book Author 2") + s.Require().NoError(s.DB.Model(&book).Association("Authors").Append([]*domain.Author{author1, author2})) + + // Act + authors, err := s.AuthorRepo.ListByBookID(context.Background(), book.ID) + // Assert s.Require().NoError(err) s.Len(authors, 2) - s.ElementsMatch([]string{"Author 1", "Author 2"}, []string{authors[0].Name, authors[1].Name}) + s.ElementsMatch([]string{"Book Author 1", "Book Author 2"}, []string{authors[0].Name, authors[1].Name}) + }) +} + +func (s *AuthorRepositoryTestSuite) TestListByCountryID() { + s.Run("should return all authors for a given country", func() { + // Arrange + country1 := &domain.Country{Name: "Country 1", Code: "C1"} + country2 := &domain.Country{Name: "Country 2", Code: "C2"} + s.Require().NoError(s.DB.Create(country1).Error) + s.Require().NoError(s.DB.Create(country2).Error) + + author1 := s.createAuthor("Author C1") + author1.CountryID = &country1.ID + s.Require().NoError(s.DB.Save(author1).Error) + + author2 := s.createAuthor("Author C2") + author2.CountryID = &country2.ID + s.Require().NoError(s.DB.Save(author2).Error) + + // Act + authors, err := s.AuthorRepo.ListByCountryID(context.Background(), country1.ID) + + // Assert + s.Require().NoError(err) + s.Len(authors, 1) + s.Equal("Author C1", authors[0].Name) + }) +} + +func (s *AuthorRepositoryTestSuite) TestGetWithTranslations() { + s.Run("should return author with preloaded translations", func() { + // Arrange + author := s.createAuthor("Translated Author") + translation := &domain.Translation{ + TranslatableType: "authors", + TranslatableID: author.ID, + Language: "es", + Title: "Autor Traducido", + Content: "Una biografĂ­a.", + } + s.Require().NoError(s.DB.Create(translation).Error) + + // Act + foundAuthor, err := s.AuthorRepo.GetWithTranslations(context.Background(), author.ID) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(foundAuthor) + s.Require().Len(foundAuthor.Translations, 1) + s.Equal("es", foundAuthor.Translations[0].Language) + s.Equal("Una biografĂ­a.", foundAuthor.Translations[0].Content) }) } diff --git a/internal/data/sql/base_repository.go b/internal/data/sql/base_repository.go index e406bd2..c5503e8 100644 --- a/internal/data/sql/base_repository.go +++ b/internal/data/sql/base_repository.go @@ -14,16 +14,6 @@ import ( "gorm.io/gorm" ) -// Common repository errors -var ( - ErrEntityNotFound = errors.New("entity not found") - ErrInvalidID = errors.New("invalid ID: cannot be zero") - ErrInvalidInput = errors.New("invalid input parameters") - ErrDatabaseOperation = errors.New("database operation failed") - ErrContextRequired = errors.New("context is required") - ErrTransactionFailed = errors.New("transaction failed") -) - // BaseRepositoryImpl provides a default implementation of BaseRepository using GORM type BaseRepositoryImpl[T any] struct { db *gorm.DB @@ -43,7 +33,7 @@ func NewBaseRepositoryImpl[T any](db *gorm.DB, cfg *config.Config) *BaseReposito // validateContext ensures context is not nil func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { if ctx == nil { - return ErrContextRequired + return domain.ErrValidation } return nil } @@ -51,7 +41,7 @@ func (r *BaseRepositoryImpl[T]) validateContext(ctx context.Context) error { // validateID ensures ID is valid func (r *BaseRepositoryImpl[T]) validateID(id uint) error { if id == 0 { - return ErrInvalidID + return domain.ErrValidation } return nil } @@ -59,7 +49,7 @@ func (r *BaseRepositoryImpl[T]) validateID(id uint) error { // validateEntity ensures entity is not nil func (r *BaseRepositoryImpl[T]) validateEntity(entity *T) error { if entity == nil { - return ErrInvalidInput + return domain.ErrValidation } return nil } @@ -133,7 +123,7 @@ func (r *BaseRepositoryImpl[T]) Create(ctx context.Context, entity *T) error { if err != nil { log.Error(err, "Failed to create entity") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("database operation failed: %w", err) } log.Debug(fmt.Sprintf("Entity created successfully in %s", duration)) @@ -151,7 +141,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent return err } if tx == nil { - return ErrTransactionFailed + return domain.ErrInvalidOperation } start := time.Now() @@ -160,7 +150,7 @@ func (r *BaseRepositoryImpl[T]) CreateInTx(ctx context.Context, tx *gorm.DB, ent if err != nil { log.Error(err, "Failed to create entity in transaction") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("database operation failed: %w", err) } log.Debug(fmt.Sprintf("Entity created successfully in transaction in %s", duration)) @@ -186,10 +176,10 @@ func (r *BaseRepositoryImpl[T]) GetByID(ctx context.Context, id uint) (*T, error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found in %s", id, duration)) - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d", id)) - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully in %s", id, duration)) @@ -216,10 +206,10 @@ func (r *BaseRepositoryImpl[T]) GetByIDWithOptions(ctx context.Context, id uint, if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with options in %s", id, duration)) - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity by ID %d with options", id)) - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } log.Debug(fmt.Sprintf("Entity with id %d retrieved successfully with options in %s", id, duration)) @@ -243,7 +233,7 @@ func (r *BaseRepositoryImpl[T]) Update(ctx context.Context, entity *T) error { if err != nil { log.Error(err, "Failed to update entity") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("database operation failed: %w", err) } log.Debug(fmt.Sprintf("Entity updated successfully in %s", duration)) @@ -261,7 +251,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent return err } if tx == nil { - return ErrTransactionFailed + return domain.ErrInvalidOperation } start := time.Now() @@ -270,7 +260,7 @@ func (r *BaseRepositoryImpl[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, ent if err != nil { log.Error(err, "Failed to update entity in transaction") - return fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return fmt.Errorf("database operation failed: %w", err) } log.Debug(fmt.Sprintf("Entity updated successfully in transaction in %s", duration)) @@ -295,12 +285,12 @@ func (r *BaseRepositoryImpl[T]) Delete(ctx context.Context, id uint) error { if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d", id)) - return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) + return fmt.Errorf("database operation failed: %w", result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in %s", id, duration)) - return ErrEntityNotFound + return domain.ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in %s", id, duration)) @@ -318,7 +308,7 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id return err } if tx == nil { - return ErrTransactionFailed + return domain.ErrInvalidOperation } start := time.Now() @@ -328,12 +318,12 @@ func (r *BaseRepositoryImpl[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id if result.Error != nil { log.Error(result.Error, fmt.Sprintf("Failed to delete entity with id %d in transaction", id)) - return fmt.Errorf("%w: %v", ErrDatabaseOperation, result.Error) + return fmt.Errorf("database operation failed: %w", result.Error) } if result.RowsAffected == 0 { log.Debug(fmt.Sprintf("No entity with id %d found to delete in transaction in %s", id, duration)) - return ErrEntityNotFound + return domain.ErrEntityNotFound } log.Debug(fmt.Sprintf("Entity with id %d deleted successfully in transaction in %s", id, duration)) @@ -360,7 +350,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (* // Get total count if err := r.db.WithContext(ctx).Model(new(T)).Count(&totalCount).Error; err != nil { log.Error(err, "Failed to count entities") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } // Calculate offset @@ -369,7 +359,7 @@ func (r *BaseRepositoryImpl[T]) List(ctx context.Context, page, pageSize int) (* // Get paginated data if err := r.db.WithContext(ctx).Offset(offset).Limit(pageSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get paginated entities") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -410,7 +400,7 @@ func (r *BaseRepositoryImpl[T]) ListWithOptions(ctx context.Context, options *do if err := query.Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities with options") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -431,7 +421,7 @@ func (r *BaseRepositoryImpl[T]) ListAll(ctx context.Context) ([]T, error) { var entities []T if err := r.db.WithContext(ctx).Find(&entities).Error; err != nil { log.Error(err, "Failed to get all entities") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -452,7 +442,7 @@ func (r *BaseRepositoryImpl[T]) Count(ctx context.Context) (int64, error) { var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities") - return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return 0, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -475,7 +465,7 @@ func (r *BaseRepositoryImpl[T]) CountWithOptions(ctx context.Context, options *d if err := query.Model(new(T)).Count(&count).Error; err != nil { log.Error(err, "Failed to count entities with options") - return 0, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return 0, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -506,10 +496,10 @@ func (r *BaseRepositoryImpl[T]) FindWithPreload(ctx context.Context, preloads [] if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Debug(fmt.Sprintf("Entity with id %d not found with preloads in %s", id, time.Since(start))) - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } log.Error(err, fmt.Sprintf("Failed to get entity with id %d with preloads", id)) - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -541,7 +531,7 @@ func (r *BaseRepositoryImpl[T]) GetAllForSync(ctx context.Context, batchSize, of var entities []T if err := r.db.WithContext(ctx).Offset(offset).Limit(batchSize).Find(&entities).Error; err != nil { log.Error(err, "Failed to get entities for sync") - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -565,7 +555,7 @@ func (r *BaseRepositoryImpl[T]) Exists(ctx context.Context, id uint) (bool, erro var count int64 if err := r.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Count(&count).Error; err != nil { log.Error(err, fmt.Sprintf("Failed to check entity existence for id %d", id)) - return false, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return false, fmt.Errorf("database operation failed: %w", err) } duration := time.Since(start) @@ -587,7 +577,7 @@ func (r *BaseRepositoryImpl[T]) BeginTx(ctx context.Context) (*gorm.DB, error) { tx := r.db.WithContext(ctx).Begin() if tx.Error != nil { log.Error(tx.Error, "Failed to begin transaction") - return nil, fmt.Errorf("%w: %v", ErrTransactionFailed, tx.Error) + return nil, fmt.Errorf("transaction failed: %w", tx.Error) } log.Debug("Transaction started successfully") @@ -625,9 +615,9 @@ func (r *BaseRepositoryImpl[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) if err := tx.Commit().Error; err != nil { log.Error(err, "Failed to commit transaction") - return fmt.Errorf("%w: %v", ErrTransactionFailed, err) + return fmt.Errorf("transaction failed: %w", err) } log.Debug("Transaction committed successfully") return nil -} +} \ No newline at end of file diff --git a/internal/data/sql/base_repository_test.go b/internal/data/sql/base_repository_test.go index 56d0be8..f661da8 100644 --- a/internal/data/sql/base_repository_test.go +++ b/internal/data/sql/base_repository_test.go @@ -31,6 +31,7 @@ func (s *BaseRepositoryTestSuite) SetupSuite() { // SetupTest cleans the database before each test. func (s *BaseRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() s.DB.Exec("DELETE FROM test_entities") } @@ -76,13 +77,13 @@ func (s *BaseRepositoryTestSuite) TestCreate() { s.Run("should return error for nil entity", func() { err := s.repo.Create(context.Background(), nil) - s.ErrorIs(err, sql.ErrInvalidInput) + s.ErrorIs(err, domain.ErrValidation) }) s.Run("should return error for nil context", func() { //nolint:staticcheck // Testing behavior with nil context is intentional here. err := s.repo.Create(nil, &testutil.TestEntity{Name: "Test Context"}) - s.ErrorIs(err, sql.ErrContextRequired) + s.ErrorIs(err, domain.ErrValidation) }) } @@ -103,12 +104,12 @@ func (s *BaseRepositoryTestSuite) TestGetByID() { s.Run("should return ErrEntityNotFound for non-existent ID", func() { _, err := s.repo.GetByID(context.Background(), 99999) - s.ErrorIs(err, sql.ErrEntityNotFound) + s.ErrorIs(err, domain.ErrEntityNotFound) }) - s.Run("should return ErrInvalidID for zero ID", func() { + s.Run("should return ErrValidation for zero ID", func() { _, err := s.repo.GetByID(context.Background(), 0) - s.ErrorIs(err, sql.ErrInvalidID) + s.ErrorIs(err, domain.ErrValidation) }) } @@ -140,12 +141,12 @@ func (s *BaseRepositoryTestSuite) TestDelete() { // Assert s.Require().NoError(err) _, getErr := s.repo.GetByID(context.Background(), created.ID) - s.ErrorIs(getErr, sql.ErrEntityNotFound) + s.ErrorIs(getErr, domain.ErrEntityNotFound) }) s.Run("should return ErrEntityNotFound when deleting non-existent entity", func() { err := s.repo.Delete(context.Background(), 99999) - s.ErrorIs(err, sql.ErrEntityNotFound) + s.ErrorIs(err, domain.ErrEntityNotFound) }) } @@ -261,6 +262,6 @@ func (s *BaseRepositoryTestSuite) TestWithTx() { s.ErrorIs(err, simulatedErr) _, getErr := s.repo.GetByID(context.Background(), createdID) - s.ErrorIs(getErr, sql.ErrEntityNotFound, "Entity should not exist after rollback") + s.ErrorIs(getErr, domain.ErrEntityNotFound, "Entity should not exist after rollback") }) } \ No newline at end of file diff --git a/internal/data/sql/book_repository.go b/internal/data/sql/book_repository.go index 5b17765..de68d1b 100644 --- a/internal/data/sql/book_repository.go +++ b/internal/data/sql/book_repository.go @@ -12,7 +12,7 @@ import ( ) type bookRepository struct { - domain.BaseRepository[domain.Book] + *BaseRepositoryImpl[domain.Book] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type bookRepository struct { // NewBookRepository creates a new BookRepository. func NewBookRepository(db *gorm.DB, cfg *config.Config) domain.BookRepository { return &bookRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Book](db, cfg), - db: db, - tracer: otel.Tracer("book.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Book](db, cfg), + db: db, + tracer: otel.Tracer("book.repository"), } } @@ -70,9 +70,9 @@ func (r *bookRepository) FindByISBN(ctx context.Context, isbn string) (*domain.B var book domain.Book if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&book).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } return &book, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/book_repository_test.go b/internal/data/sql/book_repository_test.go index bd381bf..b2065ba 100644 --- a/internal/data/sql/book_repository_test.go +++ b/internal/data/sql/book_repository_test.go @@ -24,7 +24,13 @@ func (s *BookRepositoryTestSuite) SetupSuite() { } func (s *BookRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() s.DB.Exec("DELETE FROM books") + s.DB.Exec("DELETE FROM authors") + s.DB.Exec("DELETE FROM publishers") + s.DB.Exec("DELETE FROM book_authors") + s.DB.Exec("DELETE FROM book_works") + s.DB.Exec("DELETE FROM works") } func (s *BookRepositoryTestSuite) createBook(title, isbn string) *domain.Book { @@ -40,6 +46,20 @@ func (s *BookRepositoryTestSuite) createBook(title, isbn string) *domain.Book { return book } +func (s *BookRepositoryTestSuite) createAuthor(name string) *domain.Author { + author := &domain.Author{Name: name} + err := s.DB.Create(author).Error + s.Require().NoError(err) + return author +} + +func (s *BookRepositoryTestSuite) createPublisher(name string) *domain.Publisher { + publisher := &domain.Publisher{Name: name} + err := s.DB.Create(publisher).Error + s.Require().NoError(err) + return publisher +} + func (s *BookRepositoryTestSuite) TestFindByISBN() { s.Run("should return a book by ISBN", func() { // Arrange @@ -68,4 +88,76 @@ func (s *BookRepositoryTestSuite) TestFindByISBN() { func TestBookRepository(t *testing.T) { suite.Run(t, new(BookRepositoryTestSuite)) +} + +func (s *BookRepositoryTestSuite) TestListByAuthorID() { + s.Run("should return all books for a given author", func() { + // Arrange + author1 := s.createAuthor("Test Author 1") + author2 := s.createAuthor("Test Author 2") + book1 := s.createBook("Book 1 by Author 1", "111") + book2 := s.createBook("Book 2 by Author 1", "222") + book3 := s.createBook("Book 3 by Author 2", "333") + + s.Require().NoError(s.DB.Model(&author1).Association("Books").Append([]*domain.Book{book1, book2})) + s.Require().NoError(s.DB.Model(&author2).Association("Books").Append(book3)) + + // Act + books, err := s.BookRepo.ListByAuthorID(context.Background(), author1.ID) + + // Assert + s.Require().NoError(err) + s.Len(books, 2) + s.ElementsMatch([]string{"Book 1 by Author 1", "Book 2 by Author 1"}, []string{books[0].Title, books[1].Title}) + }) +} + +func (s *BookRepositoryTestSuite) TestListByPublisherID() { + s.Run("should return all books for a given publisher", func() { + // Arrange + publisher1 := s.createPublisher("Publisher 1") + publisher2 := s.createPublisher("Publisher 2") + book1 := s.createBook("Book 1 from Publisher 1", "111") + book2 := s.createBook("Book 2 from Publisher 1", "222") + book3 := s.createBook("Book 3 from Publisher 2", "333") + + book1.PublisherID = &publisher1.ID + book2.PublisherID = &publisher1.ID + book3.PublisherID = &publisher2.ID + s.Require().NoError(s.DB.Save(book1).Error) + s.Require().NoError(s.DB.Save(book2).Error) + s.Require().NoError(s.DB.Save(book3).Error) + + // Act + books, err := s.BookRepo.ListByPublisherID(context.Background(), publisher1.ID) + + // Assert + s.Require().NoError(err) + s.Len(books, 2) + s.ElementsMatch([]string{"Book 1 from Publisher 1", "Book 2 from Publisher 1"}, []string{books[0].Title, books[1].Title}) + }) +} + +func (s *BookRepositoryTestSuite) TestListByWorkID() { + s.Run("should return all books associated with a given work", func() { + // Arrange + work1 := s.CreateTestWork(s.AdminCtx, "Work 1", "en", "content 1") + work2 := s.CreateTestWork(s.AdminCtx, "Work 2", "en", "content 2") + book1 := s.createBook("Book 1 for Work 1", "111") + book2 := s.createBook("Book 2 for Work 1", "222") + book3 := s.createBook("Book 3 for Work 2", "333") + + // Manually create the association in the join table + s.Require().NoError(s.DB.Exec("INSERT INTO book_works (book_id, work_id) VALUES (?, ?)", book1.ID, work1.ID).Error) + s.Require().NoError(s.DB.Exec("INSERT INTO book_works (book_id, work_id) VALUES (?, ?)", book2.ID, work1.ID).Error) + s.Require().NoError(s.DB.Exec("INSERT INTO book_works (book_id, work_id) VALUES (?, ?)", book3.ID, work2.ID).Error) + + // Act + books, err := s.BookRepo.ListByWorkID(context.Background(), work1.ID) + + // Assert + s.Require().NoError(err) + s.Len(books, 2) + s.ElementsMatch([]string{"Book 1 for Work 1", "Book 2 for Work 1"}, []string{books[0].Title, books[1].Title}) + }) } \ No newline at end of file diff --git a/internal/data/sql/category_repository.go b/internal/data/sql/category_repository.go index c656dd0..e37aca4 100644 --- a/internal/data/sql/category_repository.go +++ b/internal/data/sql/category_repository.go @@ -12,7 +12,7 @@ import ( ) type categoryRepository struct { - domain.BaseRepository[domain.Category] + *BaseRepositoryImpl[domain.Category] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type categoryRepository struct { // NewCategoryRepository creates a new CategoryRepository. func NewCategoryRepository(db *gorm.DB, cfg *config.Config) domain.CategoryRepository { return &categoryRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Category](db, cfg), - db: db, - tracer: otel.Tracer("category.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Category](db, cfg), + db: db, + tracer: otel.Tracer("category.repository"), } } @@ -33,7 +33,7 @@ func (r *categoryRepository) FindByName(ctx context.Context, name string) (*doma var category domain.Category if err := r.db.WithContext(ctx).Where("name = ?", name).First(&category).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -68,4 +68,4 @@ func (r *categoryRepository) ListByParentID(ctx context.Context, parentID *uint) } } return categories, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/category_repository_test.go b/internal/data/sql/category_repository_test.go index 0edef8a..eba484d 100644 --- a/internal/data/sql/category_repository_test.go +++ b/internal/data/sql/category_repository_test.go @@ -24,6 +24,7 @@ func (s *CategoryRepositoryTestSuite) SetupSuite() { } func (s *CategoryRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() s.DB.Exec("DELETE FROM work_categories") s.DB.Exec("DELETE FROM categories") s.DB.Exec("DELETE FROM works") @@ -64,7 +65,7 @@ func (s *CategoryRepositoryTestSuite) TestFindByName() { func (s *CategoryRepositoryTestSuite) TestListByWorkID() { s.Run("should return all categories for a given work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") cat1 := s.createCategory("Science Fiction", nil) cat2 := s.createCategory("Cyberpunk", &cat1.ID) diff --git a/internal/data/sql/copyright_repository.go b/internal/data/sql/copyright_repository.go index cd4a301..2f89603 100644 --- a/internal/data/sql/copyright_repository.go +++ b/internal/data/sql/copyright_repository.go @@ -12,7 +12,7 @@ import ( ) type copyrightRepository struct { - domain.BaseRepository[domain.Copyright] + *BaseRepositoryImpl[domain.Copyright] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type copyrightRepository struct { // NewCopyrightRepository creates a new CopyrightRepository. func NewCopyrightRepository(db *gorm.DB, cfg *config.Config) domain.CopyrightRepository { return ©rightRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Copyright](db, cfg), - db: db, - tracer: otel.Tracer("copyright.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Copyright](db, cfg), + db: db, + tracer: otel.Tracer("copyright.repository"), } } @@ -50,7 +50,7 @@ func (r *copyrightRepository) GetTranslationByLanguage(ctx context.Context, copy err := r.db.WithContext(ctx).Where("copyright_id = ? AND language_code = ?", copyrightID, languageCode).First(&translation).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -115,4 +115,4 @@ func (r *copyrightRepository) RemoveCopyrightFromSource(ctx context.Context, sou ctx, span := r.tracer.Start(ctx, "RemoveCopyrightFromSource") defer span.End() return r.db.WithContext(ctx).Exec("DELETE FROM source_copyrights WHERE source_id = ? AND copyright_id = ?", sourceID, copyrightID).Error -} +} \ No newline at end of file diff --git a/internal/data/sql/country_repository.go b/internal/data/sql/country_repository.go index 0c12e6d..a7d5545 100644 --- a/internal/data/sql/country_repository.go +++ b/internal/data/sql/country_repository.go @@ -10,15 +10,15 @@ import ( ) type countryRepository struct { - domain.BaseRepository[domain.Country] + *BaseRepositoryImpl[domain.Country] db *gorm.DB } // NewCountryRepository creates a new CountryRepository. func NewCountryRepository(db *gorm.DB, cfg *config.Config) domain.CountryRepository { return &countryRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Country](db, cfg), - db: db, + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Country](db, cfg), + db: db, } } @@ -27,7 +27,7 @@ func (r *countryRepository) GetByCode(ctx context.Context, code string) (*domain var country domain.Country if err := r.db.WithContext(ctx).Where("code = ?", code).First(&country).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -41,4 +41,4 @@ func (r *countryRepository) ListByContinent(ctx context.Context, continent strin return nil, err } return countries, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/edition_repository.go b/internal/data/sql/edition_repository.go index 57e28bc..aab821f 100644 --- a/internal/data/sql/edition_repository.go +++ b/internal/data/sql/edition_repository.go @@ -12,7 +12,7 @@ import ( ) type editionRepository struct { - domain.BaseRepository[domain.Edition] + *BaseRepositoryImpl[domain.Edition] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type editionRepository struct { // NewEditionRepository creates a new EditionRepository. func NewEditionRepository(db *gorm.DB, cfg *config.Config) domain.EditionRepository { return &editionRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Edition](db, cfg), - db: db, - tracer: otel.Tracer("edition.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Edition](db, cfg), + db: db, + tracer: otel.Tracer("edition.repository"), } } @@ -44,9 +44,9 @@ func (r *editionRepository) FindByISBN(ctx context.Context, isbn string) (*domai var edition domain.Edition if err := r.db.WithContext(ctx).Where("isbn = ?", isbn).First(&edition).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } return &edition, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/email_verification_repository.go b/internal/data/sql/email_verification_repository.go index 31d8326..f3f836b 100644 --- a/internal/data/sql/email_verification_repository.go +++ b/internal/data/sql/email_verification_repository.go @@ -13,7 +13,7 @@ import ( ) type emailVerificationRepository struct { - domain.BaseRepository[domain.EmailVerification] + *BaseRepositoryImpl[domain.EmailVerification] db *gorm.DB tracer trace.Tracer } @@ -21,9 +21,9 @@ type emailVerificationRepository struct { // NewEmailVerificationRepository creates a new EmailVerificationRepository. func NewEmailVerificationRepository(db *gorm.DB, cfg *config.Config) domain.EmailVerificationRepository { return &emailVerificationRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.EmailVerification](db, cfg), - db: db, - tracer: otel.Tracer("email_verification.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.EmailVerification](db, cfg), + db: db, + tracer: otel.Tracer("email_verification.repository"), } } @@ -34,7 +34,7 @@ func (r *emailVerificationRepository) GetByToken(ctx context.Context, token stri var verification domain.EmailVerification if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&verification).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/like_repository_test.go b/internal/data/sql/like_repository_test.go new file mode 100644 index 0000000..4eb06bf --- /dev/null +++ b/internal/data/sql/like_repository_test.go @@ -0,0 +1,129 @@ +package sql_test + +import ( + "context" + "testing" + "tercul/internal/data/sql" + "tercul/internal/domain" + "tercul/internal/platform/config" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type LikeRepositoryTestSuite struct { + testutil.IntegrationTestSuite + LikeRepo domain.LikeRepository + UserRepo domain.UserRepository + WorkRepo domain.WorkRepository +} + +func (s *LikeRepositoryTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(testutil.DefaultTestConfig()) + cfg, err := config.LoadConfig() + s.Require().NoError(err) + s.LikeRepo = sql.NewLikeRepository(s.DB, cfg) + s.UserRepo = sql.NewUserRepository(s.DB, cfg) + s.WorkRepo = sql.NewWorkRepository(s.DB, cfg) +} + +func (s *LikeRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + s.DB.Exec("DELETE FROM likes") + s.DB.Exec("DELETE FROM works") +} + +func TestLikeRepository(t *testing.T) { + suite.Run(t, new(LikeRepositoryTestSuite)) +} + +func (s *LikeRepositoryTestSuite) createUser(username string) *domain.User { + user := &domain.User{Username: username, Email: username + "@test.com"} + err := s.UserRepo.Create(context.Background(), user) + s.Require().NoError(err) + return user +} + +func (s *LikeRepositoryTestSuite) TestListByUserID() { + s.Run("should return all likes for a given user", func() { + // Arrange + user1 := s.createUser("user1") + user2 := s.createUser("user2") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") + + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, WorkID: &work.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, WorkID: &work.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user2.ID, WorkID: &work.ID})) + + // Act + likes, err := s.LikeRepo.ListByUserID(context.Background(), user1.ID) + + // Assert + s.Require().NoError(err) + s.Len(likes, 2) + }) +} + +func (s *LikeRepositoryTestSuite) TestListByWorkID() { + s.Run("should return all likes for a given work", func() { + // Arrange + user1 := s.createUser("user1") + work1 := s.CreateTestWork(s.AdminCtx, "Test Work 1", "en", "Test content") + work2 := s.CreateTestWork(s.AdminCtx, "Test Work 2", "en", "Test content") + + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, WorkID: &work1.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, WorkID: &work1.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, WorkID: &work2.ID})) + + // Act + likes, err := s.LikeRepo.ListByWorkID(context.Background(), work1.ID) + + // Assert + s.Require().NoError(err) + s.Len(likes, 2) + }) +} + +func (s *LikeRepositoryTestSuite) TestListByTranslationID() { + s.Run("should return all likes for a given translation", func() { + // Arrange + user1 := s.createUser("user1") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") + translation1 := s.CreateTestTranslation(work.ID, "es", "Contenido de prueba") + translation2 := s.CreateTestTranslation(work.ID, "fr", "Contenu de test") + + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, TranslationID: &translation1.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, TranslationID: &translation1.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, TranslationID: &translation2.ID})) + + // Act + likes, err := s.LikeRepo.ListByTranslationID(context.Background(), translation1.ID) + + // Assert + s.Require().NoError(err) + s.Len(likes, 2) + }) +} + +func (s *LikeRepositoryTestSuite) TestListByCommentID() { + s.Run("should return all likes for a given comment", func() { + // Arrange + user1 := s.createUser("user1") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") + comment1 := &domain.Comment{UserID: user1.ID, WorkID: &work.ID, Text: "Comment 1"} + comment2 := &domain.Comment{UserID: user1.ID, WorkID: &work.ID, Text: "Comment 2"} + s.Require().NoError(s.DB.Create(comment1).Error) + s.Require().NoError(s.DB.Create(comment2).Error) + + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, CommentID: &comment1.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, CommentID: &comment1.ID})) + s.Require().NoError(s.LikeRepo.Create(context.Background(), &domain.Like{UserID: user1.ID, CommentID: &comment2.ID})) + + // Act + likes, err := s.LikeRepo.ListByCommentID(context.Background(), comment1.ID) + + // Assert + s.Require().NoError(err) + s.Len(likes, 2) + }) +} \ No newline at end of file diff --git a/internal/data/sql/monetization_repository_test.go b/internal/data/sql/monetization_repository_test.go index 946a8da..84e9f50 100644 --- a/internal/data/sql/monetization_repository_test.go +++ b/internal/data/sql/monetization_repository_test.go @@ -24,6 +24,7 @@ func (s *MonetizationRepositoryTestSuite) SetupSuite() { } func (s *MonetizationRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() s.DB.Exec("DELETE FROM work_monetizations") s.DB.Exec("DELETE FROM monetizations") s.DB.Exec("DELETE FROM works") @@ -32,7 +33,7 @@ func (s *MonetizationRepositoryTestSuite) SetupTest() { func (s *MonetizationRepositoryTestSuite) TestAddMonetizationToWork() { s.Run("should add a monetization to a work", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") monetization := &domain.Monetization{Amount: 10.0} s.Require().NoError(s.DB.Create(monetization).Error) diff --git a/internal/data/sql/password_reset_repository.go b/internal/data/sql/password_reset_repository.go index 6a81174..236a500 100644 --- a/internal/data/sql/password_reset_repository.go +++ b/internal/data/sql/password_reset_repository.go @@ -13,7 +13,7 @@ import ( ) type passwordResetRepository struct { - domain.BaseRepository[domain.PasswordReset] + *BaseRepositoryImpl[domain.PasswordReset] db *gorm.DB tracer trace.Tracer } @@ -21,9 +21,9 @@ type passwordResetRepository struct { // NewPasswordResetRepository creates a new PasswordResetRepository. func NewPasswordResetRepository(db *gorm.DB, cfg *config.Config) domain.PasswordResetRepository { return &passwordResetRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.PasswordReset](db, cfg), - db: db, - tracer: otel.Tracer("password_reset.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.PasswordReset](db, cfg), + db: db, + tracer: otel.Tracer("password_reset.repository"), } } @@ -34,7 +34,7 @@ func (r *passwordResetRepository) GetByToken(ctx context.Context, token string) var reset domain.PasswordReset if err := r.db.WithContext(ctx).Where("token = ? AND used = ? AND expires_at > ?", token, false, time.Now()).First(&reset).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } diff --git a/internal/data/sql/source_repository.go b/internal/data/sql/source_repository.go index 4702c8c..5c51a35 100644 --- a/internal/data/sql/source_repository.go +++ b/internal/data/sql/source_repository.go @@ -12,7 +12,7 @@ import ( ) type sourceRepository struct { - domain.BaseRepository[domain.Source] + *BaseRepositoryImpl[domain.Source] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type sourceRepository struct { // NewSourceRepository creates a new SourceRepository. func NewSourceRepository(db *gorm.DB, cfg *config.Config) domain.SourceRepository { return &sourceRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Source](db, cfg), - db: db, - tracer: otel.Tracer("source.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Source](db, cfg), + db: db, + tracer: otel.Tracer("source.repository"), } } @@ -46,9 +46,9 @@ func (r *sourceRepository) FindByURL(ctx context.Context, url string) (*domain.S var source domain.Source if err := r.db.WithContext(ctx).Where("url = ?", url).First(&source).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } return &source, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/tag_repository.go b/internal/data/sql/tag_repository.go index f61e975..0038e9d 100644 --- a/internal/data/sql/tag_repository.go +++ b/internal/data/sql/tag_repository.go @@ -12,7 +12,7 @@ import ( ) type tagRepository struct { - domain.BaseRepository[domain.Tag] + *BaseRepositoryImpl[domain.Tag] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type tagRepository struct { // NewTagRepository creates a new TagRepository. func NewTagRepository(db *gorm.DB, cfg *config.Config) domain.TagRepository { return &tagRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.Tag](db, cfg), - db: db, - tracer: otel.Tracer("tag.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.Tag](db, cfg), + db: db, + tracer: otel.Tracer("tag.repository"), } } @@ -33,7 +33,7 @@ func (r *tagRepository) FindByName(ctx context.Context, name string) (*domain.Ta var tag domain.Tag if err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -51,4 +51,4 @@ func (r *tagRepository) ListByWorkID(ctx context.Context, workID uint) ([]domain return nil, err } return tags, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/user_profile_repository.go b/internal/data/sql/user_profile_repository.go index 351adeb..2ba54a0 100644 --- a/internal/data/sql/user_profile_repository.go +++ b/internal/data/sql/user_profile_repository.go @@ -12,7 +12,7 @@ import ( ) type userProfileRepository struct { - domain.BaseRepository[domain.UserProfile] + *BaseRepositoryImpl[domain.UserProfile] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type userProfileRepository struct { // NewUserProfileRepository creates a new UserProfileRepository. func NewUserProfileRepository(db *gorm.DB, cfg *config.Config) domain.UserProfileRepository { return &userProfileRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.UserProfile](db, cfg), - db: db, - tracer: otel.Tracer("user_profile.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.UserProfile](db, cfg), + db: db, + tracer: otel.Tracer("user_profile.repository"), } } @@ -33,9 +33,9 @@ func (r *userProfileRepository) GetByUserID(ctx context.Context, userID uint) (* var profile domain.UserProfile if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&profile).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } return &profile, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/user_repository.go b/internal/data/sql/user_repository.go index a6bed79..b745c3e 100644 --- a/internal/data/sql/user_repository.go +++ b/internal/data/sql/user_repository.go @@ -12,7 +12,7 @@ import ( ) type userRepository struct { - domain.BaseRepository[domain.User] + *BaseRepositoryImpl[domain.User] db *gorm.DB tracer trace.Tracer } @@ -20,9 +20,9 @@ type userRepository struct { // NewUserRepository creates a new UserRepository. func NewUserRepository(db *gorm.DB, cfg *config.Config) domain.UserRepository { return &userRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.User](db, cfg), - db: db, - tracer: otel.Tracer("user.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.User](db, cfg), + db: db, + tracer: otel.Tracer("user.repository"), } } @@ -33,7 +33,7 @@ func (r *userRepository) FindByUsername(ctx context.Context, username string) (* var user domain.User if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -47,7 +47,7 @@ func (r *userRepository) FindByEmail(ctx context.Context, email string) (*domain var user domain.User if err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -63,4 +63,4 @@ func (r *userRepository) ListByRole(ctx context.Context, role domain.UserRole) ( return nil, err } return users, nil -} +} \ No newline at end of file diff --git a/internal/data/sql/user_repository_test.go b/internal/data/sql/user_repository_test.go new file mode 100644 index 0000000..92b27bb --- /dev/null +++ b/internal/data/sql/user_repository_test.go @@ -0,0 +1,114 @@ +package sql_test + +import ( + "context" + "testing" + "tercul/internal/data/sql" + "tercul/internal/domain" + "tercul/internal/platform/config" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type UserRepositoryTestSuite struct { + testutil.IntegrationTestSuite + UserRepo domain.UserRepository +} + +func (s *UserRepositoryTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(testutil.DefaultTestConfig()) + cfg, err := config.LoadConfig() + s.Require().NoError(err) + s.UserRepo = sql.NewUserRepository(s.DB, cfg) +} + +func (s *UserRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + s.DB.Exec("DELETE FROM users") +} + +func TestUserRepository(t *testing.T) { + suite.Run(t, new(UserRepositoryTestSuite)) +} + +func (s *UserRepositoryTestSuite) createUser(username, email string, role domain.UserRole) *domain.User { + user := &domain.User{ + Username: username, + Email: email, + Role: role, + } + err := s.UserRepo.Create(context.Background(), user) + s.Require().NoError(err) + return user +} + +func (s *UserRepositoryTestSuite) TestFindByUsername() { + s.Run("should find a user by username", func() { + // Arrange + expectedUser := s.createUser("testuser", "test@test.com", domain.UserRoleReader) + + // Act + foundUser, err := s.UserRepo.FindByUsername(context.Background(), "testuser") + + // Assert + s.Require().NoError(err) + s.Require().NotNil(foundUser) + s.Equal(expectedUser.ID, foundUser.ID) + s.Equal("testuser", foundUser.Username) + }) + + s.Run("should return error if user not found", func() { + _, err := s.UserRepo.FindByUsername(context.Background(), "nonexistent") + s.Require().Error(err) + s.ErrorIs(err, domain.ErrEntityNotFound) + }) +} + +func (s *UserRepositoryTestSuite) TestFindByEmail() { + s.Run("should find a user by email", func() { + // Arrange + expectedUser := s.createUser("testuser", "test@test.com", domain.UserRoleReader) + + // Act + foundUser, err := s.UserRepo.FindByEmail(context.Background(), "test@test.com") + + // Assert + s.Require().NoError(err) + s.Require().NotNil(foundUser) + s.Equal(expectedUser.ID, foundUser.ID) + s.Equal("test@test.com", foundUser.Email) + }) + + s.Run("should return error if user not found", func() { + _, err := s.UserRepo.FindByEmail(context.Background(), "nonexistent@test.com") + s.Require().Error(err) + s.ErrorIs(err, domain.ErrEntityNotFound) + }) +} + +func (s *UserRepositoryTestSuite) TestListByRole() { + s.Run("should return all users for a given role", func() { + // Arrange + s.createUser("reader1", "reader1@test.com", domain.UserRoleReader) + s.createUser("reader2", "reader2@test.com", domain.UserRoleReader) + s.createUser("admin1", "admin1@test.com", domain.UserRoleAdmin) + + // Act + readers, err := s.UserRepo.ListByRole(context.Background(), domain.UserRoleReader) + s.Require().NoError(err) + + admins, err := s.UserRepo.ListByRole(context.Background(), domain.UserRoleAdmin) + s.Require().NoError(err) + + // Assert + s.Len(readers, 2) + s.Len(admins, 1) + }) + + s.Run("should return empty slice if no users for role", func() { + users, err := s.UserRepo.ListByRole(context.Background(), domain.UserRoleContributor) + s.Require().NoError(err) + s.Len(users, 0) + }) +} \ No newline at end of file diff --git a/internal/data/sql/user_session_repository.go b/internal/data/sql/user_session_repository.go index a431822..27e7798 100644 --- a/internal/data/sql/user_session_repository.go +++ b/internal/data/sql/user_session_repository.go @@ -13,7 +13,7 @@ import ( ) type userSessionRepository struct { - domain.BaseRepository[domain.UserSession] + *BaseRepositoryImpl[domain.UserSession] db *gorm.DB tracer trace.Tracer } @@ -21,9 +21,9 @@ type userSessionRepository struct { // NewUserSessionRepository creates a new UserSessionRepository. func NewUserSessionRepository(db *gorm.DB, cfg *config.Config) domain.UserSessionRepository { return &userSessionRepository{ - BaseRepository: NewBaseRepositoryImpl[domain.UserSession](db, cfg), - db: db, - tracer: otel.Tracer("user_session.repository"), + BaseRepositoryImpl: NewBaseRepositoryImpl[domain.UserSession](db, cfg), + db: db, + tracer: otel.Tracer("user_session.repository"), } } @@ -34,7 +34,7 @@ func (r *userSessionRepository) GetByToken(ctx context.Context, token string) (* var session domain.UserSession if err := r.db.WithContext(ctx).Where("token = ?", token).First(&session).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } return nil, err } @@ -60,4 +60,4 @@ func (r *userSessionRepository) DeleteExpired(ctx context.Context) error { return err } return nil -} +} \ No newline at end of file diff --git a/internal/data/sql/work_repository.go b/internal/data/sql/work_repository.go index 2eaa1b3..71c1a50 100644 --- a/internal/data/sql/work_repository.go +++ b/internal/data/sql/work_repository.go @@ -185,9 +185,9 @@ func (r *workRepository) GetWithAssociationsInTx(ctx context.Context, tx *gorm.D } if err := query.First(&entity, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrEntityNotFound + return nil, domain.ErrEntityNotFound } - return nil, fmt.Errorf("%w: %v", ErrDatabaseOperation, err) + return nil, fmt.Errorf("database operation failed: %w", err) } return &entity, nil } diff --git a/internal/data/sql/work_repository_test.go b/internal/data/sql/work_repository_test.go index a075458..56ff4f7 100644 --- a/internal/data/sql/work_repository_test.go +++ b/internal/data/sql/work_repository_test.go @@ -23,6 +23,13 @@ func (s *WorkRepositoryTestSuite) SetupSuite() { s.WorkRepo = sql.NewWorkRepository(s.DB, cfg) } +func (s *WorkRepositoryTestSuite) SetupTest() { + s.IntegrationTestSuite.SetupTest() + s.DB.Exec("DELETE FROM work_copyrights") + s.DB.Exec("DELETE FROM copyrights") + s.DB.Exec("DELETE FROM works") +} + func (s *WorkRepositoryTestSuite) TestCreateWork() { s.Run("should create a new work with a copyright", func() { // Arrange @@ -67,7 +74,7 @@ func (s *WorkRepositoryTestSuite) TestGetWorkByID() { } s.Require().NoError(s.DB.Create(copyright).Error) - workModel := s.CreateTestWork("Test Work", "en", "Test content") + workModel := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") s.Require().NoError(s.DB.Model(workModel).Association("Copyrights").Append(copyright)) // Act @@ -98,7 +105,7 @@ func (s *WorkRepositoryTestSuite) TestUpdateWork() { s.Require().NoError(s.DB.Create(©right1).Error) s.Require().NoError(s.DB.Create(©right2).Error) - workModel := s.CreateTestWork("Original Title", "en", "Original content") + workModel := s.CreateTestWork(s.AdminCtx, "Original Title", "en", "Original content") s.Require().NoError(s.DB.Model(workModel).Association("Copyrights").Append(copyright1)) workModel.Title = "Updated Title" @@ -123,7 +130,7 @@ func (s *WorkRepositoryTestSuite) TestUpdateWork() { func (s *WorkRepositoryTestSuite) TestDeleteWork() { s.Run("should delete an existing work and its associations", func() { // Arrange - workModel := s.CreateTestWork("To Be Deleted", "en", "Content") + workModel := s.CreateTestWork(s.AdminCtx, "To Be Deleted", "en", "Content") copyright := &domain.Copyright{Name: "C1", Identificator: "C1"} s.Require().NoError(s.DB.Create(copyright).Error) s.Require().NoError(s.DB.Model(workModel).Association("Copyrights").Append(copyright)) diff --git a/internal/domain/errors.go b/internal/domain/errors.go index f806797..0f9e5df 100644 --- a/internal/domain/errors.go +++ b/internal/domain/errors.go @@ -2,17 +2,14 @@ package domain import "errors" -// Common domain-level errors that can be used across repositories and services. var ( ErrEntityNotFound = errors.New("entity not found") - ErrInvalidID = errors.New("invalid ID: cannot be zero") - ErrInvalidInput = errors.New("invalid input parameters") - ErrDatabaseOperation = errors.New("database operation failed") - ErrContextRequired = errors.New("context is required") - ErrTransactionFailed = errors.New("transaction failed") - ErrUnauthorized = errors.New("unauthorized") ErrForbidden = errors.New("forbidden") + ErrUnauthorized = errors.New("unauthorized") ErrValidation = errors.New("validation failed") - ErrConflict = errors.New("conflict with existing resource") ErrUserNotFound = errors.New("user not found") + ErrDuplicateEntity = errors.New("duplicate entity") + ErrOptimisticLock = errors.New("optimistic lock failure") + ErrInvalidOperation = errors.New("invalid operation") + ErrConflict = errors.New("conflict") ) \ No newline at end of file diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 7fbabf5..62c5303 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -241,6 +241,7 @@ type BaseRepository[T any] interface { // AuthorRepository defines CRUD methods specific to Author. type AuthorRepository interface { BaseRepository[Author] + FindByName(ctx context.Context, name string) (*Author, error) ListByWorkID(ctx context.Context, workID uint) ([]Author, error) ListByBookID(ctx context.Context, bookID uint) ([]Author, error) ListByCountryID(ctx context.Context, countryID uint) ([]Author, error) diff --git a/internal/jobs/linguistics/analysis_repository_test.go b/internal/jobs/linguistics/analysis_repository_test.go index 9dfac6c..551030c 100644 --- a/internal/jobs/linguistics/analysis_repository_test.go +++ b/internal/jobs/linguistics/analysis_repository_test.go @@ -28,7 +28,7 @@ func (s *AnalysisRepositoryTestSuite) SetupTest() { func (s *AnalysisRepositoryTestSuite) TestGetAnalysisData() { s.Run("should return the correct analysis data", func() { // Arrange - work := s.CreateTestWork("Test Work", "en", "Test content") + work := s.CreateTestWork(s.AdminCtx, "Test Work", "en", "Test content") textMetadata := &domain.TextMetadata{WorkID: work.ID, WordCount: 123} readabilityScore := &domain.ReadabilityScore{WorkID: work.ID, Score: 45.6} languageAnalysis := &domain.LanguageAnalysis{ diff --git a/internal/testutil/integration_test_utils.go b/internal/testutil/integration_test_utils.go index f2b581f..ba7cca8 100644 --- a/internal/testutil/integration_test_utils.go +++ b/internal/testutil/integration_test_utils.go @@ -48,9 +48,10 @@ func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pip // IntegrationTestSuite provides a comprehensive test environment with either in-memory SQLite or mock repositories type IntegrationTestSuite struct { suite.Suite - App *app.Application - DB *gorm.DB - AdminCtx context.Context + App *app.Application + DB *gorm.DB + AdminCtx context.Context + AdminToken string } // TestConfig holds configuration for the test environment @@ -118,7 +119,7 @@ func (s *IntegrationTestSuite) SetupSuite(testConfig *TestConfig) { &domain.Source{}, &domain.Copyright{}, &domain.Monetization{}, &domain.WorkStats{}, &domain.Trending{}, &domain.UserSession{}, &domain.Localization{}, &domain.LanguageAnalysis{}, &domain.TextMetadata{}, &domain.ReadabilityScore{}, - &domain.TranslationStats{}, &TestEntity{}, &domain.CollectionWork{}, + &domain.TranslationStats{}, &TestEntity{}, &domain.CollectionWork{}, &domain.BookWork{}, ) s.Require().NoError(err, "Failed to migrate database schema") @@ -137,7 +138,7 @@ func (s *IntegrationTestSuite) SetupSuite(testConfig *TestConfig) { analyticsService := analytics.NewService(repos.Analytics, analysisRepo, repos.Translation, repos.Work, sentimentProvider) jwtManager := platform_auth.NewJWTManager(cfg) - authzService := authz.NewService(repos.Work, repos.Translation) + authzService := authz.NewService(repos.Work, repos.Author, repos.User, repos.Translation) authorService := author.NewService(repos.Author) bookService := book.NewService(repos.Book, authzService) bookmarkService := bookmark.NewService(repos.Bookmark, analyticsService) @@ -152,7 +153,7 @@ func (s *IntegrationTestSuite) SetupSuite(testConfig *TestConfig) { userService := user.NewService(repos.User, authzService, repos.UserProfile) localizationService := localization.NewService(repos.Localization) authService := app_auth.NewService(repos.User, jwtManager) - workService := work.NewService(repos.Work, searchClient, authzService, analyticsService) + workService := work.NewService(repos.Work, repos.Author, repos.User, searchClient, authzService, analyticsService) searchService := app_search.NewService(searchClient, localizationService) s.App = app.NewApplication( @@ -174,21 +175,6 @@ func (s *IntegrationTestSuite) SetupSuite(testConfig *TestConfig) { searchService, analyticsService, ) - - // Create a default admin user for tests - adminUser := &domain.User{ - Username: "admin", - Email: "admin@test.com", - Role: domain.UserRoleAdmin, - Active: true, - } - _ = adminUser.SetPassword("password") - err = s.DB.Create(adminUser).Error - s.Require().NoError(err) - s.AdminCtx = ContextWithClaims(context.Background(), &platform_auth.Claims{ - UserID: adminUser.ID, - Role: string(adminUser.Role), - }) } // TearDownSuite cleans up the test suite @@ -213,10 +199,33 @@ func (s *IntegrationTestSuite) SetupTest() { s.DB.Exec("DELETE FROM work_stats") s.DB.Exec("DELETE FROM translation_stats") } + + // Create a default admin user for tests + adminUser := &domain.User{ + Username: "admin", + Email: "admin@test.com", + Role: domain.UserRoleAdmin, + Active: true, + } + _ = adminUser.SetPassword("password") + err := s.DB.Create(adminUser).Error + s.Require().NoError(err) + s.AdminCtx = ContextWithClaims(context.Background(), &platform_auth.Claims{ + UserID: adminUser.ID, + Role: string(adminUser.Role), + }) + + // Generate a token for the admin user + cfg, err := platform_config.LoadConfig() + s.Require().NoError(err) + jwtManager := platform_auth.NewJWTManager(cfg) + token, err := jwtManager.GenerateToken(adminUser) + s.Require().NoError(err) + s.AdminToken = token } // CreateTestWork creates a test work with optional content -func (s *IntegrationTestSuite) CreateTestWork(title, language string, content string) *domain.Work { +func (s *IntegrationTestSuite) CreateTestWork(ctx context.Context, title, language string, content string) *domain.Work { work := &domain.Work{ Title: title, TranslatableModel: domain.TranslatableModel{ @@ -225,7 +234,7 @@ func (s *IntegrationTestSuite) CreateTestWork(title, language string, content st } // Note: CreateWork command might not exist or need context. Assuming it does for now. // If CreateWork also requires auth, this context should be s.AdminCtx - createdWork, err := s.App.Work.Commands.CreateWork(s.AdminCtx, work) + createdWork, err := s.App.Work.Commands.CreateWork(ctx, work) s.Require().NoError(err) if content != "" { @@ -237,7 +246,7 @@ func (s *IntegrationTestSuite) CreateTestWork(title, language string, content st TranslatableType: "works", IsOriginalLanguage: true, // Assuming the first one is original } - _, err = s.App.Translation.Commands.CreateOrUpdateTranslation(s.AdminCtx, translationInput) + _, err = s.App.Translation.Commands.CreateOrUpdateTranslation(ctx, translationInput) s.Require().NoError(err) } return createdWork diff --git a/internal/testutil/mock_user_repository.go b/internal/testutil/mock_user_repository.go index a2dfccc..eee7944 100644 --- a/internal/testutil/mock_user_repository.go +++ b/internal/testutil/mock_user_repository.go @@ -2,182 +2,151 @@ package testutil import ( "context" - "strings" "tercul/internal/domain" + "github.com/stretchr/testify/mock" "gorm.io/gorm" ) -// MockUserRepository is a mock implementation of the UserRepository interface. +// MockUserRepository is a mock implementation of the UserRepository interface using testify/mock. type MockUserRepository struct { - Users []*domain.User + mock.Mock } -// NewMockUserRepository creates a new MockUserRepository. -func NewMockUserRepository() *MockUserRepository { - return &MockUserRepository{Users: []*domain.User{}} -} - -// Create adds a new user to the mock repository. func (m *MockUserRepository) Create(ctx context.Context, user *domain.User) error { - user.ID = uint(len(m.Users) + 1) - m.Users = append(m.Users, user) - return nil + args := m.Called(ctx, user) + return args.Error(0) } -// GetByID retrieves a user by their ID from the mock repository. func (m *MockUserRepository) GetByID(ctx context.Context, id uint) (*domain.User, error) { - for _, u := range m.Users { - if u.ID == id { - return u, nil - } + args := m.Called(ctx, id) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, gorm.ErrRecordNotFound + return args.Get(0).(*domain.User), args.Error(1) } -// FindByUsername retrieves a user by their username from the mock repository. func (m *MockUserRepository) FindByUsername(ctx context.Context, username string) (*domain.User, error) { - for _, u := range m.Users { - if strings.EqualFold(u.Username, username) { - return u, nil - } + args := m.Called(ctx, username) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, gorm.ErrRecordNotFound + return args.Get(0).(*domain.User), args.Error(1) } -// FindByEmail retrieves a user by their email from the mock repository. func (m *MockUserRepository) FindByEmail(ctx context.Context, email string) (*domain.User, error) { - for _, u := range m.Users { - if strings.EqualFold(u.Email, email) { - return u, nil - } + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) } - return nil, gorm.ErrRecordNotFound + return args.Get(0).(*domain.User), args.Error(1) } -// ListByRole retrieves users by their role from the mock repository. func (m *MockUserRepository) ListByRole(ctx context.Context, role domain.UserRole) ([]domain.User, error) { - var users []domain.User - for _, u := range m.Users { - if u.Role == role { - users = append(users, *u) - } + args := m.Called(ctx, role) + if args.Get(0) == nil { + return nil, args.Error(1) } - return users, nil + return args.Get(0).([]domain.User), args.Error(1) } -// The rest of the BaseRepository methods can be stubbed out or implemented as needed. func (m *MockUserRepository) CreateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { - return m.Create(ctx, entity) + args := m.Called(ctx, tx, entity) + return args.Error(0) } + func (m *MockUserRepository) GetByIDWithOptions(ctx context.Context, id uint, options *domain.QueryOptions) (*domain.User, error) { - return m.GetByID(ctx, id) + args := m.Called(ctx, id, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) } + func (m *MockUserRepository) Update(ctx context.Context, entity *domain.User) error { - for i, u := range m.Users { - if u.ID == entity.ID { - m.Users[i] = entity - return nil - } - } - return gorm.ErrRecordNotFound + args := m.Called(ctx, entity) + return args.Error(0) } + func (m *MockUserRepository) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *domain.User) error { - return m.Update(ctx, entity) + args := m.Called(ctx, tx, entity) + return args.Error(0) } + func (m *MockUserRepository) Delete(ctx context.Context, id uint) error { - for i, u := range m.Users { - if u.ID == id { - m.Users = append(m.Users[:i], m.Users[i+1:]...) - return nil - } - } - return gorm.ErrRecordNotFound + args := m.Called(ctx, id) + return args.Error(0) } + func (m *MockUserRepository) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { - return m.Delete(ctx, id) + args := m.Called(ctx, tx, id) + return args.Error(0) } + func (m *MockUserRepository) List(ctx context.Context, page, pageSize int) (*domain.PaginatedResult[domain.User], error) { - start := (page - 1) * pageSize - end := start + pageSize - if start > len(m.Users) { - start = len(m.Users) + args := m.Called(ctx, page, pageSize) + if args.Get(0) == nil { + return nil, args.Error(1) } - if end > len(m.Users) { - end = len(m.Users) - } - - paginatedUsers := m.Users[start:end] - var users []domain.User - for _, u := range paginatedUsers { - users = append(users, *u) - } - - totalCount := int64(len(m.Users)) - totalPages := int(totalCount) / pageSize - if int(totalCount)%pageSize != 0 { - totalPages++ - } - - return &domain.PaginatedResult[domain.User]{ - Items: users, - TotalCount: totalCount, - Page: page, - PageSize: pageSize, - TotalPages: totalPages, - HasNext: page < totalPages, - HasPrev: page > 1, - }, nil + return args.Get(0).(*domain.PaginatedResult[domain.User]), args.Error(1) } func (m *MockUserRepository) ListWithOptions(ctx context.Context, options *domain.QueryOptions) ([]domain.User, error) { - // This is a mock implementation and doesn't handle options. - return m.ListAll(ctx) + args := m.Called(ctx, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]domain.User), args.Error(1) } func (m *MockUserRepository) ListAll(ctx context.Context) ([]domain.User, error) { - var users []domain.User - for _, u := range m.Users { - users = append(users, *u) + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) } - return users, nil + return args.Get(0).([]domain.User), args.Error(1) } func (m *MockUserRepository) Count(ctx context.Context) (int64, error) { - return int64(len(m.Users)), nil + args := m.Called(ctx) + return args.Get(0).(int64), args.Error(1) } func (m *MockUserRepository) CountWithOptions(ctx context.Context, options *domain.QueryOptions) (int64, error) { - // This is a mock implementation and doesn't handle options. - return m.Count(ctx) + args := m.Called(ctx, options) + return args.Get(0).(int64), args.Error(1) } func (m *MockUserRepository) FindWithPreload(ctx context.Context, preloads []string, id uint) (*domain.User, error) { - return m.GetByID(ctx, id) + args := m.Called(ctx, preloads, id) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.User), args.Error(1) } func (m *MockUserRepository) GetAllForSync(ctx context.Context, batchSize, offset int) ([]domain.User, error) { - start := offset - end := start + batchSize - if start > len(m.Users) { - return []domain.User{}, nil + args := m.Called(ctx, batchSize, offset) + if args.Get(0) == nil { + return nil, args.Error(1) } - if end > len(m.Users) { - end = len(m.Users) - } - var users []domain.User - for _, u := range m.Users[start:end] { - users = append(users, *u) - } - return users, nil + return args.Get(0).([]domain.User), args.Error(1) } + func (m *MockUserRepository) Exists(ctx context.Context, id uint) (bool, error) { - _, err := m.GetByID(ctx, id) - return err == nil, nil + args := m.Called(ctx, id) + return args.Bool(0), args.Error(1) } + func (m *MockUserRepository) BeginTx(ctx context.Context) (*gorm.DB, error) { - return nil, nil + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*gorm.DB), args.Error(1) } + func (m *MockUserRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { - return fn(nil) + args := m.Called(ctx, fn) + return args.Error(0) } \ No newline at end of file