package repositories_test import ( "context" "errors" "tercul/internal/models" repositories2 "tercul/internal/repositories" "testing" "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "gorm.io/gorm" "tercul/internal/testutil" ) // TestModel is a simple entity used for cached repository tests type TestModel struct { models.BaseModel Name string Description string } // MockCache is a mock implementation of the Cache interface type MockCache struct { mock.Mock } func (m *MockCache) Get(ctx context.Context, key string, value interface{}) error { args := m.Called(ctx, key, value) return args.Error(0) } func (m *MockCache) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error { args := m.Called(ctx, key, value, expiration) return args.Error(0) } func (m *MockCache) Delete(ctx context.Context, key string) error { args := m.Called(ctx, key) return args.Error(0) } func (m *MockCache) Clear(ctx context.Context) error { args := m.Called(ctx) return args.Error(0) } func (m *MockCache) GetMulti(ctx context.Context, keys []string) (map[string][]byte, error) { args := m.Called(ctx, keys) return args.Get(0).(map[string][]byte), args.Error(1) } func (m *MockCache) SetMulti(ctx context.Context, items map[string]interface{}, expiration time.Duration) error { args := m.Called(ctx, items, expiration) return args.Error(0) } // MockRepository is a mock implementation of the BaseRepository interface type MockRepository[T any] struct { mock.Mock } func (m *MockRepository[T]) Create(ctx context.Context, entity *T) error { args := m.Called(ctx, entity) return args.Error(0) } func (m *MockRepository[T]) CreateInTx(ctx context.Context, tx *gorm.DB, entity *T) error { return nil } func (m *MockRepository[T]) GetByID(ctx context.Context, id uint) (*T, error) { args := m.Called(ctx, id) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*T), args.Error(1) } func (m *MockRepository[T]) GetByIDWithOptions(ctx context.Context, id uint, options *repositories2.QueryOptions) (*T, error) { return nil, nil } func (m *MockRepository[T]) Update(ctx context.Context, entity *T) error { args := m.Called(ctx, entity) return args.Error(0) } func (m *MockRepository[T]) UpdateInTx(ctx context.Context, tx *gorm.DB, entity *T) error { return nil } func (m *MockRepository[T]) Delete(ctx context.Context, id uint) error { args := m.Called(ctx, id) return args.Error(0) } func (m *MockRepository[T]) DeleteInTx(ctx context.Context, tx *gorm.DB, id uint) error { return nil } func (m *MockRepository[T]) List(ctx context.Context, page, pageSize int) (*repositories2.PaginatedResult[T], error) { args := m.Called(ctx, page, pageSize) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*repositories2.PaginatedResult[T]), args.Error(1) } func (m *MockRepository[T]) ListWithOptions(ctx context.Context, options *repositories2.QueryOptions) ([]T, error) { var z []T return z, nil } func (m *MockRepository[T]) ListAll(ctx context.Context) ([]T, error) { args := m.Called(ctx) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).([]T), args.Error(1) } func (m *MockRepository[T]) GetAllForSync(ctx context.Context, batchSize, offset int) ([]T, error) { args := m.Called(ctx, batchSize, offset) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).([]T), args.Error(1) } func (m *MockRepository[T]) Count(ctx context.Context) (int64, error) { args := m.Called(ctx) return args.Get(0).(int64), args.Error(1) } func (m *MockRepository[T]) CountWithOptions(ctx context.Context, options *repositories2.QueryOptions) (int64, error) { return 0, nil } func (m *MockRepository[T]) FindWithPreload(ctx context.Context, preloads []string, id uint) (*T, error) { args := m.Called(ctx, preloads, id) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*T), args.Error(1) } func (m *MockRepository[T]) Exists(ctx context.Context, id uint) (bool, error) { return false, nil } func (m *MockRepository[T]) BeginTx(ctx context.Context) (*gorm.DB, error) { return nil, nil } func (m *MockRepository[T]) WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error { return nil } // CachedRepositorySuite is a test suite for the CachedRepository type CachedRepositorySuite struct { testutil.BaseSuite mockRepo *MockRepository[TestModel] mockCache *MockCache repo *repositories2.CachedRepository[TestModel] } // SetupTest sets up each test func (s *CachedRepositorySuite) SetupTest() { s.mockRepo = new(MockRepository[TestModel]) s.mockCache = new(MockCache) s.repo = repositories2.NewCachedRepository[TestModel]( s.mockRepo, s.mockCache, nil, "test_model", 1*time.Hour, ) } // TestGetByID tests the GetByID method with cache hit func (s *CachedRepositorySuite) TestGetByIDCacheHit() { // Setup id := uint(1) expectedModel := &TestModel{ BaseModel: models.BaseModel{ ID: id, }, Name: "Test Model", Description: "This is a test model", } // Mock cache hit s.mockCache.On("Get", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { // Set the value to simulate cache hit value := args.Get(2).(*TestModel) *value = *expectedModel }). Return(nil) // Execute ctx := context.Background() result, err := s.repo.GetByID(ctx, id) // Assert s.Require().NoError(err) s.Require().NotNil(result) s.Equal(expectedModel.ID, result.ID) s.Equal(expectedModel.Name, result.Name) s.Equal(expectedModel.Description, result.Description) // Verify mocks s.mockCache.AssertCalled(s.T(), "Get", mock.Anything, mock.Anything, mock.Anything) s.mockRepo.AssertNotCalled(s.T(), "GetByID", mock.Anything, mock.Anything) } // TestGetByID tests the GetByID method with cache miss func (s *CachedRepositorySuite) TestGetByIDCacheMiss() { // Setup id := uint(1) expectedModel := &TestModel{ BaseModel: models.BaseModel{ ID: id, }, Name: "Test Model", Description: "This is a test model", } // Mock cache miss s.mockCache.On("Get", mock.Anything, mock.Anything, mock.Anything). Return(errors.New("cache miss")) // Mock repository s.mockRepo.On("GetByID", mock.Anything, id). Return(expectedModel, nil) // Mock cache set s.mockCache.On("Set", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil) // Execute ctx := context.Background() result, err := s.repo.GetByID(ctx, id) // Assert s.Require().NoError(err) s.Require().NotNil(result) s.Equal(expectedModel.ID, result.ID) s.Equal(expectedModel.Name, result.Name) s.Equal(expectedModel.Description, result.Description) // Verify mocks s.mockCache.AssertCalled(s.T(), "Get", mock.Anything, mock.Anything, mock.Anything) s.mockRepo.AssertCalled(s.T(), "GetByID", mock.Anything, id) s.mockCache.AssertCalled(s.T(), "Set", mock.Anything, mock.Anything, mock.Anything, mock.Anything) } // TestCreate tests the Create method func (s *CachedRepositorySuite) TestCreate() { // Setup model := &TestModel{ Name: "Test Model", Description: "This is a test model", } // Mock repository s.mockRepo.On("Create", mock.Anything, model). Return(nil) // Execute ctx := context.Background() err := s.repo.Create(ctx, model) // Assert s.Require().NoError(err) // Verify mocks s.mockRepo.AssertCalled(s.T(), "Create", mock.Anything, model) } // TestUpdate tests the Update method func (s *CachedRepositorySuite) TestUpdate() { // Setup model := &TestModel{ BaseModel: models.BaseModel{ ID: 1, }, Name: "Test Model", Description: "This is a test model", } // Mock repository s.mockRepo.On("Update", mock.Anything, model). Return(nil) // Execute ctx := context.Background() // Expect cache delete during update invalidation s.mockCache.On("Delete", mock.Anything, mock.Anything).Return(nil) err := s.repo.Update(ctx, model) // Assert s.Require().NoError(err) // Verify mocks s.mockRepo.AssertCalled(s.T(), "Update", mock.Anything, model) } // TestDelete tests the Delete method func (s *CachedRepositorySuite) TestDelete() { // Setup id := uint(1) // Mock repository and cache delete s.mockRepo.On("Delete", mock.Anything, id).Return(nil) s.mockCache.On("Delete", mock.Anything, mock.Anything).Return(nil) // Execute ctx := context.Background() err := s.repo.Delete(ctx, id) // Assert s.Require().NoError(err) // Verify mocks s.mockRepo.AssertCalled(s.T(), "Delete", mock.Anything, id) } // TestList tests the List method with cache hit func (s *CachedRepositorySuite) TestListCacheHit() { // Setup page := 1 pageSize := 10 expectedResult := &repositories2.PaginatedResult[TestModel]{ Items: []TestModel{ { BaseModel: models.BaseModel{ ID: 1, }, Name: "Test Model 1", Description: "This is test model 1", }, { BaseModel: models.BaseModel{ ID: 2, }, Name: "Test Model 2", Description: "This is test model 2", }, }, TotalCount: 2, Page: page, PageSize: pageSize, TotalPages: 1, } // Mock cache hit s.mockCache.On("Get", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { // Set the value to simulate cache hit value := args.Get(2).(*repositories2.PaginatedResult[TestModel]) *value = *expectedResult }). Return(nil) // Execute ctx := context.Background() result, err := s.repo.List(ctx, page, pageSize) // Assert s.Require().NoError(err) s.Require().NotNil(result) s.Equal(expectedResult.TotalCount, result.TotalCount) s.Equal(expectedResult.Page, result.Page) s.Equal(expectedResult.PageSize, result.PageSize) s.Equal(expectedResult.TotalPages, result.TotalPages) s.Equal(len(expectedResult.Items), len(result.Items)) // Verify mocks s.mockCache.AssertCalled(s.T(), "Get", mock.Anything, mock.Anything, mock.Anything) s.mockRepo.AssertNotCalled(s.T(), "List", mock.Anything, mock.Anything, mock.Anything) } // TestList tests the List method with cache miss func (s *CachedRepositorySuite) TestListCacheMiss() { // Setup page := 1 pageSize := 10 expectedResult := &repositories2.PaginatedResult[TestModel]{ Items: []TestModel{ { BaseModel: models.BaseModel{ ID: 1, }, Name: "Test Model 1", Description: "This is test model 1", }, { BaseModel: models.BaseModel{ ID: 2, }, Name: "Test Model 2", Description: "This is test model 2", }, }, TotalCount: 2, Page: page, PageSize: pageSize, TotalPages: 1, } // Mock cache miss s.mockCache.On("Get", mock.Anything, mock.Anything, mock.Anything). Return(errors.New("cache miss")) // Mock repository s.mockRepo.On("List", mock.Anything, page, pageSize). Return(expectedResult, nil) // Mock cache set s.mockCache.On("Set", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil) // Execute ctx := context.Background() result, err := s.repo.List(ctx, page, pageSize) // Assert s.Require().NoError(err) s.Require().NotNil(result) s.Equal(expectedResult.TotalCount, result.TotalCount) s.Equal(expectedResult.Page, result.Page) s.Equal(expectedResult.PageSize, result.PageSize) s.Equal(expectedResult.TotalPages, result.TotalPages) s.Equal(len(expectedResult.Items), len(result.Items)) // Verify mocks s.mockCache.AssertCalled(s.T(), "Get", mock.Anything, mock.Anything, mock.Anything) s.mockRepo.AssertCalled(s.T(), "List", mock.Anything, page, pageSize) s.mockCache.AssertCalled(s.T(), "Set", mock.Anything, mock.Anything, mock.Anything, mock.Anything) } // TestCachedRepositorySuite runs the test suite func TestCachedRepositorySuite(t *testing.T) { suite.Run(t, new(CachedRepositorySuite)) }