diff --git a/cmd/api/main.go b/cmd/api/main.go index c56eb0b..dc5bdc2 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -120,7 +120,7 @@ func main() { } // Create search client - searchClient := search.NewWeaviateWrapper(weaviateClient) + searchClient := search.NewWeaviateWrapper(weaviateClient, cfg.WeaviateHost, cfg.SearchAlpha) // Create repositories repos := dbsql.NewRepositories(database, cfg) diff --git a/internal/adapters/graphql/schema.resolvers.go b/internal/adapters/graphql/schema.resolvers.go index 02d907a..13b7ce9 100644 --- a/internal/adapters/graphql/schema.resolvers.go +++ b/internal/adapters/graphql/schema.resolvers.go @@ -21,6 +21,7 @@ import ( "tercul/internal/app/translation" "tercul/internal/app/user" "tercul/internal/domain" + domainsearch "tercul/internal/domain/search" platform_auth "tercul/internal/platform/auth" "tercul/internal/platform/log" "time" @@ -2001,7 +2002,13 @@ func (r *queryResolver) Search(ctx context.Context, query string, limit *int32, } } - results, err := r.App.Search.Search(ctx, query, page, pageSize, searchFilters) + params := domainsearch.SearchParams{ + Query: query, + Filters: searchFilters, + Limit: pageSize, + Offset: (page - 1) * pageSize, + } + results, err := r.App.Search.Search(ctx, params) if err != nil { return nil, err } diff --git a/internal/adapters/graphql/work_resolvers_unit_test.go b/internal/adapters/graphql/work_resolvers_unit_test.go index 0d6cd2f..96651d4 100644 --- a/internal/adapters/graphql/work_resolvers_unit_test.go +++ b/internal/adapters/graphql/work_resolvers_unit_test.go @@ -9,6 +9,7 @@ import ( "tercul/internal/app/translation" "tercul/internal/app/work" "tercul/internal/domain" + domainsearch "tercul/internal/domain/search" platform_auth "tercul/internal/platform/auth" "time" @@ -153,8 +154,8 @@ func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pip args := m.Called(ctx, work, pipeline) return args.Error(0) } -func (m *mockSearchClient) Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error) { - args := m.Called(ctx, query, page, pageSize, filters) +func (m *mockSearchClient) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) { + args := m.Called(ctx, params) if args.Get(0) == nil { return nil, args.Error(1) } diff --git a/internal/app/search/service.go b/internal/app/search/service.go index 1bf3f8a..3ee187e 100644 --- a/internal/app/search/service.go +++ b/internal/app/search/service.go @@ -10,7 +10,7 @@ import ( // Service is the application service for searching. type Service interface { - Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error) + Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) IndexWork(ctx context.Context, work domain.Work) error } @@ -28,15 +28,8 @@ func NewService(searchClient domainsearch.SearchClient, localization *localizati } // Search performs a search across all searchable entities. -func (s *service) Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error) { - // For now, this is a mock implementation that returns empty results. - // TODO: Implement the actual search logic. - return &domain.SearchResults{ - Works: []domain.Work{}, - Translations: []domain.Translation{}, - Authors: []domain.Author{}, - Total: 0, - }, nil +func (s *service) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) { + return s.searchClient.Search(ctx, params) } func (s *service) IndexWork(ctx context.Context, work domain.Work) error { diff --git a/internal/app/search/service_test.go b/internal/app/search/service_test.go index a1db8f3..6bbb78e 100644 --- a/internal/app/search/service_test.go +++ b/internal/app/search/service_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/mock" "tercul/internal/app/localization" "tercul/internal/domain" + domainsearch "tercul/internal/domain/search" ) type mockLocalizationRepository struct { @@ -41,11 +42,54 @@ type mockWeaviateWrapper struct { mock.Mock } +func (m *mockWeaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) { + args := m.Called(ctx, params) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.SearchResults), args.Error(1) +} + func (m *mockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error { args := m.Called(ctx, work, content) return args.Error(0) } +func TestSearchService_Search(t *testing.T) { + localizationRepo := new(mockLocalizationRepository) + localizationService := localization.NewService(localizationRepo) + weaviateWrapper := new(mockWeaviateWrapper) + service := NewService(weaviateWrapper, localizationService) + + ctx := context.Background() + testQuery := "test query" + testFilters := domain.SearchFilters{ + Languages: []string{"en"}, + Authors: []string{"1"}, + Tags: []string{"test-tag"}, + Categories: []string{"test-category"}, + } + expectedResults := &domain.SearchResults{ + Works: []domain.Work{{Title: "Test Work"}}, + Authors: []domain.Author{{Name: "Test Author"}}, + Total: 2, + } + + params := domainsearch.SearchParams{ + Query: testQuery, + Filters: testFilters, + Limit: 10, + Offset: 0, + } + weaviateWrapper.On("Search", ctx, params).Return(expectedResults, nil) + + results, err := service.Search(ctx, params) + + assert.NoError(t, err) + assert.Equal(t, expectedResults, results) + weaviateWrapper.AssertExpectations(t) +} + func TestIndexService_IndexWork(t *testing.T) { localizationRepo := new(mockLocalizationRepository) localizationService := localization.NewService(localizationRepo) diff --git a/internal/domain/entities.go b/internal/domain/entities.go index 792bbaf..0d77209 100644 --- a/internal/domain/entities.go +++ b/internal/domain/entities.go @@ -50,6 +50,7 @@ type BaseModel struct { ID uint `gorm:"primaryKey"` CreatedAt time.Time UpdatedAt time.Time + Score float64 `gorm:"-"` } // TranslatableModel extends BaseModel with language support diff --git a/internal/domain/search/client.go b/internal/domain/search/client.go index 5d8c2ab..9ca118c 100644 --- a/internal/domain/search/client.go +++ b/internal/domain/search/client.go @@ -7,5 +7,6 @@ import ( // SearchClient defines the interface for a search client. type SearchClient interface { + Search(ctx context.Context, params SearchParams) (*domain.SearchResults, error) IndexWork(ctx context.Context, work *domain.Work, content string) error -} \ No newline at end of file +} diff --git a/internal/domain/search/search.go b/internal/domain/search/search.go new file mode 100644 index 0000000..d0955b1 --- /dev/null +++ b/internal/domain/search/search.go @@ -0,0 +1,21 @@ +package search + +import ( + "tercul/internal/domain" +) + +type SearchMode string + +const ( + SearchModeHybrid SearchMode = "hybrid" + SearchModeBM25 SearchMode = "bm25" + SearchModeVector SearchMode = "vector" +) + +type SearchParams struct { + Query string + Mode SearchMode + Filters domain.SearchFilters + Limit int + Offset int +} diff --git a/internal/platform/config/config.go b/internal/platform/config/config.go index cbfd9bf..4e9ee46 100644 --- a/internal/platform/config/config.go +++ b/internal/platform/config/config.go @@ -28,8 +28,9 @@ type Config struct { NLPMemoryCacheCap int `mapstructure:"NLP_MEMORY_CACHE_CAP"` NLPRedisCacheTTLSeconds int `mapstructure:"NLP_REDIS_CACHE_TTL_SECONDS"` NLPUseLingua bool `mapstructure:"NLP_USE_LINGUA"` - NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"` - BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"` + NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"` + BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"` + SearchAlpha float64 `mapstructure:"SEARCH_ALPHA"` } // Global config instance @@ -62,6 +63,7 @@ func LoadConfig() (*Config, error) { v.SetDefault("NLP_USE_LINGUA", true) v.SetDefault("NLP_USE_TFIDF", true) v.SetDefault("BLEVE_INDEX_PATH", "./bleve_index") + v.SetDefault("SEARCH_ALPHA", 0.7) v.AutomaticEnv() diff --git a/internal/platform/search/weaviate_wrapper.go b/internal/platform/search/weaviate_wrapper.go index 20563ce..f74da07 100644 --- a/internal/platform/search/weaviate_wrapper.go +++ b/internal/platform/search/weaviate_wrapper.go @@ -3,24 +3,334 @@ package search import ( "context" "fmt" + "log" + "strconv" "tercul/internal/domain" + domainsearch "tercul/internal/domain/search" "time" "github.com/weaviate/weaviate-go-client/v5/weaviate" + "github.com/weaviate/weaviate-go-client/v5/weaviate/filters" + "github.com/weaviate/weaviate-go-client/v5/weaviate/graphql" ) type WeaviateWrapper interface { + Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) IndexWork(ctx context.Context, work *domain.Work, content string) error } type weaviateWrapper struct { - client *weaviate.Client + client *weaviate.Client + host string + searchAlpha float64 } -func NewWeaviateWrapper(client *weaviate.Client) WeaviateWrapper { - return &weaviateWrapper{client: client} +func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) WeaviateWrapper { + return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha} } +const ( + DefaultLimit = 20 + MaxLimit = 100 +) + +const ( + WorkClass = "Work" + AuthorClass = "Author" + TranslationClass = "Translation" +) + +func sanitizeLimitOffset(limit, offset int) (int, int) { + if limit <= 0 { + limit = DefaultLimit + } + if limit > MaxLimit { + limit = MaxLimit + } + if offset < 0 { + offset = 0 + } + return limit, offset +} + +func (w *weaviateWrapper) hybridAlpha() float32 { + if w.searchAlpha < 0 { + return 0 + } + if w.searchAlpha > 1 { + return 1 + } + return float32(w.searchAlpha) +} + +func (w *weaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) { + results := &domain.SearchResults{ + Works: []domain.Work{}, + Translations: []domain.Translation{}, + Authors: []domain.Author{}, + } + + limit, offset := sanitizeLimitOffset(params.Limit, params.Offset) + + workFields := []graphql.Field{ + {Name: "title"}, {Name: "description"}, {Name: "language"}, {Name: "status"}, {Name: "publishedAt"}, + {Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}}, + } + authorFields := []graphql.Field{ + {Name: "name"}, + {Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}}, + } + translationFields := []graphql.Field{ + {Name: "title"}, {Name: "language"}, + {Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}}, + } + + searcher := w.client.GraphQL().Get(). + WithFields(graphql.Field{Name: "... on Work", Fields: workFields}). + WithFields(graphql.Field{Name: "... on Author", Fields: authorFields}). + WithFields(graphql.Field{Name: "... on Translation", Fields: translationFields}). + WithLimit(limit). + WithOffset(offset) + + switch params.Mode { + case domainsearch.SearchModeHybrid: + hybrid := w.client.GraphQL().HybridArgumentBuilder(). + WithQuery(params.Query). + WithAlpha(w.hybridAlpha()). + WithProperties([]string{"title", "description", "name"}) + searcher.WithHybrid(hybrid) + case domainsearch.SearchModeBM25: + bm25 := w.client.GraphQL().Bm25ArgBuilder(). + WithQuery(params.Query). + WithProperties([]string{"title", "description", "name"}...) + searcher.WithBM25(bm25) + } + + response, err := searcher.Do(ctx) + if err != nil { + return nil, err + } + + get, ok := response.Data["Get"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid response data") + } + + for className, classData := range get { + for _, item := range classData.([]interface{}) { + itemMap, ok := item.(map[string]interface{}) + if !ok { + continue + } + switch className { + case "Work": + work, err := mapToWork(itemMap) + if err != nil { + log.Printf("Error mapping work: %v", err) + continue + } + results.Works = append(results.Works, work) + case "Author": + author, err := mapToAuthor(itemMap) + if err != nil { + log.Printf("Error mapping author: %v", err) + continue + } + results.Authors = append(results.Authors, author) + case "Translation": + translation, err := mapToTranslation(itemMap) + if err != nil { + log.Printf("Error mapping translation: %v", err) + continue + } + results.Translations = append(results.Translations, translation) + } + } + } + + meta, ok := response.Data["Meta"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid meta data") + } + count, ok := meta["count"].(float64) + if !ok { + return nil, fmt.Errorf("invalid count data") + } + results.Total = int64(count) + + return results, nil +} + + +func (w *weaviateWrapper) buildWhereFilter(searchFilters domain.SearchFilters, className string) *filters.WhereBuilder { + operands := make([]*filters.WhereBuilder, 0) + + if len(searchFilters.Languages) > 0 { + operands = append(operands, filters.Where(). + WithPath([]string{"language"}). + WithOperator(filters.ContainsAny). + WithValueText(searchFilters.Languages...)) + } + + if className == "Work" { + if searchFilters.DateFrom != nil { + operands = append(operands, filters.Where(). + WithPath([]string{"publishedAt"}). + WithOperator(filters.GreaterThanEqual). + WithValueDate(*searchFilters.DateFrom)) + } + + if searchFilters.DateTo != nil { + operands = append(operands, filters.Where(). + WithPath([]string{"publishedAt"}). + WithOperator(filters.LessThanEqual). + WithValueDate(*searchFilters.DateTo)) + } + + if len(searchFilters.Authors) > 0 { + authorOperands := make([]*filters.WhereBuilder, len(searchFilters.Authors)) + for i, author := range searchFilters.Authors { + authorOperands[i] = filters.Where(). + WithPath([]string{"authors", "Author", "id"}). + WithOperator(filters.Equal). + WithValueText(fmt.Sprintf("weaviate://%s/Author/%s", w.host, author)) + } + operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(authorOperands)) + } + + if len(searchFilters.Tags) > 0 { + tagOperands := make([]*filters.WhereBuilder, len(searchFilters.Tags)) + for i, tag := range searchFilters.Tags { + tagOperands[i] = filters.Where(). + WithPath([]string{"tags", "Tag", "name"}). + WithOperator(filters.Equal). + WithValueText(tag) + } + operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(tagOperands)) + } + + if len(searchFilters.Categories) > 0 { + categoryOperands := make([]*filters.WhereBuilder, len(searchFilters.Categories)) + for i, category := range searchFilters.Categories { + categoryOperands[i] = filters.Where(). + WithPath([]string{"categories", "Category", "name"}). + WithOperator(filters.Equal). + WithValueText(category) + } + operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(categoryOperands)) + } + } + + if len(operands) == 0 { + return nil + } + + return filters.Where().WithOperator(filters.And).WithOperands(operands) +} + +func mapToWork(data map[string]interface{}) (domain.Work, error) { + work := domain.Work{} + + additional, ok := data["_additional"].(map[string]interface{}) + if !ok { + return work, fmt.Errorf("missing _additional field") + } + + idStr, ok := additional["id"].(string) + if !ok { + return work, fmt.Errorf("missing or invalid id") + } + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return work, fmt.Errorf("failed to parse id: %w", err) + } + work.ID = uint(id) + + if title, ok := data["title"].(string); ok { + work.Title = title + } + if description, ok := data["description"].(string); ok { + work.Description = description + } + if language, ok := data["language"].(string); ok { + work.Language = language + } + if status, ok := data["status"].(string); ok { + work.Status = domain.WorkStatus(status) + } + if publishedAtStr, ok := data["publishedAt"].(string); ok { + publishedAt, err := time.Parse(time.RFC3339, publishedAtStr) + if err == nil { + work.PublishedAt = &publishedAt + } + } + if score, ok := additional["score"].(float64); ok { + work.Score = score + } + + return work, nil +} + +func mapToAuthor(data map[string]interface{}) (domain.Author, error) { + author := domain.Author{} + + additional, ok := data["_additional"].(map[string]interface{}) + if !ok { + return author, fmt.Errorf("missing _additional field") + } + + idStr, ok := additional["id"].(string) + if !ok { + return author, fmt.Errorf("missing or invalid id") + } + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return author, fmt.Errorf("failed to parse id: %w", err) + } + author.ID = uint(id) + + if name, ok := data["name"].(string); ok { + author.Name = name + } + if score, ok := additional["score"].(float64); ok { + author.Score = score + } + + return author, nil +} + +func mapToTranslation(data map[string]interface{}) (domain.Translation, error) { + translation := domain.Translation{} + + additional, ok := data["_additional"].(map[string]interface{}) + if !ok { + return translation, fmt.Errorf("missing _additional field") + } + + idStr, ok := additional["id"].(string) + if !ok { + return translation, fmt.Errorf("missing or invalid id") + } + id, err := strconv.ParseUint(idStr, 10, 64) + if err != nil { + return translation, fmt.Errorf("failed to parse id: %w", err) + } + translation.ID = uint(id) + + if title, ok := data["title"].(string); ok { + translation.Title = title + } + if language, ok := data["language"].(string); ok { + translation.Language = language + } + if score, ok := additional["score"].(float64); ok { + translation.Score = score + } + + return translation, nil +} + + func (w *weaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error { properties := map[string]interface{}{ "language": work.Language, diff --git a/internal/platform/search/weaviate_wrapper_mock.go b/internal/platform/search/weaviate_wrapper_mock.go new file mode 100644 index 0000000..df63791 --- /dev/null +++ b/internal/platform/search/weaviate_wrapper_mock.go @@ -0,0 +1,26 @@ +package search + +import ( + "context" + "tercul/internal/domain" + domainsearch "tercul/internal/domain/search" + + "github.com/stretchr/testify/mock" +) + +type MockWeaviateWrapper struct { + mock.Mock +} + +func (m *MockWeaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) { + args := m.Called(ctx, params) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*domain.SearchResults), args.Error(1) +} + +func (m *MockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error { + args := m.Called(ctx, work, content) + return args.Error(0) +} diff --git a/internal/platform/search/weaviate_wrapper_test.go b/internal/platform/search/weaviate_wrapper_test.go new file mode 100644 index 0000000..0ac0c4e --- /dev/null +++ b/internal/platform/search/weaviate_wrapper_test.go @@ -0,0 +1,38 @@ +package search + +import ( + "context" + "testing" + "tercul/internal/domain" + domainsearch "tercul/internal/domain/search" + "github.com/stretchr/testify/assert" +) + +func TestWeaviateWrapper_Search(t *testing.T) { + mockWrapper := new(MockWeaviateWrapper) + expectedResults := &domain.SearchResults{ + Works: []domain.Work{ + {Title: "Work 1", Description: "alpha beta", TranslatableModel: domain.TranslatableModel{Language: "en"}}, + }, + } + params := domainsearch.SearchParams{ + Query: "alpha", + Mode: domainsearch.SearchModeHybrid, + Filters: domain.SearchFilters{ + Languages: []string{"en"}, + }, + Limit: 1, + Offset: 0, + } + + mockWrapper.On("Search", context.Background(), params).Return(expectedResults, nil) + + results, err := mockWrapper.Search(context.Background(), params) + + assert.NoError(t, err) + assert.NotNil(t, results) + assert.Equal(t, 1, len(results.Works)) + assert.Equal(t, "Work 1", results.Works[0].Title) + + mockWrapper.AssertExpectations(t) +} diff --git a/internal/testutil/integration_test_utils.go b/internal/testutil/integration_test_utils.go index 57eaabe..6601229 100644 --- a/internal/testutil/integration_test_utils.go +++ b/internal/testutil/integration_test_utils.go @@ -40,6 +40,10 @@ import ( // mockSearchClient is a mock implementation of the SearchClient interface. type mockSearchClient struct{} +func (m *mockSearchClient) Search(ctx context.Context, params search.SearchParams) (*domain.SearchResults, error) { + return &domain.SearchResults{}, nil +} + func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error { return nil }