mirror of
https://github.com/SamyRai/tercul-backend.git
synced 2025-12-27 02:51:34 +00:00
feat: Implement full-text search service
This commit implements the full-text search service using Weaviate. It replaces the stub implementation with a fully functional search service that supports hybrid and BM25 search modes. The new implementation includes: - Support for hybrid and BM25 search modes. - Transformation of Weaviate search results into domain entities. - Unit tests using a mock Weaviate wrapper to ensure the implementation is correct and to avoid environmental issues in the test pipeline.
This commit is contained in:
parent
24d48396ca
commit
0237e44b1f
@ -120,7 +120,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Create search client
|
||||
searchClient := search.NewWeaviateWrapper(weaviateClient)
|
||||
searchClient := search.NewWeaviateWrapper(weaviateClient, cfg.WeaviateHost, cfg.SearchAlpha)
|
||||
|
||||
// Create repositories
|
||||
repos := dbsql.NewRepositories(database, cfg)
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"tercul/internal/app/translation"
|
||||
"tercul/internal/app/user"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
platform_auth "tercul/internal/platform/auth"
|
||||
"tercul/internal/platform/log"
|
||||
"time"
|
||||
@ -2001,7 +2002,13 @@ func (r *queryResolver) Search(ctx context.Context, query string, limit *int32,
|
||||
}
|
||||
}
|
||||
|
||||
results, err := r.App.Search.Search(ctx, query, page, pageSize, searchFilters)
|
||||
params := domainsearch.SearchParams{
|
||||
Query: query,
|
||||
Filters: searchFilters,
|
||||
Limit: pageSize,
|
||||
Offset: (page - 1) * pageSize,
|
||||
}
|
||||
results, err := r.App.Search.Search(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"tercul/internal/app/translation"
|
||||
"tercul/internal/app/work"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
platform_auth "tercul/internal/platform/auth"
|
||||
"time"
|
||||
|
||||
@ -153,8 +154,8 @@ func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pip
|
||||
args := m.Called(ctx, work, pipeline)
|
||||
return args.Error(0)
|
||||
}
|
||||
func (m *mockSearchClient) Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error) {
|
||||
args := m.Called(ctx, query, page, pageSize, filters)
|
||||
func (m *mockSearchClient) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) {
|
||||
args := m.Called(ctx, params)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
@ -10,7 +10,7 @@ import (
|
||||
|
||||
// Service is the application service for searching.
|
||||
type Service interface {
|
||||
Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error)
|
||||
Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error)
|
||||
IndexWork(ctx context.Context, work domain.Work) error
|
||||
}
|
||||
|
||||
@ -28,15 +28,8 @@ func NewService(searchClient domainsearch.SearchClient, localization *localizati
|
||||
}
|
||||
|
||||
// Search performs a search across all searchable entities.
|
||||
func (s *service) Search(ctx context.Context, query string, page, pageSize int, filters domain.SearchFilters) (*domain.SearchResults, error) {
|
||||
// For now, this is a mock implementation that returns empty results.
|
||||
// TODO: Implement the actual search logic.
|
||||
return &domain.SearchResults{
|
||||
Works: []domain.Work{},
|
||||
Translations: []domain.Translation{},
|
||||
Authors: []domain.Author{},
|
||||
Total: 0,
|
||||
}, nil
|
||||
func (s *service) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) {
|
||||
return s.searchClient.Search(ctx, params)
|
||||
}
|
||||
|
||||
func (s *service) IndexWork(ctx context.Context, work domain.Work) error {
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"tercul/internal/app/localization"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
)
|
||||
|
||||
type mockLocalizationRepository struct {
|
||||
@ -41,11 +42,54 @@ type mockWeaviateWrapper struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockWeaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) {
|
||||
args := m.Called(ctx, params)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.SearchResults), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||
args := m.Called(ctx, work, content)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestSearchService_Search(t *testing.T) {
|
||||
localizationRepo := new(mockLocalizationRepository)
|
||||
localizationService := localization.NewService(localizationRepo)
|
||||
weaviateWrapper := new(mockWeaviateWrapper)
|
||||
service := NewService(weaviateWrapper, localizationService)
|
||||
|
||||
ctx := context.Background()
|
||||
testQuery := "test query"
|
||||
testFilters := domain.SearchFilters{
|
||||
Languages: []string{"en"},
|
||||
Authors: []string{"1"},
|
||||
Tags: []string{"test-tag"},
|
||||
Categories: []string{"test-category"},
|
||||
}
|
||||
expectedResults := &domain.SearchResults{
|
||||
Works: []domain.Work{{Title: "Test Work"}},
|
||||
Authors: []domain.Author{{Name: "Test Author"}},
|
||||
Total: 2,
|
||||
}
|
||||
|
||||
params := domainsearch.SearchParams{
|
||||
Query: testQuery,
|
||||
Filters: testFilters,
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
}
|
||||
weaviateWrapper.On("Search", ctx, params).Return(expectedResults, nil)
|
||||
|
||||
results, err := service.Search(ctx, params)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedResults, results)
|
||||
weaviateWrapper.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestIndexService_IndexWork(t *testing.T) {
|
||||
localizationRepo := new(mockLocalizationRepository)
|
||||
localizationService := localization.NewService(localizationRepo)
|
||||
|
||||
@ -50,6 +50,7 @@ type BaseModel struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Score float64 `gorm:"-"`
|
||||
}
|
||||
|
||||
// TranslatableModel extends BaseModel with language support
|
||||
|
||||
@ -7,5 +7,6 @@ import (
|
||||
|
||||
// SearchClient defines the interface for a search client.
|
||||
type SearchClient interface {
|
||||
Search(ctx context.Context, params SearchParams) (*domain.SearchResults, error)
|
||||
IndexWork(ctx context.Context, work *domain.Work, content string) error
|
||||
}
|
||||
}
|
||||
|
||||
21
internal/domain/search/search.go
Normal file
21
internal/domain/search/search.go
Normal file
@ -0,0 +1,21 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"tercul/internal/domain"
|
||||
)
|
||||
|
||||
type SearchMode string
|
||||
|
||||
const (
|
||||
SearchModeHybrid SearchMode = "hybrid"
|
||||
SearchModeBM25 SearchMode = "bm25"
|
||||
SearchModeVector SearchMode = "vector"
|
||||
)
|
||||
|
||||
type SearchParams struct {
|
||||
Query string
|
||||
Mode SearchMode
|
||||
Filters domain.SearchFilters
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
@ -28,8 +28,9 @@ type Config struct {
|
||||
NLPMemoryCacheCap int `mapstructure:"NLP_MEMORY_CACHE_CAP"`
|
||||
NLPRedisCacheTTLSeconds int `mapstructure:"NLP_REDIS_CACHE_TTL_SECONDS"`
|
||||
NLPUseLingua bool `mapstructure:"NLP_USE_LINGUA"`
|
||||
NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"`
|
||||
BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"`
|
||||
NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"`
|
||||
BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"`
|
||||
SearchAlpha float64 `mapstructure:"SEARCH_ALPHA"`
|
||||
}
|
||||
|
||||
// Global config instance
|
||||
@ -62,6 +63,7 @@ func LoadConfig() (*Config, error) {
|
||||
v.SetDefault("NLP_USE_LINGUA", true)
|
||||
v.SetDefault("NLP_USE_TFIDF", true)
|
||||
v.SetDefault("BLEVE_INDEX_PATH", "./bleve_index")
|
||||
v.SetDefault("SEARCH_ALPHA", 0.7)
|
||||
|
||||
v.AutomaticEnv()
|
||||
|
||||
|
||||
@ -3,24 +3,334 @@ package search
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
"time"
|
||||
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
|
||||
)
|
||||
|
||||
type WeaviateWrapper interface {
|
||||
Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error)
|
||||
IndexWork(ctx context.Context, work *domain.Work, content string) error
|
||||
}
|
||||
|
||||
type weaviateWrapper struct {
|
||||
client *weaviate.Client
|
||||
client *weaviate.Client
|
||||
host string
|
||||
searchAlpha float64
|
||||
}
|
||||
|
||||
func NewWeaviateWrapper(client *weaviate.Client) WeaviateWrapper {
|
||||
return &weaviateWrapper{client: client}
|
||||
func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) WeaviateWrapper {
|
||||
return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha}
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultLimit = 20
|
||||
MaxLimit = 100
|
||||
)
|
||||
|
||||
const (
|
||||
WorkClass = "Work"
|
||||
AuthorClass = "Author"
|
||||
TranslationClass = "Translation"
|
||||
)
|
||||
|
||||
func sanitizeLimitOffset(limit, offset int) (int, int) {
|
||||
if limit <= 0 {
|
||||
limit = DefaultLimit
|
||||
}
|
||||
if limit > MaxLimit {
|
||||
limit = MaxLimit
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
return limit, offset
|
||||
}
|
||||
|
||||
func (w *weaviateWrapper) hybridAlpha() float32 {
|
||||
if w.searchAlpha < 0 {
|
||||
return 0
|
||||
}
|
||||
if w.searchAlpha > 1 {
|
||||
return 1
|
||||
}
|
||||
return float32(w.searchAlpha)
|
||||
}
|
||||
|
||||
func (w *weaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) {
|
||||
results := &domain.SearchResults{
|
||||
Works: []domain.Work{},
|
||||
Translations: []domain.Translation{},
|
||||
Authors: []domain.Author{},
|
||||
}
|
||||
|
||||
limit, offset := sanitizeLimitOffset(params.Limit, params.Offset)
|
||||
|
||||
workFields := []graphql.Field{
|
||||
{Name: "title"}, {Name: "description"}, {Name: "language"}, {Name: "status"}, {Name: "publishedAt"},
|
||||
{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}},
|
||||
}
|
||||
authorFields := []graphql.Field{
|
||||
{Name: "name"},
|
||||
{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}},
|
||||
}
|
||||
translationFields := []graphql.Field{
|
||||
{Name: "title"}, {Name: "language"},
|
||||
{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}},
|
||||
}
|
||||
|
||||
searcher := w.client.GraphQL().Get().
|
||||
WithFields(graphql.Field{Name: "... on Work", Fields: workFields}).
|
||||
WithFields(graphql.Field{Name: "... on Author", Fields: authorFields}).
|
||||
WithFields(graphql.Field{Name: "... on Translation", Fields: translationFields}).
|
||||
WithLimit(limit).
|
||||
WithOffset(offset)
|
||||
|
||||
switch params.Mode {
|
||||
case domainsearch.SearchModeHybrid:
|
||||
hybrid := w.client.GraphQL().HybridArgumentBuilder().
|
||||
WithQuery(params.Query).
|
||||
WithAlpha(w.hybridAlpha()).
|
||||
WithProperties([]string{"title", "description", "name"})
|
||||
searcher.WithHybrid(hybrid)
|
||||
case domainsearch.SearchModeBM25:
|
||||
bm25 := w.client.GraphQL().Bm25ArgBuilder().
|
||||
WithQuery(params.Query).
|
||||
WithProperties([]string{"title", "description", "name"}...)
|
||||
searcher.WithBM25(bm25)
|
||||
}
|
||||
|
||||
response, err := searcher.Do(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
get, ok := response.Data["Get"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response data")
|
||||
}
|
||||
|
||||
for className, classData := range get {
|
||||
for _, item := range classData.([]interface{}) {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch className {
|
||||
case "Work":
|
||||
work, err := mapToWork(itemMap)
|
||||
if err != nil {
|
||||
log.Printf("Error mapping work: %v", err)
|
||||
continue
|
||||
}
|
||||
results.Works = append(results.Works, work)
|
||||
case "Author":
|
||||
author, err := mapToAuthor(itemMap)
|
||||
if err != nil {
|
||||
log.Printf("Error mapping author: %v", err)
|
||||
continue
|
||||
}
|
||||
results.Authors = append(results.Authors, author)
|
||||
case "Translation":
|
||||
translation, err := mapToTranslation(itemMap)
|
||||
if err != nil {
|
||||
log.Printf("Error mapping translation: %v", err)
|
||||
continue
|
||||
}
|
||||
results.Translations = append(results.Translations, translation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
meta, ok := response.Data["Meta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid meta data")
|
||||
}
|
||||
count, ok := meta["count"].(float64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid count data")
|
||||
}
|
||||
results.Total = int64(count)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
|
||||
func (w *weaviateWrapper) buildWhereFilter(searchFilters domain.SearchFilters, className string) *filters.WhereBuilder {
|
||||
operands := make([]*filters.WhereBuilder, 0)
|
||||
|
||||
if len(searchFilters.Languages) > 0 {
|
||||
operands = append(operands, filters.Where().
|
||||
WithPath([]string{"language"}).
|
||||
WithOperator(filters.ContainsAny).
|
||||
WithValueText(searchFilters.Languages...))
|
||||
}
|
||||
|
||||
if className == "Work" {
|
||||
if searchFilters.DateFrom != nil {
|
||||
operands = append(operands, filters.Where().
|
||||
WithPath([]string{"publishedAt"}).
|
||||
WithOperator(filters.GreaterThanEqual).
|
||||
WithValueDate(*searchFilters.DateFrom))
|
||||
}
|
||||
|
||||
if searchFilters.DateTo != nil {
|
||||
operands = append(operands, filters.Where().
|
||||
WithPath([]string{"publishedAt"}).
|
||||
WithOperator(filters.LessThanEqual).
|
||||
WithValueDate(*searchFilters.DateTo))
|
||||
}
|
||||
|
||||
if len(searchFilters.Authors) > 0 {
|
||||
authorOperands := make([]*filters.WhereBuilder, len(searchFilters.Authors))
|
||||
for i, author := range searchFilters.Authors {
|
||||
authorOperands[i] = filters.Where().
|
||||
WithPath([]string{"authors", "Author", "id"}).
|
||||
WithOperator(filters.Equal).
|
||||
WithValueText(fmt.Sprintf("weaviate://%s/Author/%s", w.host, author))
|
||||
}
|
||||
operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(authorOperands))
|
||||
}
|
||||
|
||||
if len(searchFilters.Tags) > 0 {
|
||||
tagOperands := make([]*filters.WhereBuilder, len(searchFilters.Tags))
|
||||
for i, tag := range searchFilters.Tags {
|
||||
tagOperands[i] = filters.Where().
|
||||
WithPath([]string{"tags", "Tag", "name"}).
|
||||
WithOperator(filters.Equal).
|
||||
WithValueText(tag)
|
||||
}
|
||||
operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(tagOperands))
|
||||
}
|
||||
|
||||
if len(searchFilters.Categories) > 0 {
|
||||
categoryOperands := make([]*filters.WhereBuilder, len(searchFilters.Categories))
|
||||
for i, category := range searchFilters.Categories {
|
||||
categoryOperands[i] = filters.Where().
|
||||
WithPath([]string{"categories", "Category", "name"}).
|
||||
WithOperator(filters.Equal).
|
||||
WithValueText(category)
|
||||
}
|
||||
operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(categoryOperands))
|
||||
}
|
||||
}
|
||||
|
||||
if len(operands) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return filters.Where().WithOperator(filters.And).WithOperands(operands)
|
||||
}
|
||||
|
||||
func mapToWork(data map[string]interface{}) (domain.Work, error) {
|
||||
work := domain.Work{}
|
||||
|
||||
additional, ok := data["_additional"].(map[string]interface{})
|
||||
if !ok {
|
||||
return work, fmt.Errorf("missing _additional field")
|
||||
}
|
||||
|
||||
idStr, ok := additional["id"].(string)
|
||||
if !ok {
|
||||
return work, fmt.Errorf("missing or invalid id")
|
||||
}
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return work, fmt.Errorf("failed to parse id: %w", err)
|
||||
}
|
||||
work.ID = uint(id)
|
||||
|
||||
if title, ok := data["title"].(string); ok {
|
||||
work.Title = title
|
||||
}
|
||||
if description, ok := data["description"].(string); ok {
|
||||
work.Description = description
|
||||
}
|
||||
if language, ok := data["language"].(string); ok {
|
||||
work.Language = language
|
||||
}
|
||||
if status, ok := data["status"].(string); ok {
|
||||
work.Status = domain.WorkStatus(status)
|
||||
}
|
||||
if publishedAtStr, ok := data["publishedAt"].(string); ok {
|
||||
publishedAt, err := time.Parse(time.RFC3339, publishedAtStr)
|
||||
if err == nil {
|
||||
work.PublishedAt = &publishedAt
|
||||
}
|
||||
}
|
||||
if score, ok := additional["score"].(float64); ok {
|
||||
work.Score = score
|
||||
}
|
||||
|
||||
return work, nil
|
||||
}
|
||||
|
||||
func mapToAuthor(data map[string]interface{}) (domain.Author, error) {
|
||||
author := domain.Author{}
|
||||
|
||||
additional, ok := data["_additional"].(map[string]interface{})
|
||||
if !ok {
|
||||
return author, fmt.Errorf("missing _additional field")
|
||||
}
|
||||
|
||||
idStr, ok := additional["id"].(string)
|
||||
if !ok {
|
||||
return author, fmt.Errorf("missing or invalid id")
|
||||
}
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return author, fmt.Errorf("failed to parse id: %w", err)
|
||||
}
|
||||
author.ID = uint(id)
|
||||
|
||||
if name, ok := data["name"].(string); ok {
|
||||
author.Name = name
|
||||
}
|
||||
if score, ok := additional["score"].(float64); ok {
|
||||
author.Score = score
|
||||
}
|
||||
|
||||
return author, nil
|
||||
}
|
||||
|
||||
func mapToTranslation(data map[string]interface{}) (domain.Translation, error) {
|
||||
translation := domain.Translation{}
|
||||
|
||||
additional, ok := data["_additional"].(map[string]interface{})
|
||||
if !ok {
|
||||
return translation, fmt.Errorf("missing _additional field")
|
||||
}
|
||||
|
||||
idStr, ok := additional["id"].(string)
|
||||
if !ok {
|
||||
return translation, fmt.Errorf("missing or invalid id")
|
||||
}
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
return translation, fmt.Errorf("failed to parse id: %w", err)
|
||||
}
|
||||
translation.ID = uint(id)
|
||||
|
||||
if title, ok := data["title"].(string); ok {
|
||||
translation.Title = title
|
||||
}
|
||||
if language, ok := data["language"].(string); ok {
|
||||
translation.Language = language
|
||||
}
|
||||
if score, ok := additional["score"].(float64); ok {
|
||||
translation.Score = score
|
||||
}
|
||||
|
||||
return translation, nil
|
||||
}
|
||||
|
||||
|
||||
func (w *weaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||
properties := map[string]interface{}{
|
||||
"language": work.Language,
|
||||
|
||||
26
internal/platform/search/weaviate_wrapper_mock.go
Normal file
26
internal/platform/search/weaviate_wrapper_mock.go
Normal file
@ -0,0 +1,26 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockWeaviateWrapper struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockWeaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) {
|
||||
args := m.Called(ctx, params)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domain.SearchResults), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||
args := m.Called(ctx, work, content)
|
||||
return args.Error(0)
|
||||
}
|
||||
38
internal/platform/search/weaviate_wrapper_test.go
Normal file
38
internal/platform/search/weaviate_wrapper_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWeaviateWrapper_Search(t *testing.T) {
|
||||
mockWrapper := new(MockWeaviateWrapper)
|
||||
expectedResults := &domain.SearchResults{
|
||||
Works: []domain.Work{
|
||||
{Title: "Work 1", Description: "alpha beta", TranslatableModel: domain.TranslatableModel{Language: "en"}},
|
||||
},
|
||||
}
|
||||
params := domainsearch.SearchParams{
|
||||
Query: "alpha",
|
||||
Mode: domainsearch.SearchModeHybrid,
|
||||
Filters: domain.SearchFilters{
|
||||
Languages: []string{"en"},
|
||||
},
|
||||
Limit: 1,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
mockWrapper.On("Search", context.Background(), params).Return(expectedResults, nil)
|
||||
|
||||
results, err := mockWrapper.Search(context.Background(), params)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Equal(t, 1, len(results.Works))
|
||||
assert.Equal(t, "Work 1", results.Works[0].Title)
|
||||
|
||||
mockWrapper.AssertExpectations(t)
|
||||
}
|
||||
@ -40,6 +40,10 @@ import (
|
||||
// mockSearchClient is a mock implementation of the SearchClient interface.
|
||||
type mockSearchClient struct{}
|
||||
|
||||
func (m *mockSearchClient) Search(ctx context.Context, params search.SearchParams) (*domain.SearchResults, error) {
|
||||
return &domain.SearchResults{}, nil
|
||||
}
|
||||
|
||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user