package search import ( "context" "encoding/json" "fmt" "sort" "strconv" "tercul/internal/domain" domainsearch "tercul/internal/domain/search" "github.com/weaviate/weaviate-go-client/v5/weaviate" "github.com/weaviate/weaviate-go-client/v5/weaviate/filters" "github.com/weaviate/weaviate-go-client/v5/weaviate/graphql" "github.com/weaviate/weaviate/entities/models" ) type weaviateWrapper struct { client *weaviate.Client host string searchAlpha float64 } // NewWeaviateWrapper creates a new WeaviateWrapper that implements the SearchClient interface. func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) domainsearch.SearchClient { return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha} } // hybridAlpha returns the alpha value for hybrid search, clamped between 0 and 1. func (w *weaviateWrapper) hybridAlpha() float32 { if w.searchAlpha < 0 { return 0 } if w.searchAlpha > 1 { return 1 } return float32(w.searchAlpha) } // Search performs a multi-class search against the Weaviate instance. func (w *weaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) { allResults := make([]domainsearch.SearchResultItem, 0) // Determine which entity types to search for. If a type filter is present, use it. var searchTypes []string if len(params.Filters.Types) > 0 { searchTypes = params.Filters.Types } else { searchTypes = []string{"Work", "Translation", "Author"} } if contains(searchTypes, "Work") { workResults, err := w.searchWorks(ctx, ¶ms) if err != nil { return nil, err } allResults = append(allResults, workResults...) } if contains(searchTypes, "Translation") { translationResults, err := w.searchTranslations(ctx, ¶ms) if err != nil { return nil, err } allResults = append(allResults, translationResults...) } if contains(searchTypes, "Author") { authorResults, err := w.searchAuthors(ctx, ¶ms) if err != nil { return nil, err } allResults = append(allResults, authorResults...) } // --- Sort by Relevance Score --- sort.Slice(allResults, func(i, j int) bool { return allResults[i].Score > allResults[j].Score // Descending order }) totalResults := int64(len(allResults)) // --- Paginate In-Memory --- paginatedResults := make([]domainsearch.SearchResultItem, 0) start := params.Offset end := params.Offset + params.Limit if start < len(allResults) { if end > len(allResults) { end = len(allResults) } paginatedResults = allResults[start:end] } return &domainsearch.SearchResults{ Results: paginatedResults, TotalResults: totalResults, Limit: params.Limit, Offset: params.Offset, }, nil } func (w *weaviateWrapper) searchWorks(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) { fields := []graphql.Field{ {Name: "db_id"}, {Name: "title"}, {Name: "description"}, {Name: "language"}, {Name: "status"}, {Name: "createdAt"}, {Name: "updatedAt"}, {Name: "tags"}, {Name: "_additional", Fields: []graphql.Field{{Name: "score"}}}, } searcher := w.client.GraphQL().Get().WithClassName("Work").WithFields(fields...) w.addSearchArguments(searcher, params, "Work", []string{"title", "description"}) resp, err := searcher.Do(ctx) if err != nil { return nil, fmt.Errorf("failed to search works: %w", err) } return w.parseGraphQLResponse(resp, "Work") } func (w *weaviateWrapper) searchTranslations(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) { fields := []graphql.Field{ {Name: "db_id"}, {Name: "title"}, {Name: "content"}, {Name: "language"}, {Name: "status"}, {Name: "_additional", Fields: []graphql.Field{{Name: "score"}}}, } searcher := w.client.GraphQL().Get().WithClassName("Translation").WithFields(fields...) w.addSearchArguments(searcher, params, "Translation", []string{"title", "content"}) resp, err := searcher.Do(ctx) if err != nil { return nil, fmt.Errorf("failed to search translations: %w", err) } return w.parseGraphQLResponse(resp, "Translation") } func (w *weaviateWrapper) searchAuthors(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) { fields := []graphql.Field{ {Name: "db_id"}, {Name: "name"}, {Name: "biography"}, {Name: "_additional", Fields: []graphql.Field{{Name: "score"}}}, } searcher := w.client.GraphQL().Get().WithClassName("Author").WithFields(fields...) w.addSearchArguments(searcher, params, "Author", []string{"name", "biography"}) // Authors should not be filtered by language if len(params.Filters.Languages) > 0 { return []domainsearch.SearchResultItem{}, nil } resp, err := searcher.Do(ctx) if err != nil { return nil, fmt.Errorf("failed to search authors: %w", err) } return w.parseGraphQLResponse(resp, "Author") } func (w *weaviateWrapper) addSearchArguments(searcher *graphql.GetBuilder, params *domainsearch.SearchParams, className string, searchFields []string) { if params.Query != "" || len(params.Concepts) > 0 { switch params.Mode { case domainsearch.SearchModeBM25: bm25 := w.client.GraphQL().Bm25ArgBuilder().WithQuery(params.Query).WithProperties(searchFields...) searcher.WithBM25(bm25) case domainsearch.SearchModeVector: nearText := w.client.GraphQL().NearTextArgBuilder().WithConcepts(params.Concepts) searcher.WithNearText(nearText) default: hybrid := w.client.GraphQL().HybridArgumentBuilder(). WithQuery(params.Query). WithAlpha(w.hybridAlpha()) searcher.WithHybrid(hybrid) } } where := buildWhereFilter(¶ms.Filters, className) if where != nil { searcher.WithWhere(where) } } // buildWhereFilter constructs the 'where' clause for the GraphQL query based on the search parameters. func buildWhereFilter(searchFilters *domainsearch.SearchFilters, className string) *filters.WhereBuilder { if searchFilters == nil { return nil } operands := make([]*filters.WhereBuilder, 0) if (className == "Work" || className == "Translation") && len(searchFilters.Languages) > 0 { operands = append(operands, filters.Where(). WithPath([]string{"language"}). WithOperator(filters.ContainsAny). WithValueText(searchFilters.Languages...)) } if className == "Work" { if searchFilters.DateFrom != nil && !searchFilters.DateFrom.IsZero() { operands = append(operands, filters.Where(). WithPath([]string{"createdAt"}). WithOperator(filters.GreaterThanEqual). WithValueDate(*searchFilters.DateFrom)) } if searchFilters.DateTo != nil && !searchFilters.DateTo.IsZero() { operands = append(operands, filters.Where(). WithPath([]string{"createdAt"}). WithOperator(filters.LessThanEqual). WithValueDate(*searchFilters.DateTo)) } if len(searchFilters.Tags) > 0 { operands = append(operands, filters.Where(). WithPath([]string{"tags"}). WithOperator(filters.ContainsAny). WithValueText(searchFilters.Tags...)) } } if className == "Author" && len(searchFilters.Authors) > 0 { operands = append(operands, filters.Where(). WithPath([]string{"name"}). WithOperator(filters.ContainsAny). WithValueText(searchFilters.Authors...)) } if len(operands) == 0 { return nil } return filters.Where().WithOperator(filters.And).WithOperands(operands) } type weaviateResult struct { Additional struct { Score string `json:"score"` } `json:"_additional"` Properties map[string]interface{} `json:"-"` } func (r *weaviateResult) UnmarshalJSON(data []byte) error { var raw map[string]interface{} if err := json.Unmarshal(data, &raw); err != nil { return err } if add, ok := raw["_additional"]; ok { addBytes, err := json.Marshal(add) if err != nil { return err } if err := json.Unmarshal(addBytes, &r.Additional); err != nil { return err } } delete(raw, "_additional") r.Properties = raw return nil } // Temporary struct to handle the mismatch between Weaviate's string array for tags // and the domain's []*Tag struct. type workWithDBIDAndStringTags struct { domain.Work DBID json.Number `json:"db_id"` Tags []string `json:"tags"` // This captures the tags as strings } type translationWithDBID struct { domain.Translation DBID json.Number `json:"db_id"` } type authorWithDBID struct { domain.Author DBID json.Number `json:"db_id"` Biography string `json:"biography"` } func (w *weaviateWrapper) parseGraphQLResponse(resp *models.GraphQLResponse, className string) ([]domainsearch.SearchResultItem, error) { var results []domainsearch.SearchResultItem get, ok := resp.Data["Get"].(map[string]interface{}) if !ok { return nil, fmt.Errorf("unexpected response format: 'Get' not found") } classData, ok := get[className].([]interface{}) if !ok { return results, nil } for _, itemData := range classData { itemBytes, err := json.Marshal(itemData) if err != nil { return nil, fmt.Errorf("failed to marshal search item: %w", err) } var tempResult weaviateResult if err := json.Unmarshal(itemBytes, &tempResult); err != nil { return nil, fmt.Errorf("failed to unmarshal search item: %w", err) } score, err := strconv.ParseFloat(tempResult.Additional.Score, 64) if err != nil { // In BM25, score can be nil if not found, treat as 0 score = 0 } propBytes, err := json.Marshal(tempResult.Properties) if err != nil { return nil, fmt.Errorf("failed to marshal item properties: %w", err) } var entity interface{} switch className { case "Work": var tempWork workWithDBIDAndStringTags if err := json.Unmarshal(propBytes, &tempWork); err != nil { return nil, fmt.Errorf("failed to unmarshal work: %w", err) } id, err := tempWork.DBID.Int64() if err != nil { return nil, fmt.Errorf("failed to parse work db_id: %w", err) } tempWork.Work.ID = uint(id) // Convert []string to []*domain.Tag finalTags := make([]*domain.Tag, len(tempWork.Tags)) for i, tagName := range tempWork.Tags { finalTags[i] = &domain.Tag{Name: tagName} } tempWork.Work.Tags = finalTags entity = tempWork.Work case "Translation": var translation translationWithDBID if err := json.Unmarshal(propBytes, &translation); err != nil { return nil, fmt.Errorf("failed to unmarshal translation: %w", err) } id, err := translation.DBID.Int64() if err != nil { return nil, fmt.Errorf("failed to parse translation db_id: %w", err) } translation.Translation.ID = uint(id) entity = translation.Translation case "Author": var author authorWithDBID if err := json.Unmarshal(propBytes, &author); err != nil { return nil, fmt.Errorf("failed to unmarshal author: %w", err) } id, err := author.DBID.Int64() if err != nil { return nil, fmt.Errorf("failed to parse author db_id: %w", err) } author.Author.ID = uint(id) entity = author.Author default: return nil, fmt.Errorf("unknown class name for parsing: %s", className) } results = append(results, domainsearch.SearchResultItem{ Type: className, Entity: entity, Score: score, }) } return results, nil } func (w *weaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error { // Convert []*domain.Tag to []string tags := make([]string, len(work.Tags)) for i, tag := range work.Tags { tags[i] = tag.Name } properties := map[string]interface{}{ "db_id": work.ID, "title": work.Title, "description": work.Description, "language": work.Language, "status": work.Status, "createdAt": work.CreatedAt, "updatedAt": work.UpdatedAt, "tags": tags, "content": content, // Assuming content is passed in } _, err := w.client.Data().Creator(). WithClassName("Work"). WithID(strconv.FormatUint(uint64(work.ID), 10)). WithProperties(properties). Do(ctx) return err } func contains(slice []string, item string) bool { for _, s := range slice { if s == item { return true } } return false }