mirror of
https://github.com/SamyRai/tercul-backend.git
synced 2025-12-27 05:11:34 +00:00
Add configurable SearchAlpha parameter for hybrid search tuning - Added SearchAlpha config parameter (default: 0.7) for tuning BM25 vs vector search balance - Updated NewWeaviateWrapper to accept host and searchAlpha parameters - Enhanced hybrid search with configurable alpha parameter via WithAlpha() - Fixed all type mismatches in mocks and tests to use domainsearch.SearchResults - Updated GraphQL resolver to use new SearchResults structure with SearchResultItem - All tests and vet checks passing Closes #30
396 lines
12 KiB
Go
396 lines
12 KiB
Go
package search
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sort"
|
|
"strconv"
|
|
"tercul/internal/domain"
|
|
domainsearch "tercul/internal/domain/search"
|
|
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
|
|
"github.com/weaviate/weaviate/entities/models"
|
|
)
|
|
|
|
type weaviateWrapper struct {
|
|
client *weaviate.Client
|
|
host string
|
|
searchAlpha float64
|
|
}
|
|
|
|
// NewWeaviateWrapper creates a new WeaviateWrapper that implements the SearchClient interface.
|
|
func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) domainsearch.SearchClient {
|
|
return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha}
|
|
}
|
|
|
|
// hybridAlpha returns the alpha value for hybrid search, clamped between 0 and 1.
|
|
func (w *weaviateWrapper) hybridAlpha() float32 {
|
|
if w.searchAlpha < 0 {
|
|
return 0
|
|
}
|
|
if w.searchAlpha > 1 {
|
|
return 1
|
|
}
|
|
return float32(w.searchAlpha)
|
|
}
|
|
|
|
// Search performs a multi-class search against the Weaviate instance.
|
|
func (w *weaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
|
allResults := make([]domainsearch.SearchResultItem, 0)
|
|
|
|
// Determine which entity types to search for. If a type filter is present, use it.
|
|
var searchTypes []string
|
|
if len(params.Filters.Types) > 0 {
|
|
searchTypes = params.Filters.Types
|
|
} else {
|
|
searchTypes = []string{"Work", "Translation", "Author"}
|
|
}
|
|
|
|
if contains(searchTypes, "Work") {
|
|
workResults, err := w.searchWorks(ctx, ¶ms)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allResults = append(allResults, workResults...)
|
|
}
|
|
|
|
if contains(searchTypes, "Translation") {
|
|
translationResults, err := w.searchTranslations(ctx, ¶ms)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allResults = append(allResults, translationResults...)
|
|
}
|
|
|
|
if contains(searchTypes, "Author") {
|
|
authorResults, err := w.searchAuthors(ctx, ¶ms)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allResults = append(allResults, authorResults...)
|
|
}
|
|
|
|
// --- Sort by Relevance Score ---
|
|
sort.Slice(allResults, func(i, j int) bool {
|
|
return allResults[i].Score > allResults[j].Score // Descending order
|
|
})
|
|
|
|
totalResults := int64(len(allResults))
|
|
|
|
// --- Paginate In-Memory ---
|
|
paginatedResults := make([]domainsearch.SearchResultItem, 0)
|
|
start := params.Offset
|
|
end := params.Offset + params.Limit
|
|
|
|
if start < len(allResults) {
|
|
if end > len(allResults) {
|
|
end = len(allResults)
|
|
}
|
|
paginatedResults = allResults[start:end]
|
|
}
|
|
|
|
return &domainsearch.SearchResults{
|
|
Results: paginatedResults,
|
|
TotalResults: totalResults,
|
|
Limit: params.Limit,
|
|
Offset: params.Offset,
|
|
}, nil
|
|
}
|
|
|
|
func (w *weaviateWrapper) searchWorks(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) {
|
|
fields := []graphql.Field{
|
|
{Name: "db_id"}, {Name: "title"}, {Name: "description"}, {Name: "language"},
|
|
{Name: "status"}, {Name: "createdAt"}, {Name: "updatedAt"}, {Name: "tags"},
|
|
{Name: "_additional", Fields: []graphql.Field{{Name: "score"}}},
|
|
}
|
|
searcher := w.client.GraphQL().Get().WithClassName("Work").WithFields(fields...)
|
|
w.addSearchArguments(searcher, params, "Work", []string{"title", "description"})
|
|
resp, err := searcher.Do(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search works: %w", err)
|
|
}
|
|
return w.parseGraphQLResponse(resp, "Work")
|
|
}
|
|
|
|
func (w *weaviateWrapper) searchTranslations(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) {
|
|
fields := []graphql.Field{
|
|
{Name: "db_id"}, {Name: "title"}, {Name: "content"}, {Name: "language"}, {Name: "status"},
|
|
{Name: "_additional", Fields: []graphql.Field{{Name: "score"}}},
|
|
}
|
|
searcher := w.client.GraphQL().Get().WithClassName("Translation").WithFields(fields...)
|
|
w.addSearchArguments(searcher, params, "Translation", []string{"title", "content"})
|
|
resp, err := searcher.Do(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search translations: %w", err)
|
|
}
|
|
return w.parseGraphQLResponse(resp, "Translation")
|
|
}
|
|
|
|
func (w *weaviateWrapper) searchAuthors(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) {
|
|
fields := []graphql.Field{
|
|
{Name: "db_id"}, {Name: "name"}, {Name: "biography"},
|
|
{Name: "_additional", Fields: []graphql.Field{{Name: "score"}}},
|
|
}
|
|
searcher := w.client.GraphQL().Get().WithClassName("Author").WithFields(fields...)
|
|
w.addSearchArguments(searcher, params, "Author", []string{"name", "biography"})
|
|
// Authors should not be filtered by language
|
|
if len(params.Filters.Languages) > 0 {
|
|
return []domainsearch.SearchResultItem{}, nil
|
|
}
|
|
resp, err := searcher.Do(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search authors: %w", err)
|
|
}
|
|
return w.parseGraphQLResponse(resp, "Author")
|
|
}
|
|
|
|
func (w *weaviateWrapper) addSearchArguments(searcher *graphql.GetBuilder, params *domainsearch.SearchParams, className string, searchFields []string) {
|
|
if params.Query != "" || len(params.Concepts) > 0 {
|
|
switch params.Mode {
|
|
case domainsearch.SearchModeBM25:
|
|
bm25 := w.client.GraphQL().Bm25ArgBuilder().WithQuery(params.Query).WithProperties(searchFields...)
|
|
searcher.WithBM25(bm25)
|
|
case domainsearch.SearchModeVector:
|
|
nearText := w.client.GraphQL().NearTextArgBuilder().WithConcepts(params.Concepts)
|
|
searcher.WithNearText(nearText)
|
|
default:
|
|
hybrid := w.client.GraphQL().HybridArgumentBuilder().
|
|
WithQuery(params.Query).
|
|
WithAlpha(w.hybridAlpha())
|
|
searcher.WithHybrid(hybrid)
|
|
}
|
|
}
|
|
where := buildWhereFilter(¶ms.Filters, className)
|
|
if where != nil {
|
|
searcher.WithWhere(where)
|
|
}
|
|
}
|
|
|
|
// buildWhereFilter constructs the 'where' clause for the GraphQL query based on the search parameters.
|
|
func buildWhereFilter(searchFilters *domainsearch.SearchFilters, className string) *filters.WhereBuilder {
|
|
if searchFilters == nil {
|
|
return nil
|
|
}
|
|
|
|
operands := make([]*filters.WhereBuilder, 0)
|
|
|
|
if (className == "Work" || className == "Translation") && len(searchFilters.Languages) > 0 {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"language"}).
|
|
WithOperator(filters.ContainsAny).
|
|
WithValueText(searchFilters.Languages...))
|
|
}
|
|
|
|
if className == "Work" {
|
|
if searchFilters.DateFrom != nil && !searchFilters.DateFrom.IsZero() {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"createdAt"}).
|
|
WithOperator(filters.GreaterThanEqual).
|
|
WithValueDate(*searchFilters.DateFrom))
|
|
}
|
|
if searchFilters.DateTo != nil && !searchFilters.DateTo.IsZero() {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"createdAt"}).
|
|
WithOperator(filters.LessThanEqual).
|
|
WithValueDate(*searchFilters.DateTo))
|
|
}
|
|
if len(searchFilters.Tags) > 0 {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"tags"}).
|
|
WithOperator(filters.ContainsAny).
|
|
WithValueText(searchFilters.Tags...))
|
|
}
|
|
}
|
|
|
|
if className == "Author" && len(searchFilters.Authors) > 0 {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"name"}).
|
|
WithOperator(filters.ContainsAny).
|
|
WithValueText(searchFilters.Authors...))
|
|
}
|
|
|
|
if len(operands) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return filters.Where().WithOperator(filters.And).WithOperands(operands)
|
|
}
|
|
|
|
type weaviateResult struct {
|
|
Additional struct {
|
|
Score string `json:"score"`
|
|
} `json:"_additional"`
|
|
Properties map[string]interface{} `json:"-"`
|
|
}
|
|
|
|
func (r *weaviateResult) UnmarshalJSON(data []byte) error {
|
|
var raw map[string]interface{}
|
|
if err := json.Unmarshal(data, &raw); err != nil {
|
|
return err
|
|
}
|
|
|
|
if add, ok := raw["_additional"]; ok {
|
|
addBytes, err := json.Marshal(add)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := json.Unmarshal(addBytes, &r.Additional); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
delete(raw, "_additional")
|
|
r.Properties = raw
|
|
return nil
|
|
}
|
|
|
|
// Temporary struct to handle the mismatch between Weaviate's string array for tags
|
|
// and the domain's []*Tag struct.
|
|
type workWithDBIDAndStringTags struct {
|
|
domain.Work
|
|
DBID json.Number `json:"db_id"`
|
|
Tags []string `json:"tags"` // This captures the tags as strings
|
|
}
|
|
|
|
type translationWithDBID struct {
|
|
domain.Translation
|
|
DBID json.Number `json:"db_id"`
|
|
}
|
|
type authorWithDBID struct {
|
|
domain.Author
|
|
DBID json.Number `json:"db_id"`
|
|
Biography string `json:"biography"`
|
|
}
|
|
|
|
func (w *weaviateWrapper) parseGraphQLResponse(resp *models.GraphQLResponse, className string) ([]domainsearch.SearchResultItem, error) {
|
|
var results []domainsearch.SearchResultItem
|
|
|
|
get, ok := resp.Data["Get"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected response format: 'Get' not found")
|
|
}
|
|
|
|
classData, ok := get[className].([]interface{})
|
|
if !ok {
|
|
return results, nil
|
|
}
|
|
|
|
for _, itemData := range classData {
|
|
itemBytes, err := json.Marshal(itemData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal search item: %w", err)
|
|
}
|
|
|
|
var tempResult weaviateResult
|
|
if err := json.Unmarshal(itemBytes, &tempResult); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal search item: %w", err)
|
|
}
|
|
|
|
score, err := strconv.ParseFloat(tempResult.Additional.Score, 64)
|
|
if err != nil {
|
|
// In BM25, score can be nil if not found, treat as 0
|
|
score = 0
|
|
}
|
|
|
|
propBytes, err := json.Marshal(tempResult.Properties)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal item properties: %w", err)
|
|
}
|
|
|
|
var entity interface{}
|
|
switch className {
|
|
case "Work":
|
|
var tempWork workWithDBIDAndStringTags
|
|
if err := json.Unmarshal(propBytes, &tempWork); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal work: %w", err)
|
|
}
|
|
id, err := tempWork.DBID.Int64()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse work db_id: %w", err)
|
|
}
|
|
tempWork.Work.ID = uint(id)
|
|
|
|
// Convert []string to []*domain.Tag
|
|
finalTags := make([]*domain.Tag, len(tempWork.Tags))
|
|
for i, tagName := range tempWork.Tags {
|
|
finalTags[i] = &domain.Tag{Name: tagName}
|
|
}
|
|
tempWork.Work.Tags = finalTags
|
|
|
|
entity = tempWork.Work
|
|
case "Translation":
|
|
var translation translationWithDBID
|
|
if err := json.Unmarshal(propBytes, &translation); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal translation: %w", err)
|
|
}
|
|
id, err := translation.DBID.Int64()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse translation db_id: %w", err)
|
|
}
|
|
translation.Translation.ID = uint(id)
|
|
entity = translation.Translation
|
|
case "Author":
|
|
var author authorWithDBID
|
|
if err := json.Unmarshal(propBytes, &author); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal author: %w", err)
|
|
}
|
|
id, err := author.DBID.Int64()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse author db_id: %w", err)
|
|
}
|
|
author.Author.ID = uint(id)
|
|
entity = author.Author
|
|
default:
|
|
return nil, fmt.Errorf("unknown class name for parsing: %s", className)
|
|
}
|
|
|
|
results = append(results, domainsearch.SearchResultItem{
|
|
Type: className,
|
|
Entity: entity,
|
|
Score: score,
|
|
})
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (w *weaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
|
// Convert []*domain.Tag to []string
|
|
tags := make([]string, len(work.Tags))
|
|
for i, tag := range work.Tags {
|
|
tags[i] = tag.Name
|
|
}
|
|
|
|
properties := map[string]interface{}{
|
|
"db_id": work.ID,
|
|
"title": work.Title,
|
|
"description": work.Description,
|
|
"language": work.Language,
|
|
"status": work.Status,
|
|
"createdAt": work.CreatedAt,
|
|
"updatedAt": work.UpdatedAt,
|
|
"tags": tags,
|
|
"content": content, // Assuming content is passed in
|
|
}
|
|
|
|
_, err := w.client.Data().Creator().
|
|
WithClassName("Work").
|
|
WithID(strconv.FormatUint(uint64(work.ID), 10)).
|
|
WithProperties(properties).
|
|
Do(ctx)
|
|
|
|
return err
|
|
}
|
|
|
|
func contains(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|