tercul-backend/internal/platform/search/weaviate_wrapper.go
Damir Mukimov d50722dad5
Some checks failed
Test / Integration Tests (push) Successful in 4s
Build / Build Binary (push) Failing after 2m9s
Docker Build / Build Docker Image (push) Failing after 2m32s
Test / Unit Tests (push) Failing after 3m12s
Lint / Go Lint (push) Failing after 1m0s
Refactor ID handling to use UUIDs across the application
- Updated database models and repositories to replace uint IDs with UUIDs.
- Modified test fixtures to generate and use UUIDs for authors, translations, users, and works.
- Adjusted mock implementations to align with the new UUID structure.
- Ensured all relevant functions and methods are updated to handle UUIDs correctly.
- Added necessary imports for UUID handling in various files.
2025-12-27 00:33:34 +01:00

409 lines
12 KiB
Go

package search
import (
"context"
"encoding/json"
"fmt"
"sort"
"strconv"
"tercul/internal/domain"
domainsearch "tercul/internal/domain/search"
"github.com/google/uuid"
"github.com/weaviate/weaviate-go-client/v5/weaviate"
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
"github.com/weaviate/weaviate/entities/models"
)
type weaviateWrapper struct {
client *weaviate.Client
host string
searchAlpha float64
}
// NewWeaviateWrapper creates a new WeaviateWrapper that implements the SearchClient interface.
func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) domainsearch.SearchClient {
return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha}
}
// hybridAlpha returns the alpha value for hybrid search, clamped between 0 and 1.
func (w *weaviateWrapper) hybridAlpha() float32 {
if w.searchAlpha < 0 {
return 0
}
if w.searchAlpha > 1 {
return 1
}
return float32(w.searchAlpha)
}
// Search performs a multi-class search against the Weaviate instance.
func (w *weaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
allResults := make([]domainsearch.SearchResultItem, 0)
// Determine which entity types to search for. If a type filter is present, use it.
var searchTypes []string
if len(params.Filters.Types) > 0 {
searchTypes = params.Filters.Types
} else {
searchTypes = []string{"Work", "Translation", "Author"}
}
if contains(searchTypes, "Work") {
workResults, err := w.searchWorks(ctx, &params)
if err != nil {
return nil, err
}
allResults = append(allResults, workResults...)
}
if contains(searchTypes, "Translation") {
translationResults, err := w.searchTranslations(ctx, &params)
if err != nil {
return nil, err
}
allResults = append(allResults, translationResults...)
}
if contains(searchTypes, "Author") {
authorResults, err := w.searchAuthors(ctx, &params)
if err != nil {
return nil, err
}
allResults = append(allResults, authorResults...)
}
// --- Sort by Relevance Score ---
sort.Slice(allResults, func(i, j int) bool {
return allResults[i].Score > allResults[j].Score // Descending order
})
totalResults := int64(len(allResults))
// --- Paginate In-Memory ---
paginatedResults := make([]domainsearch.SearchResultItem, 0)
start := params.Offset
end := params.Offset + params.Limit
if start < len(allResults) {
if end > len(allResults) {
end = len(allResults)
}
paginatedResults = allResults[start:end]
}
return &domainsearch.SearchResults{
Results: paginatedResults,
TotalResults: totalResults,
Limit: params.Limit,
Offset: params.Offset,
}, nil
}
func (w *weaviateWrapper) searchWorks(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) {
fields := []graphql.Field{
{Name: "db_id"}, {Name: "title"}, {Name: "description"}, {Name: "language"},
{Name: "status"}, {Name: "createdAt"}, {Name: "updatedAt"}, {Name: "tags"},
{Name: "_additional", Fields: []graphql.Field{{Name: "score"}}},
}
searcher := w.client.GraphQL().Get().WithClassName("Work").WithFields(fields...)
w.addSearchArguments(searcher, params, "Work", []string{"title", "description"})
resp, err := searcher.Do(ctx)
if err != nil {
return nil, fmt.Errorf("failed to search works: %w", err)
}
return w.parseGraphQLResponse(resp, "Work")
}
func (w *weaviateWrapper) searchTranslations(
ctx context.Context,
params *domainsearch.SearchParams,
) ([]domainsearch.SearchResultItem, error) {
fields := []graphql.Field{
{Name: "db_id"}, {Name: "title"}, {Name: "content"}, {Name: "language"}, {Name: "status"},
{Name: "_additional", Fields: []graphql.Field{{Name: "score"}}},
}
searcher := w.client.GraphQL().Get().WithClassName("Translation").WithFields(fields...)
w.addSearchArguments(searcher, params, "Translation", []string{"title", "content"})
resp, err := searcher.Do(ctx)
if err != nil {
return nil, fmt.Errorf("failed to search translations: %w", err)
}
return w.parseGraphQLResponse(resp, "Translation")
}
func (w *weaviateWrapper) searchAuthors(ctx context.Context, params *domainsearch.SearchParams) ([]domainsearch.SearchResultItem, error) {
fields := []graphql.Field{
{Name: "db_id"}, {Name: "name"}, {Name: "biography"},
{Name: "_additional", Fields: []graphql.Field{{Name: "score"}}},
}
searcher := w.client.GraphQL().Get().WithClassName("Author").WithFields(fields...)
w.addSearchArguments(searcher, params, "Author", []string{"name", "biography"})
// Authors should not be filtered by language
if len(params.Filters.Languages) > 0 {
return []domainsearch.SearchResultItem{}, nil
}
resp, err := searcher.Do(ctx)
if err != nil {
return nil, fmt.Errorf("failed to search authors: %w", err)
}
return w.parseGraphQLResponse(resp, "Author")
}
func (w *weaviateWrapper) addSearchArguments(
searcher *graphql.GetBuilder,
params *domainsearch.SearchParams,
className string,
searchFields []string,
) {
if params.Query != "" || len(params.Concepts) > 0 {
switch params.Mode {
case domainsearch.SearchModeBM25:
bm25 := w.client.GraphQL().Bm25ArgBuilder().WithQuery(params.Query).WithProperties(searchFields...)
searcher.WithBM25(bm25)
case domainsearch.SearchModeVector:
nearText := w.client.GraphQL().NearTextArgBuilder().WithConcepts(params.Concepts)
searcher.WithNearText(nearText)
default:
hybrid := w.client.GraphQL().HybridArgumentBuilder().
WithQuery(params.Query).
WithAlpha(w.hybridAlpha())
searcher.WithHybrid(hybrid)
}
}
where := buildWhereFilter(&params.Filters, className)
if where != nil {
searcher.WithWhere(where)
}
}
// buildWhereFilter constructs the 'where' clause for the GraphQL query based on the search parameters.
func buildWhereFilter(searchFilters *domainsearch.SearchFilters, className string) *filters.WhereBuilder {
if searchFilters == nil {
return nil
}
operands := make([]*filters.WhereBuilder, 0)
if (className == "Work" || className == "Translation") && len(searchFilters.Languages) > 0 {
operands = append(operands, filters.Where().
WithPath([]string{"language"}).
WithOperator(filters.ContainsAny).
WithValueText(searchFilters.Languages...))
}
if className == "Work" {
if searchFilters.DateFrom != nil && !searchFilters.DateFrom.IsZero() {
operands = append(operands, filters.Where().
WithPath([]string{"createdAt"}).
WithOperator(filters.GreaterThanEqual).
WithValueDate(*searchFilters.DateFrom))
}
if searchFilters.DateTo != nil && !searchFilters.DateTo.IsZero() {
operands = append(operands, filters.Where().
WithPath([]string{"createdAt"}).
WithOperator(filters.LessThanEqual).
WithValueDate(*searchFilters.DateTo))
}
if len(searchFilters.Tags) > 0 {
operands = append(operands, filters.Where().
WithPath([]string{"tags"}).
WithOperator(filters.ContainsAny).
WithValueText(searchFilters.Tags...))
}
}
if className == "Author" && len(searchFilters.Authors) > 0 {
operands = append(operands, filters.Where().
WithPath([]string{"name"}).
WithOperator(filters.ContainsAny).
WithValueText(searchFilters.Authors...))
}
if len(operands) == 0 {
return nil
}
return filters.Where().WithOperator(filters.And).WithOperands(operands)
}
type weaviateResult struct {
Additional struct {
Score string `json:"score"`
} `json:"_additional"`
Properties map[string]interface{} `json:"-"`
}
func (r *weaviateResult) UnmarshalJSON(data []byte) error {
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
if add, ok := raw["_additional"]; ok {
addBytes, err := json.Marshal(add)
if err != nil {
return err
}
if err := json.Unmarshal(addBytes, &r.Additional); err != nil {
return err
}
}
delete(raw, "_additional")
r.Properties = raw
return nil
}
// Temporary struct to handle the mismatch between Weaviate's string array for tags
// and the domain's []*Tag struct.
type workWithDBIDAndStringTags struct {
domain.Work
DBID json.Number `json:"db_id"`
Tags []string `json:"tags"` // This captures the tags as strings
}
type translationWithDBID struct {
domain.Translation
DBID json.Number `json:"db_id"`
}
type authorWithDBID struct {
domain.Author
DBID json.Number `json:"db_id"`
Biography string `json:"biography"`
}
//nolint:gocyclo // Complex parsing logic
func (w *weaviateWrapper) parseGraphQLResponse(resp *models.GraphQLResponse, className string) ([]domainsearch.SearchResultItem, error) {
var results []domainsearch.SearchResultItem
get, ok := resp.Data["Get"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected response format: 'Get' not found")
}
classData, ok := get[className].([]interface{})
if !ok {
return results, nil
}
for _, itemData := range classData {
itemBytes, err := json.Marshal(itemData)
if err != nil {
return nil, fmt.Errorf("failed to marshal search item: %w", err)
}
var tempResult weaviateResult
if err := json.Unmarshal(itemBytes, &tempResult); err != nil {
return nil, fmt.Errorf("failed to unmarshal search item: %w", err)
}
score, err := strconv.ParseFloat(tempResult.Additional.Score, 64)
if err != nil {
// In BM25, score can be nil if not found, treat as 0
score = 0
}
propBytes, err := json.Marshal(tempResult.Properties)
if err != nil {
return nil, fmt.Errorf("failed to marshal item properties: %w", err)
}
var entity interface{}
switch className {
case "Work":
var tempWork workWithDBIDAndStringTags
if err := json.Unmarshal(propBytes, &tempWork); err != nil {
return nil, fmt.Errorf("failed to unmarshal work: %w", err)
}
idStr := tempWork.DBID.String()
id, err := uuid.Parse(idStr)
if err != nil {
return nil, fmt.Errorf("failed to parse work db_id as UUID: %w", err)
}
tempWork.Work.ID = id
// Convert []string to []*domain.Tag
finalTags := make([]*domain.Tag, len(tempWork.Tags))
for i, tagName := range tempWork.Tags {
finalTags[i] = &domain.Tag{Name: tagName}
}
tempWork.Work.Tags = finalTags
entity = tempWork.Work
case "Translation":
var translation translationWithDBID
if err := json.Unmarshal(propBytes, &translation); err != nil {
return nil, fmt.Errorf("failed to unmarshal translation: %w", err)
}
idStr := translation.DBID.String()
id, err := uuid.Parse(idStr)
if err != nil {
return nil, fmt.Errorf("failed to parse translation db_id as UUID: %w", err)
}
translation.Translation.ID = id
entity = translation.Translation
case "Author":
var author authorWithDBID
if err := json.Unmarshal(propBytes, &author); err != nil {
return nil, fmt.Errorf("failed to unmarshal author: %w", err)
}
idStr := author.DBID.String()
id, err := uuid.Parse(idStr)
if err != nil {
return nil, fmt.Errorf("failed to parse author db_id as UUID: %w", err)
}
author.Author.ID = id
entity = author.Author
default:
return nil, fmt.Errorf("unknown class name for parsing: %s", className)
}
results = append(results, domainsearch.SearchResultItem{
Type: className,
Entity: entity,
Score: score,
})
}
return results, nil
}
func (w *weaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
// Convert []*domain.Tag to []string
tags := make([]string, len(work.Tags))
for i, tag := range work.Tags {
tags[i] = tag.Name
}
properties := map[string]interface{}{
"db_id": work.ID,
"title": work.Title,
"description": work.Description,
"language": work.Language,
"status": work.Status,
"createdAt": work.CreatedAt,
"updatedAt": work.UpdatedAt,
"tags": tags,
"content": content, // Assuming content is passed in
}
_, err := w.client.Data().Creator().
WithClassName("Work").
WithID(work.ID.String()).
WithProperties(properties).
Do(ctx)
return err
}
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}