mirror of
https://github.com/SamyRai/tercul-backend.git
synced 2025-12-27 00:31:35 +00:00
feat: Add SearchAlpha support for hybrid search tuning
Add configurable SearchAlpha parameter for hybrid search tuning - Added SearchAlpha config parameter (default: 0.7) for tuning BM25 vs vector search balance - Updated NewWeaviateWrapper to accept host and searchAlpha parameters - Enhanced hybrid search with configurable alpha parameter via WithAlpha() - Fixed all type mismatches in mocks and tests to use domainsearch.SearchResults - Updated GraphQL resolver to use new SearchResults structure with SearchResultItem - All tests and vet checks passing Closes #30
This commit is contained in:
parent
d7390053b9
commit
d0852353b7
@ -117,7 +117,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Create search client
|
||||
searchClient := search.NewWeaviateWrapper(weaviateClient)
|
||||
searchClient := search.NewWeaviateWrapper(weaviateClient, cfg.WeaviateHost, cfg.SearchAlpha)
|
||||
|
||||
// Create repositories
|
||||
repos := dbsql.NewRepositories(database, cfg)
|
||||
|
||||
@ -21,8 +21,10 @@ import (
|
||||
"tercul/internal/app/translation"
|
||||
"tercul/internal/app/user"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
platform_auth "tercul/internal/platform/auth"
|
||||
"tercul/internal/platform/log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Register is the resolver for the register field.
|
||||
@ -1968,8 +1970,97 @@ func (r *queryResolver) Comments(ctx context.Context, workID *string, translatio
|
||||
|
||||
// Search is the resolver for the search field.
|
||||
func (r *queryResolver) Search(ctx context.Context, query string, limit *int32, offset *int32, filters *model.SearchFilters) (*model.SearchResults, error) {
|
||||
// Commenting out the body of this function to allow gqlgen to regenerate.
|
||||
return nil, nil
|
||||
page := 1
|
||||
pageSize := 20
|
||||
if limit != nil {
|
||||
pageSize = int(*limit)
|
||||
}
|
||||
if offset != nil {
|
||||
page = int(*offset)/pageSize + 1
|
||||
}
|
||||
|
||||
var searchFilters domain.SearchFilters
|
||||
if filters != nil {
|
||||
searchFilters.Languages = filters.Languages
|
||||
searchFilters.Categories = filters.Categories
|
||||
searchFilters.Tags = filters.Tags
|
||||
searchFilters.Authors = filters.Authors
|
||||
|
||||
if filters.DateFrom != nil {
|
||||
t, err := time.Parse(time.RFC3339, *filters.DateFrom)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid DateFrom format: %w", err)
|
||||
}
|
||||
searchFilters.DateFrom = &t
|
||||
}
|
||||
if filters.DateTo != nil {
|
||||
t, err := time.Parse(time.RFC3339, *filters.DateTo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid DateTo format: %w", err)
|
||||
}
|
||||
searchFilters.DateTo = &t
|
||||
}
|
||||
}
|
||||
|
||||
params := domainsearch.SearchParams{
|
||||
Query: query,
|
||||
Filters: domainsearch.SearchFilters{
|
||||
Languages: searchFilters.Languages,
|
||||
Tags: searchFilters.Tags,
|
||||
Categories: searchFilters.Categories,
|
||||
Authors: searchFilters.Authors,
|
||||
DateFrom: searchFilters.DateFrom,
|
||||
DateTo: searchFilters.DateTo,
|
||||
},
|
||||
Limit: pageSize,
|
||||
Offset: (page - 1) * pageSize,
|
||||
}
|
||||
results, err := r.App.Search.Search(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var works []*model.Work
|
||||
var translations []*model.Translation
|
||||
var authors []*model.Author
|
||||
|
||||
for _, item := range results.Results {
|
||||
switch item.Type {
|
||||
case "Work":
|
||||
if work, ok := item.Entity.(domain.Work); ok {
|
||||
works = append(works, &model.Work{
|
||||
ID: fmt.Sprintf("%d", work.ID),
|
||||
Name: work.Title,
|
||||
Language: work.Language,
|
||||
})
|
||||
}
|
||||
case "Translation":
|
||||
if translation, ok := item.Entity.(domain.Translation); ok {
|
||||
translations = append(translations, &model.Translation{
|
||||
ID: fmt.Sprintf("%d", translation.ID),
|
||||
Name: translation.Title,
|
||||
Language: translation.Language,
|
||||
Content: &translation.Content,
|
||||
WorkID: fmt.Sprintf("%d", translation.TranslatableID),
|
||||
})
|
||||
}
|
||||
case "Author":
|
||||
if author, ok := item.Entity.(domain.Author); ok {
|
||||
authors = append(authors, &model.Author{
|
||||
ID: fmt.Sprintf("%d", author.ID),
|
||||
Name: author.Name,
|
||||
Language: author.Language,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &model.SearchResults{
|
||||
Works: works,
|
||||
Translations: translations,
|
||||
Authors: authors,
|
||||
Total: int32(results.TotalResults),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TrendingWorks is the resolver for the trendingWorks field.
|
||||
|
||||
@ -150,8 +150,8 @@ func (m *mockUserRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) er
|
||||
|
||||
type mockSearchClient struct{ mock.Mock }
|
||||
|
||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||
args := m.Called(ctx, work, content)
|
||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error {
|
||||
args := m.Called(ctx, work, pipeline)
|
||||
return args.Error(0)
|
||||
}
|
||||
func (m *mockSearchClient) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
||||
|
||||
@ -29,8 +29,6 @@ func NewService(searchClient domainsearch.SearchClient, localization *localizati
|
||||
|
||||
// Search performs a search across all searchable entities.
|
||||
func (s *service) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
||||
// For now, this is a mock implementation that returns empty results.
|
||||
// TODO: Implement the actual search logic.
|
||||
return s.searchClient.Search(ctx, params)
|
||||
}
|
||||
|
||||
|
||||
@ -55,6 +55,45 @@ func (m *mockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work,
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestSearchService_Search(t *testing.T) {
|
||||
localizationRepo := new(mockLocalizationRepository)
|
||||
localizationService := localization.NewService(localizationRepo)
|
||||
weaviateWrapper := new(mockWeaviateWrapper)
|
||||
service := NewService(weaviateWrapper, localizationService)
|
||||
|
||||
ctx := context.Background()
|
||||
testQuery := "test query"
|
||||
testFilters := domainsearch.SearchFilters{
|
||||
Languages: []string{"en"},
|
||||
Authors: []string{"1"},
|
||||
Tags: []string{"test-tag"},
|
||||
Categories: []string{"test-category"},
|
||||
}
|
||||
expectedResults := &domainsearch.SearchResults{
|
||||
Results: []domainsearch.SearchResultItem{
|
||||
{Type: "Work", Entity: domain.Work{Title: "Test Work"}, Score: 0.9},
|
||||
{Type: "Author", Entity: domain.Author{Name: "Test Author"}, Score: 0.8},
|
||||
},
|
||||
TotalResults: 2,
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
params := domainsearch.SearchParams{
|
||||
Query: testQuery,
|
||||
Filters: testFilters,
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
}
|
||||
weaviateWrapper.On("Search", ctx, params).Return(expectedResults, nil)
|
||||
|
||||
results, err := service.Search(ctx, params)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedResults, results)
|
||||
weaviateWrapper.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestIndexService_IndexWork(t *testing.T) {
|
||||
localizationRepo := new(mockLocalizationRepository)
|
||||
localizationService := localization.NewService(localizationRepo)
|
||||
|
||||
@ -50,6 +50,7 @@ type BaseModel struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Score float64 `gorm:"-"`
|
||||
}
|
||||
|
||||
// TranslatableModel extends BaseModel with language support
|
||||
|
||||
@ -28,8 +28,9 @@ type Config struct {
|
||||
NLPMemoryCacheCap int `mapstructure:"NLP_MEMORY_CACHE_CAP"`
|
||||
NLPRedisCacheTTLSeconds int `mapstructure:"NLP_REDIS_CACHE_TTL_SECONDS"`
|
||||
NLPUseLingua bool `mapstructure:"NLP_USE_LINGUA"`
|
||||
NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"`
|
||||
BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"`
|
||||
NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"`
|
||||
BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"`
|
||||
SearchAlpha float64 `mapstructure:"SEARCH_ALPHA"`
|
||||
}
|
||||
|
||||
// Global config instance
|
||||
@ -62,6 +63,7 @@ func LoadConfig() (*Config, error) {
|
||||
v.SetDefault("NLP_USE_LINGUA", true)
|
||||
v.SetDefault("NLP_USE_TFIDF", true)
|
||||
v.SetDefault("BLEVE_INDEX_PATH", "./bleve_index")
|
||||
v.SetDefault("SEARCH_ALPHA", 0.7)
|
||||
|
||||
v.AutomaticEnv()
|
||||
|
||||
|
||||
@ -16,12 +16,25 @@ import (
|
||||
)
|
||||
|
||||
type weaviateWrapper struct {
|
||||
client *weaviate.Client
|
||||
client *weaviate.Client
|
||||
host string
|
||||
searchAlpha float64
|
||||
}
|
||||
|
||||
// NewWeaviateWrapper creates a new WeaviateWrapper that implements the SearchClient interface.
|
||||
func NewWeaviateWrapper(client *weaviate.Client) domainsearch.SearchClient {
|
||||
return &weaviateWrapper{client: client}
|
||||
func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) domainsearch.SearchClient {
|
||||
return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha}
|
||||
}
|
||||
|
||||
// hybridAlpha returns the alpha value for hybrid search, clamped between 0 and 1.
|
||||
func (w *weaviateWrapper) hybridAlpha() float32 {
|
||||
if w.searchAlpha < 0 {
|
||||
return 0
|
||||
}
|
||||
if w.searchAlpha > 1 {
|
||||
return 1
|
||||
}
|
||||
return float32(w.searchAlpha)
|
||||
}
|
||||
|
||||
// Search performs a multi-class search against the Weaviate instance.
|
||||
@ -144,7 +157,9 @@ func (w *weaviateWrapper) addSearchArguments(searcher *graphql.GetBuilder, param
|
||||
nearText := w.client.GraphQL().NearTextArgBuilder().WithConcepts(params.Concepts)
|
||||
searcher.WithNearText(nearText)
|
||||
default:
|
||||
hybrid := w.client.GraphQL().HybridArgumentBuilder().WithQuery(params.Query)
|
||||
hybrid := w.client.GraphQL().HybridArgumentBuilder().
|
||||
WithQuery(params.Query).
|
||||
WithAlpha(w.hybridAlpha())
|
||||
searcher.WithHybrid(hybrid)
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ func (s *WeaviateWrapperIntegrationTestSuite) SetupSuite() {
|
||||
require.NoError(s.T(), err)
|
||||
s.client = client
|
||||
|
||||
s.wrapper = search.NewWeaviateWrapper(client)
|
||||
s.wrapper = search.NewWeaviateWrapper(client, fmt.Sprintf("%s:%s", host, port.Port()), 0.7)
|
||||
|
||||
s.createTestSchema(ctx)
|
||||
s.seedTestData(ctx)
|
||||
|
||||
26
internal/platform/search/weaviate_wrapper_mock.go
Normal file
26
internal/platform/search/weaviate_wrapper_mock.go
Normal file
@ -0,0 +1,26 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockWeaviateWrapper struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockWeaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
||||
args := m.Called(ctx, params)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*domainsearch.SearchResults), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||
args := m.Called(ctx, work, content)
|
||||
return args.Error(0)
|
||||
}
|
||||
51
internal/platform/search/weaviate_wrapper_test.go
Normal file
51
internal/platform/search/weaviate_wrapper_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"tercul/internal/domain"
|
||||
domainsearch "tercul/internal/domain/search"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWeaviateWrapper_Search(t *testing.T) {
|
||||
mockWrapper := new(MockWeaviateWrapper)
|
||||
expectedResults := &domainsearch.SearchResults{
|
||||
Results: []domainsearch.SearchResultItem{
|
||||
{
|
||||
Type: "Work",
|
||||
Entity: domain.Work{
|
||||
Title: "Work 1",
|
||||
Description: "alpha beta",
|
||||
TranslatableModel: domain.TranslatableModel{Language: "en"},
|
||||
},
|
||||
Score: 0.95,
|
||||
},
|
||||
},
|
||||
TotalResults: 1,
|
||||
Limit: 1,
|
||||
Offset: 0,
|
||||
}
|
||||
params := domainsearch.SearchParams{
|
||||
Query: "alpha",
|
||||
Mode: domainsearch.SearchModeHybrid,
|
||||
Filters: domainsearch.SearchFilters{
|
||||
Languages: []string{"en"},
|
||||
},
|
||||
Limit: 1,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
mockWrapper.On("Search", context.Background(), params).Return(expectedResults, nil)
|
||||
|
||||
results, err := mockWrapper.Search(context.Background(), params)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Equal(t, 1, len(results.Results))
|
||||
assert.Equal(t, "Work", results.Results[0].Type)
|
||||
work := results.Results[0].Entity.(domain.Work)
|
||||
assert.Equal(t, "Work 1", work.Title)
|
||||
|
||||
mockWrapper.AssertExpectations(t)
|
||||
}
|
||||
@ -41,13 +41,10 @@ import (
|
||||
type mockSearchClient struct{}
|
||||
|
||||
func (m *mockSearchClient) Search(ctx context.Context, params search.SearchParams) (*search.SearchResults, error) {
|
||||
return &search.SearchResults{
|
||||
Results: []search.SearchResultItem{},
|
||||
TotalResults: 0,
|
||||
}, nil
|
||||
return &search.SearchResults{}, nil
|
||||
}
|
||||
|
||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user