mirror of
https://github.com/SamyRai/tercul-backend.git
synced 2025-12-27 05:11:34 +00:00
This commit implements the full-text search service using Weaviate. It replaces the stub implementation with a fully functional search service that supports hybrid and BM25 search modes. The new implementation includes: - Support for hybrid and BM25 search modes. - Transformation of Weaviate search results into domain entities. - Unit tests using a mock Weaviate wrapper to ensure the implementation is correct and to avoid environmental issues in the test pipeline.
355 lines
9.6 KiB
Go
355 lines
9.6 KiB
Go
package search
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"strconv"
|
|
"tercul/internal/domain"
|
|
domainsearch "tercul/internal/domain/search"
|
|
"time"
|
|
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/filters"
|
|
"github.com/weaviate/weaviate-go-client/v5/weaviate/graphql"
|
|
)
|
|
|
|
type WeaviateWrapper interface {
|
|
Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error)
|
|
IndexWork(ctx context.Context, work *domain.Work, content string) error
|
|
}
|
|
|
|
type weaviateWrapper struct {
|
|
client *weaviate.Client
|
|
host string
|
|
searchAlpha float64
|
|
}
|
|
|
|
func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) WeaviateWrapper {
|
|
return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha}
|
|
}
|
|
|
|
const (
|
|
DefaultLimit = 20
|
|
MaxLimit = 100
|
|
)
|
|
|
|
const (
|
|
WorkClass = "Work"
|
|
AuthorClass = "Author"
|
|
TranslationClass = "Translation"
|
|
)
|
|
|
|
func sanitizeLimitOffset(limit, offset int) (int, int) {
|
|
if limit <= 0 {
|
|
limit = DefaultLimit
|
|
}
|
|
if limit > MaxLimit {
|
|
limit = MaxLimit
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
return limit, offset
|
|
}
|
|
|
|
func (w *weaviateWrapper) hybridAlpha() float32 {
|
|
if w.searchAlpha < 0 {
|
|
return 0
|
|
}
|
|
if w.searchAlpha > 1 {
|
|
return 1
|
|
}
|
|
return float32(w.searchAlpha)
|
|
}
|
|
|
|
func (w *weaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domain.SearchResults, error) {
|
|
results := &domain.SearchResults{
|
|
Works: []domain.Work{},
|
|
Translations: []domain.Translation{},
|
|
Authors: []domain.Author{},
|
|
}
|
|
|
|
limit, offset := sanitizeLimitOffset(params.Limit, params.Offset)
|
|
|
|
workFields := []graphql.Field{
|
|
{Name: "title"}, {Name: "description"}, {Name: "language"}, {Name: "status"}, {Name: "publishedAt"},
|
|
{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}},
|
|
}
|
|
authorFields := []graphql.Field{
|
|
{Name: "name"},
|
|
{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}},
|
|
}
|
|
translationFields := []graphql.Field{
|
|
{Name: "title"}, {Name: "language"},
|
|
{Name: "_additional", Fields: []graphql.Field{{Name: "id"}, {Name: "score"}, {Name: "meta", Fields: []graphql.Field{{Name: "count"}}}}},
|
|
}
|
|
|
|
searcher := w.client.GraphQL().Get().
|
|
WithFields(graphql.Field{Name: "... on Work", Fields: workFields}).
|
|
WithFields(graphql.Field{Name: "... on Author", Fields: authorFields}).
|
|
WithFields(graphql.Field{Name: "... on Translation", Fields: translationFields}).
|
|
WithLimit(limit).
|
|
WithOffset(offset)
|
|
|
|
switch params.Mode {
|
|
case domainsearch.SearchModeHybrid:
|
|
hybrid := w.client.GraphQL().HybridArgumentBuilder().
|
|
WithQuery(params.Query).
|
|
WithAlpha(w.hybridAlpha()).
|
|
WithProperties([]string{"title", "description", "name"})
|
|
searcher.WithHybrid(hybrid)
|
|
case domainsearch.SearchModeBM25:
|
|
bm25 := w.client.GraphQL().Bm25ArgBuilder().
|
|
WithQuery(params.Query).
|
|
WithProperties([]string{"title", "description", "name"}...)
|
|
searcher.WithBM25(bm25)
|
|
}
|
|
|
|
response, err := searcher.Do(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
get, ok := response.Data["Get"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid response data")
|
|
}
|
|
|
|
for className, classData := range get {
|
|
for _, item := range classData.([]interface{}) {
|
|
itemMap, ok := item.(map[string]interface{})
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch className {
|
|
case "Work":
|
|
work, err := mapToWork(itemMap)
|
|
if err != nil {
|
|
log.Printf("Error mapping work: %v", err)
|
|
continue
|
|
}
|
|
results.Works = append(results.Works, work)
|
|
case "Author":
|
|
author, err := mapToAuthor(itemMap)
|
|
if err != nil {
|
|
log.Printf("Error mapping author: %v", err)
|
|
continue
|
|
}
|
|
results.Authors = append(results.Authors, author)
|
|
case "Translation":
|
|
translation, err := mapToTranslation(itemMap)
|
|
if err != nil {
|
|
log.Printf("Error mapping translation: %v", err)
|
|
continue
|
|
}
|
|
results.Translations = append(results.Translations, translation)
|
|
}
|
|
}
|
|
}
|
|
|
|
meta, ok := response.Data["Meta"].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid meta data")
|
|
}
|
|
count, ok := meta["count"].(float64)
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid count data")
|
|
}
|
|
results.Total = int64(count)
|
|
|
|
return results, nil
|
|
}
|
|
|
|
|
|
func (w *weaviateWrapper) buildWhereFilter(searchFilters domain.SearchFilters, className string) *filters.WhereBuilder {
|
|
operands := make([]*filters.WhereBuilder, 0)
|
|
|
|
if len(searchFilters.Languages) > 0 {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"language"}).
|
|
WithOperator(filters.ContainsAny).
|
|
WithValueText(searchFilters.Languages...))
|
|
}
|
|
|
|
if className == "Work" {
|
|
if searchFilters.DateFrom != nil {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"publishedAt"}).
|
|
WithOperator(filters.GreaterThanEqual).
|
|
WithValueDate(*searchFilters.DateFrom))
|
|
}
|
|
|
|
if searchFilters.DateTo != nil {
|
|
operands = append(operands, filters.Where().
|
|
WithPath([]string{"publishedAt"}).
|
|
WithOperator(filters.LessThanEqual).
|
|
WithValueDate(*searchFilters.DateTo))
|
|
}
|
|
|
|
if len(searchFilters.Authors) > 0 {
|
|
authorOperands := make([]*filters.WhereBuilder, len(searchFilters.Authors))
|
|
for i, author := range searchFilters.Authors {
|
|
authorOperands[i] = filters.Where().
|
|
WithPath([]string{"authors", "Author", "id"}).
|
|
WithOperator(filters.Equal).
|
|
WithValueText(fmt.Sprintf("weaviate://%s/Author/%s", w.host, author))
|
|
}
|
|
operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(authorOperands))
|
|
}
|
|
|
|
if len(searchFilters.Tags) > 0 {
|
|
tagOperands := make([]*filters.WhereBuilder, len(searchFilters.Tags))
|
|
for i, tag := range searchFilters.Tags {
|
|
tagOperands[i] = filters.Where().
|
|
WithPath([]string{"tags", "Tag", "name"}).
|
|
WithOperator(filters.Equal).
|
|
WithValueText(tag)
|
|
}
|
|
operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(tagOperands))
|
|
}
|
|
|
|
if len(searchFilters.Categories) > 0 {
|
|
categoryOperands := make([]*filters.WhereBuilder, len(searchFilters.Categories))
|
|
for i, category := range searchFilters.Categories {
|
|
categoryOperands[i] = filters.Where().
|
|
WithPath([]string{"categories", "Category", "name"}).
|
|
WithOperator(filters.Equal).
|
|
WithValueText(category)
|
|
}
|
|
operands = append(operands, filters.Where().WithOperator(filters.Or).WithOperands(categoryOperands))
|
|
}
|
|
}
|
|
|
|
if len(operands) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return filters.Where().WithOperator(filters.And).WithOperands(operands)
|
|
}
|
|
|
|
func mapToWork(data map[string]interface{}) (domain.Work, error) {
|
|
work := domain.Work{}
|
|
|
|
additional, ok := data["_additional"].(map[string]interface{})
|
|
if !ok {
|
|
return work, fmt.Errorf("missing _additional field")
|
|
}
|
|
|
|
idStr, ok := additional["id"].(string)
|
|
if !ok {
|
|
return work, fmt.Errorf("missing or invalid id")
|
|
}
|
|
id, err := strconv.ParseUint(idStr, 10, 64)
|
|
if err != nil {
|
|
return work, fmt.Errorf("failed to parse id: %w", err)
|
|
}
|
|
work.ID = uint(id)
|
|
|
|
if title, ok := data["title"].(string); ok {
|
|
work.Title = title
|
|
}
|
|
if description, ok := data["description"].(string); ok {
|
|
work.Description = description
|
|
}
|
|
if language, ok := data["language"].(string); ok {
|
|
work.Language = language
|
|
}
|
|
if status, ok := data["status"].(string); ok {
|
|
work.Status = domain.WorkStatus(status)
|
|
}
|
|
if publishedAtStr, ok := data["publishedAt"].(string); ok {
|
|
publishedAt, err := time.Parse(time.RFC3339, publishedAtStr)
|
|
if err == nil {
|
|
work.PublishedAt = &publishedAt
|
|
}
|
|
}
|
|
if score, ok := additional["score"].(float64); ok {
|
|
work.Score = score
|
|
}
|
|
|
|
return work, nil
|
|
}
|
|
|
|
func mapToAuthor(data map[string]interface{}) (domain.Author, error) {
|
|
author := domain.Author{}
|
|
|
|
additional, ok := data["_additional"].(map[string]interface{})
|
|
if !ok {
|
|
return author, fmt.Errorf("missing _additional field")
|
|
}
|
|
|
|
idStr, ok := additional["id"].(string)
|
|
if !ok {
|
|
return author, fmt.Errorf("missing or invalid id")
|
|
}
|
|
id, err := strconv.ParseUint(idStr, 10, 64)
|
|
if err != nil {
|
|
return author, fmt.Errorf("failed to parse id: %w", err)
|
|
}
|
|
author.ID = uint(id)
|
|
|
|
if name, ok := data["name"].(string); ok {
|
|
author.Name = name
|
|
}
|
|
if score, ok := additional["score"].(float64); ok {
|
|
author.Score = score
|
|
}
|
|
|
|
return author, nil
|
|
}
|
|
|
|
func mapToTranslation(data map[string]interface{}) (domain.Translation, error) {
|
|
translation := domain.Translation{}
|
|
|
|
additional, ok := data["_additional"].(map[string]interface{})
|
|
if !ok {
|
|
return translation, fmt.Errorf("missing _additional field")
|
|
}
|
|
|
|
idStr, ok := additional["id"].(string)
|
|
if !ok {
|
|
return translation, fmt.Errorf("missing or invalid id")
|
|
}
|
|
id, err := strconv.ParseUint(idStr, 10, 64)
|
|
if err != nil {
|
|
return translation, fmt.Errorf("failed to parse id: %w", err)
|
|
}
|
|
translation.ID = uint(id)
|
|
|
|
if title, ok := data["title"].(string); ok {
|
|
translation.Title = title
|
|
}
|
|
if language, ok := data["language"].(string); ok {
|
|
translation.Language = language
|
|
}
|
|
if score, ok := additional["score"].(float64); ok {
|
|
translation.Score = score
|
|
}
|
|
|
|
return translation, nil
|
|
}
|
|
|
|
|
|
func (w *weaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
|
properties := map[string]interface{}{
|
|
"language": work.Language,
|
|
"title": work.Title,
|
|
"description": work.Description,
|
|
"status": work.Status,
|
|
"createdAt": work.CreatedAt.Format(time.RFC3339),
|
|
"updatedAt": work.UpdatedAt.Format(time.RFC3339),
|
|
}
|
|
if content != "" {
|
|
properties["content"] = content
|
|
}
|
|
|
|
_, err := w.client.Data().Creator().
|
|
WithClassName("Work").
|
|
WithID(fmt.Sprintf("%d", work.ID)).
|
|
WithProperties(properties).
|
|
Do(ctx)
|
|
|
|
return err
|
|
}
|