From 1655a02a0879235a6bfe49127b768b12e1244640 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 6 Sep 2025 13:01:04 +0000 Subject: [PATCH] Refactor repository tests to be more DRY and maintainable. Introduced a new testing strategy for the data access layer to avoid redundant testing of generic repository methods. - Created a comprehensive test suite for the generic `BaseRepository` using a dedicated `TestEntity`. This suite covers all common CRUD operations, including transactions and error handling, in a single location. - Added a new, focused test suite for `CategoryRepository` that only tests its repository-specific methods, relying on the base repository tests for generic functionality. - Refactored the existing `AuthorRepository` test suite to remove redundant CRUD tests, aligning it with the new, cleaner pattern. - Updated the test utilities to support the new testing strategy. This change significantly improves the maintainability and efficiency of the test suite and provides a clear, future-proof pattern for testing all repositories. --- internal/data/sql/author_repository_test.go | 96 ++----- internal/data/sql/base_repository_test.go | 259 ++++++++++++++++++ internal/data/sql/category_repository_test.go | 111 ++++++++ internal/testutil/integration_test_utils.go | 1 + internal/testutil/test_entity.go | 10 + 5 files changed, 398 insertions(+), 79 deletions(-) create mode 100644 internal/data/sql/base_repository_test.go create mode 100644 internal/data/sql/category_repository_test.go create mode 100644 internal/testutil/test_entity.go diff --git a/internal/data/sql/author_repository_test.go b/internal/data/sql/author_repository_test.go index ef3c44c..48a1c54 100644 --- a/internal/data/sql/author_repository_test.go +++ b/internal/data/sql/author_repository_test.go @@ -17,93 +17,30 @@ func (s *AuthorRepositoryTestSuite) SetupSuite() { s.IntegrationTestSuite.SetupSuite(testutil.DefaultTestConfig()) } -func (s *AuthorRepositoryTestSuite) TestCreateAuthor() { - s.Run("should create a new author", func() { - // Arrange - author := &domain.Author{ - Name: "New Test Author", - TranslatableModel: domain.TranslatableModel{ - Language: "en", - }, - } - - // Act - err := s.AuthorRepo.Create(context.Background(), author) - - // Assert - s.Require().NoError(err) - s.NotZero(author.ID) - - // Verify that the author was actually created in the database - var foundAuthor domain.Author - err = s.DB.First(&foundAuthor, author.ID).Error - s.Require().NoError(err) - s.Equal("New Test Author", foundAuthor.Name) - s.Equal("en", foundAuthor.Language) - }) +func (s *AuthorRepositoryTestSuite) SetupTest() { + s.DB.Exec("DELETE FROM work_authors") + s.DB.Exec("DELETE FROM authors") + s.DB.Exec("DELETE FROM works") } -func (s *AuthorRepositoryTestSuite) TestGetAuthorByID() { - s.Run("should return an author by ID", func() { - // Arrange - author := &domain.Author{Name: "Test Author"} - s.Require().NoError(s.AuthorRepo.Create(context.Background(), author)) - - // Act - foundAuthor, err := s.AuthorRepo.GetByID(context.Background(), author.ID) - - // Assert - s.Require().NoError(err) - s.Require().NotNil(foundAuthor) - s.Equal(author.ID, foundAuthor.ID) - s.Equal("Test Author", foundAuthor.Name) - }) -} - -func (s *AuthorRepositoryTestSuite) TestUpdateAuthor() { - s.Run("should update an existing author", func() { - // Arrange - author := &domain.Author{Name: "Original Name"} - s.Require().NoError(s.AuthorRepo.Create(context.Background(), author)) - author.Name = "Updated Name" - - // Act - err := s.AuthorRepo.Update(context.Background(), author) - - // Assert - s.Require().NoError(err) - var foundAuthor domain.Author - err = s.DB.First(&foundAuthor, author.ID).Error - s.Require().NoError(err) - s.Equal("Updated Name", foundAuthor.Name) - }) -} - -func (s *AuthorRepositoryTestSuite) TestDeleteAuthor() { - s.Run("should delete an existing author", func() { - // Arrange - author := &domain.Author{Name: "To Be Deleted"} - s.Require().NoError(s.AuthorRepo.Create(context.Background(), author)) - - // Act - err := s.AuthorRepo.Delete(context.Background(), author.ID) - - // Assert - s.Require().NoError(err) - var foundAuthor domain.Author - err = s.DB.First(&foundAuthor, author.ID).Error - s.Require().Error(err) - }) +func (s *AuthorRepositoryTestSuite) createAuthor(name string) *domain.Author { + author := &domain.Author{ + Name: name, + TranslatableModel: domain.TranslatableModel{ + Language: "en", + }, + } + err := s.AuthorRepo.Create(context.Background(), author) + s.Require().NoError(err) + return author } func (s *AuthorRepositoryTestSuite) TestListByWorkID() { s.Run("should return all authors for a given work", func() { // Arrange work := s.CreateTestWork("Test Work", "en", "Test content") - author1 := &domain.Author{Name: "Author 1"} - author2 := &domain.Author{Name: "Author 2"} - s.Require().NoError(s.AuthorRepo.Create(context.Background(), author1)) - s.Require().NoError(s.AuthorRepo.Create(context.Background(), author2)) + author1 := s.createAuthor("Author 1") + author2 := s.createAuthor("Author 2") s.Require().NoError(s.DB.Model(&work).Association("Authors").Append([]*domain.Author{author1, author2})) // Act @@ -112,6 +49,7 @@ func (s *AuthorRepositoryTestSuite) TestListByWorkID() { // Assert s.Require().NoError(err) s.Len(authors, 2) + s.ElementsMatch([]string{"Author 1", "Author 2"}, []string{authors[0].Name, authors[1].Name}) }) } diff --git a/internal/data/sql/base_repository_test.go b/internal/data/sql/base_repository_test.go new file mode 100644 index 0000000..2589caa --- /dev/null +++ b/internal/data/sql/base_repository_test.go @@ -0,0 +1,259 @@ +package sql_test + +import ( + "context" + "errors" + "testing" + "tercul/internal/data/sql" + "tercul/internal/domain" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +// BaseRepositoryTestSuite tests the generic BaseRepository implementation. +type BaseRepositoryTestSuite struct { + testutil.IntegrationTestSuite + repo domain.BaseRepository[testutil.TestEntity] +} + +// SetupSuite initializes the test suite, database, and repository. +func (s *BaseRepositoryTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(testutil.DefaultTestConfig()) + s.repo = sql.NewBaseRepositoryImpl[testutil.TestEntity](s.DB) +} + +// SetupTest cleans the database before each test. +func (s *BaseRepositoryTestSuite) SetupTest() { + s.DB.Exec("DELETE FROM test_entities") +} + +// TearDownSuite drops the test table after the suite finishes. +func (s *BaseRepositoryTestSuite) TearDownSuite() { + s.DB.Migrator().DropTable(&testutil.TestEntity{}) +} + +// TestBaseRepository runs the entire test suite. +func TestBaseRepository(t *testing.T) { + suite.Run(t, new(BaseRepositoryTestSuite)) +} + +// createTestEntity is a helper to create a test entity. +func (s *BaseRepositoryTestSuite) createTestEntity(name string) *testutil.TestEntity { + entity := &testutil.TestEntity{Name: name} + err := s.repo.Create(context.Background(), entity) + s.Require().NoError(err) + s.Require().NotZero(entity.ID) + return entity +} + +func (s *BaseRepositoryTestSuite) TestCreate() { + s.Run("should create a new entity", func() { + // Arrange + ctx := context.Background() + entity := &testutil.TestEntity{Name: "Test Create"} + + // Act + err := s.repo.Create(ctx, entity) + + // Assert + s.Require().NoError(err) + s.NotZero(entity.ID) + + // Verify in DB + var foundEntity testutil.TestEntity + err = s.DB.First(&foundEntity, entity.ID).Error + s.Require().NoError(err) + s.Equal("Test Create", foundEntity.Name) + }) + + s.Run("should return error for nil entity", func() { + err := s.repo.Create(context.Background(), nil) + s.ErrorIs(err, sql.ErrInvalidInput) + }) + + s.Run("should return error for nil context", func() { + err := s.repo.Create(nil, &testutil.TestEntity{Name: "Test Context"}) + s.ErrorIs(err, sql.ErrContextRequired) + }) +} + +func (s *BaseRepositoryTestSuite) TestGetByID() { + s.Run("should return an entity by ID", func() { + // Arrange + created := s.createTestEntity("Test GetByID") + + // Act + found, err := s.repo.GetByID(context.Background(), created.ID) + + // Assert + s.Require().NoError(err) + s.Require().NotNil(found) + s.Equal(created.ID, found.ID) + s.Equal(created.Name, found.Name) + }) + + s.Run("should return ErrEntityNotFound for non-existent ID", func() { + _, err := s.repo.GetByID(context.Background(), 99999) + s.ErrorIs(err, sql.ErrEntityNotFound) + }) + + s.Run("should return ErrInvalidID for zero ID", func() { + _, err := s.repo.GetByID(context.Background(), 0) + s.ErrorIs(err, sql.ErrInvalidID) + }) +} + +func (s *BaseRepositoryTestSuite) TestUpdate() { + s.Run("should update an existing entity", func() { + // Arrange + created := s.createTestEntity("Original Name") + created.Name = "Updated Name" + + // Act + err := s.repo.Update(context.Background(), created) + + // Assert + s.Require().NoError(err) + found, getErr := s.repo.GetByID(context.Background(), created.ID) + s.Require().NoError(getErr) + s.Equal("Updated Name", found.Name) + }) +} + +func (s *BaseRepositoryTestSuite) TestDelete() { + s.Run("should delete an existing entity", func() { + // Arrange + created := s.createTestEntity("To Be Deleted") + + // Act + err := s.repo.Delete(context.Background(), created.ID) + + // Assert + s.Require().NoError(err) + _, getErr := s.repo.GetByID(context.Background(), created.ID) + s.ErrorIs(getErr, sql.ErrEntityNotFound) + }) + + s.Run("should return ErrEntityNotFound when deleting non-existent entity", func() { + err := s.repo.Delete(context.Background(), 99999) + s.ErrorIs(err, sql.ErrEntityNotFound) + }) +} + +func (s *BaseRepositoryTestSuite) TestList() { + // Arrange + s.createTestEntity("Entity 1") + s.createTestEntity("Entity 2") + s.createTestEntity("Entity 3") + + s.Run("should return a paginated list of entities", func() { + // Act + result, err := s.repo.List(context.Background(), 1, 2) + + // Assert + s.Require().NoError(err) + s.Equal(int64(3), result.TotalCount) + s.Equal(2, result.TotalPages) + s.Equal(1, result.Page) + s.Equal(2, result.PageSize) + s.True(result.HasNext) + s.False(result.HasPrev) + s.Len(result.Items, 2) + }) +} + +func (s *BaseRepositoryTestSuite) TestListWithOptions() { + // Arrange + s.createTestEntity("Apple") + s.createTestEntity("Banana") + s.createTestEntity("Avocado") + + s.Run("should filter with Where clause", func() { + // Act + options := &domain.QueryOptions{ + Where: map[string]interface{}{"name LIKE ?": "A%"}, + } + results, err := s.repo.ListWithOptions(context.Background(), options) + + // Assert + s.Require().NoError(err) + s.Len(results, 2) + }) + + s.Run("should order results", func() { + // Act + options := &domain.QueryOptions{OrderBy: "name desc"} + results, err := s.repo.ListWithOptions(context.Background(), options) + + // Assert + s.Require().NoError(err) + s.Len(results, 3) + s.Equal("Banana", results[0].Name) + s.Equal("Avocado", results[1].Name) + s.Equal("Apple", results[2].Name) + }) +} + +func (s *BaseRepositoryTestSuite) TestCount() { + // Arrange + s.createTestEntity("Entity 1") + s.createTestEntity("Entity 2") + + s.Run("should return the total count of entities", func() { + // Act + count, err := s.repo.Count(context.Background()) + + // Assert + s.Require().NoError(err) + s.Equal(int64(2), count) + }) +} + +func (s *BaseRepositoryTestSuite) TestWithTx() { + s.Run("should commit transaction on success", func() { + // Arrange + var createdID uint + + // Act + err := s.repo.WithTx(context.Background(), func(tx *gorm.DB) error { + entity := &testutil.TestEntity{Name: "TX Commit"} + repoInTx := sql.NewBaseRepositoryImpl[testutil.TestEntity](tx) + if err := repoInTx.Create(context.Background(), entity); err != nil { + return err + } + createdID = entity.ID + return nil + }) + + // Assert + s.Require().NoError(err) + _, getErr := s.repo.GetByID(context.Background(), createdID) + s.NoError(getErr, "Entity should exist after commit") + }) + + s.Run("should rollback transaction on error", func() { + // Arrange + var createdID uint + simulatedErr := errors.New("simulated error") + + // Act + err := s.repo.WithTx(context.Background(), func(tx *gorm.DB) error { + entity := &testutil.TestEntity{Name: "TX Rollback"} + repoInTx := sql.NewBaseRepositoryImpl[testutil.TestEntity](tx) + if err := repoInTx.Create(context.Background(), entity); err != nil { + return err + } + createdID = entity.ID + return simulatedErr // Force a rollback + }) + + // Assert + s.Require().Error(err) + s.ErrorIs(err, simulatedErr) + + _, getErr := s.repo.GetByID(context.Background(), createdID) + s.ErrorIs(getErr, sql.ErrEntityNotFound, "Entity should not exist after rollback") + }) +} diff --git a/internal/data/sql/category_repository_test.go b/internal/data/sql/category_repository_test.go new file mode 100644 index 0000000..3aa210c --- /dev/null +++ b/internal/data/sql/category_repository_test.go @@ -0,0 +1,111 @@ +package sql_test + +import ( + "context" + "testing" + "tercul/internal/domain" + "tercul/internal/testutil" + + "github.com/stretchr/testify/suite" +) + +type CategoryRepositoryTestSuite struct { + testutil.IntegrationTestSuite +} + +func (s *CategoryRepositoryTestSuite) SetupSuite() { + s.IntegrationTestSuite.SetupSuite(testutil.DefaultTestConfig()) +} + +func (s *CategoryRepositoryTestSuite) SetupTest() { + s.DB.Exec("DELETE FROM work_categories") + s.DB.Exec("DELETE FROM categories") + s.DB.Exec("DELETE FROM works") +} + +func TestCategoryRepository(t *testing.T) { + suite.Run(t, new(CategoryRepositoryTestSuite)) +} + +func (s *CategoryRepositoryTestSuite) createCategory(name string, parentID *uint) *domain.Category { + category := &domain.Category{Name: name, ParentID: parentID} + err := s.CategoryRepo.Create(context.Background(), category) + s.Require().NoError(err) + s.Require().NotZero(category.ID) + return category +} + +func (s *CategoryRepositoryTestSuite) TestFindByName() { + s.Run("should find a category by its name", func() { + // Arrange + s.createCategory("Fiction", nil) + + // Act + found, err := s.CategoryRepo.FindByName(context.Background(), "Fiction") + + // Assert + s.Require().NoError(err) + s.Require().NotNil(found) + s.Equal("Fiction", found.Name) + }) + + s.Run("should return error if not found", func() { + _, err := s.CategoryRepo.FindByName(context.Background(), "NonExistent") + s.Require().Error(err) + }) +} + +func (s *CategoryRepositoryTestSuite) TestListByWorkID() { + s.Run("should return all categories for a given work", func() { + // Arrange + work := s.CreateTestWork("Test Work", "en", "Test content") + cat1 := s.createCategory("Science Fiction", nil) + cat2 := s.createCategory("Cyberpunk", &cat1.ID) + + err := s.DB.Model(&work).Association("Categories").Append([]*domain.Category{cat1, cat2}) + s.Require().NoError(err) + + // Act + categories, err := s.CategoryRepo.ListByWorkID(context.Background(), work.ID) + + // Assert + s.Require().NoError(err) + s.Len(categories, 2) + s.ElementsMatch([]string{"Science Fiction", "Cyberpunk"}, []string{categories[0].Name, categories[1].Name}) + }) +} + +func (s *CategoryRepositoryTestSuite) TestListByParentID() { + s.Run("should return top-level categories when parent ID is nil", func() { + // Arrange + s.createCategory("Root 1", nil) + s.createCategory("Root 2", nil) + child := s.createCategory("Child 1", &[]uint{1}[0]) // Create a child to ensure it's not returned + + // Act + categories, err := s.CategoryRepo.ListByParentID(context.Background(), nil) + + // Assert + s.Require().NoError(err) + s.Len(categories, 2) + s.NotContains(categories, child) + }) + + s.Run("should return child categories for a given parent ID", func() { + // Arrange + parent := s.createCategory("Parent", nil) + s.createCategory("Sub-Child 1", &parent.ID) + s.createCategory("Sub-Child 2", &parent.ID) + s.createCategory("Another Parent", nil) + + // Act + categories, err := s.CategoryRepo.ListByParentID(context.Background(), &parent.ID) + + // Assert + s.Require().NoError(err) + s.Len(categories, 2) + for _, cat := range categories { + s.Equal(parent.ID, *cat.ParentID) + } + }) +} diff --git a/internal/testutil/integration_test_utils.go b/internal/testutil/integration_test_utils.go index cd7c38e..172f23b 100644 --- a/internal/testutil/integration_test_utils.go +++ b/internal/testutil/integration_test_utils.go @@ -171,6 +171,7 @@ func (s *IntegrationTestSuite) setupInMemoryDB(config *TestConfig) { &domain.TextMetadata{}, &domain.PoeticAnalysis{}, &domain.TranslationField{}, + &TestEntity{}, // Add TestEntity for generic repository tests ); err != nil { s.T().Fatalf("Failed to run migrations: %v", err) } diff --git a/internal/testutil/test_entity.go b/internal/testutil/test_entity.go new file mode 100644 index 0000000..035228f --- /dev/null +++ b/internal/testutil/test_entity.go @@ -0,0 +1,10 @@ +package testutil + +import "gorm.io/gorm" + +// TestEntity is a simple struct used for testing the generic BaseRepository. +// It is not used in the main application. +type TestEntity struct { + gorm.Model + Name string +}