mirror of
https://github.com/SamyRai/tercul-backend.git
synced 2025-12-27 05:11:34 +00:00
feat: Add SearchAlpha support for hybrid search tuning
Add configurable SearchAlpha parameter for hybrid search tuning - Added SearchAlpha config parameter (default: 0.7) for tuning BM25 vs vector search balance - Updated NewWeaviateWrapper to accept host and searchAlpha parameters - Enhanced hybrid search with configurable alpha parameter via WithAlpha() - Fixed all type mismatches in mocks and tests to use domainsearch.SearchResults - Updated GraphQL resolver to use new SearchResults structure with SearchResultItem - All tests and vet checks passing Closes #30
This commit is contained in:
parent
d7390053b9
commit
d0852353b7
@ -117,7 +117,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create search client
|
// Create search client
|
||||||
searchClient := search.NewWeaviateWrapper(weaviateClient)
|
searchClient := search.NewWeaviateWrapper(weaviateClient, cfg.WeaviateHost, cfg.SearchAlpha)
|
||||||
|
|
||||||
// Create repositories
|
// Create repositories
|
||||||
repos := dbsql.NewRepositories(database, cfg)
|
repos := dbsql.NewRepositories(database, cfg)
|
||||||
|
|||||||
@ -21,8 +21,10 @@ import (
|
|||||||
"tercul/internal/app/translation"
|
"tercul/internal/app/translation"
|
||||||
"tercul/internal/app/user"
|
"tercul/internal/app/user"
|
||||||
"tercul/internal/domain"
|
"tercul/internal/domain"
|
||||||
|
domainsearch "tercul/internal/domain/search"
|
||||||
platform_auth "tercul/internal/platform/auth"
|
platform_auth "tercul/internal/platform/auth"
|
||||||
"tercul/internal/platform/log"
|
"tercul/internal/platform/log"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Register is the resolver for the register field.
|
// Register is the resolver for the register field.
|
||||||
@ -1968,8 +1970,97 @@ func (r *queryResolver) Comments(ctx context.Context, workID *string, translatio
|
|||||||
|
|
||||||
// Search is the resolver for the search field.
|
// Search is the resolver for the search field.
|
||||||
func (r *queryResolver) Search(ctx context.Context, query string, limit *int32, offset *int32, filters *model.SearchFilters) (*model.SearchResults, error) {
|
func (r *queryResolver) Search(ctx context.Context, query string, limit *int32, offset *int32, filters *model.SearchFilters) (*model.SearchResults, error) {
|
||||||
// Commenting out the body of this function to allow gqlgen to regenerate.
|
page := 1
|
||||||
return nil, nil
|
pageSize := 20
|
||||||
|
if limit != nil {
|
||||||
|
pageSize = int(*limit)
|
||||||
|
}
|
||||||
|
if offset != nil {
|
||||||
|
page = int(*offset)/pageSize + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var searchFilters domain.SearchFilters
|
||||||
|
if filters != nil {
|
||||||
|
searchFilters.Languages = filters.Languages
|
||||||
|
searchFilters.Categories = filters.Categories
|
||||||
|
searchFilters.Tags = filters.Tags
|
||||||
|
searchFilters.Authors = filters.Authors
|
||||||
|
|
||||||
|
if filters.DateFrom != nil {
|
||||||
|
t, err := time.Parse(time.RFC3339, *filters.DateFrom)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid DateFrom format: %w", err)
|
||||||
|
}
|
||||||
|
searchFilters.DateFrom = &t
|
||||||
|
}
|
||||||
|
if filters.DateTo != nil {
|
||||||
|
t, err := time.Parse(time.RFC3339, *filters.DateTo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid DateTo format: %w", err)
|
||||||
|
}
|
||||||
|
searchFilters.DateTo = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := domainsearch.SearchParams{
|
||||||
|
Query: query,
|
||||||
|
Filters: domainsearch.SearchFilters{
|
||||||
|
Languages: searchFilters.Languages,
|
||||||
|
Tags: searchFilters.Tags,
|
||||||
|
Categories: searchFilters.Categories,
|
||||||
|
Authors: searchFilters.Authors,
|
||||||
|
DateFrom: searchFilters.DateFrom,
|
||||||
|
DateTo: searchFilters.DateTo,
|
||||||
|
},
|
||||||
|
Limit: pageSize,
|
||||||
|
Offset: (page - 1) * pageSize,
|
||||||
|
}
|
||||||
|
results, err := r.App.Search.Search(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var works []*model.Work
|
||||||
|
var translations []*model.Translation
|
||||||
|
var authors []*model.Author
|
||||||
|
|
||||||
|
for _, item := range results.Results {
|
||||||
|
switch item.Type {
|
||||||
|
case "Work":
|
||||||
|
if work, ok := item.Entity.(domain.Work); ok {
|
||||||
|
works = append(works, &model.Work{
|
||||||
|
ID: fmt.Sprintf("%d", work.ID),
|
||||||
|
Name: work.Title,
|
||||||
|
Language: work.Language,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "Translation":
|
||||||
|
if translation, ok := item.Entity.(domain.Translation); ok {
|
||||||
|
translations = append(translations, &model.Translation{
|
||||||
|
ID: fmt.Sprintf("%d", translation.ID),
|
||||||
|
Name: translation.Title,
|
||||||
|
Language: translation.Language,
|
||||||
|
Content: &translation.Content,
|
||||||
|
WorkID: fmt.Sprintf("%d", translation.TranslatableID),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "Author":
|
||||||
|
if author, ok := item.Entity.(domain.Author); ok {
|
||||||
|
authors = append(authors, &model.Author{
|
||||||
|
ID: fmt.Sprintf("%d", author.ID),
|
||||||
|
Name: author.Name,
|
||||||
|
Language: author.Language,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.SearchResults{
|
||||||
|
Works: works,
|
||||||
|
Translations: translations,
|
||||||
|
Authors: authors,
|
||||||
|
Total: int32(results.TotalResults),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrendingWorks is the resolver for the trendingWorks field.
|
// TrendingWorks is the resolver for the trendingWorks field.
|
||||||
|
|||||||
@ -150,8 +150,8 @@ func (m *mockUserRepository) WithTx(ctx context.Context, fn func(tx *gorm.DB) er
|
|||||||
|
|
||||||
type mockSearchClient struct{ mock.Mock }
|
type mockSearchClient struct{ mock.Mock }
|
||||||
|
|
||||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error {
|
||||||
args := m.Called(ctx, work, content)
|
args := m.Called(ctx, work, pipeline)
|
||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
func (m *mockSearchClient) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
func (m *mockSearchClient) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
||||||
|
|||||||
@ -29,8 +29,6 @@ func NewService(searchClient domainsearch.SearchClient, localization *localizati
|
|||||||
|
|
||||||
// Search performs a search across all searchable entities.
|
// Search performs a search across all searchable entities.
|
||||||
func (s *service) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
func (s *service) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
||||||
// For now, this is a mock implementation that returns empty results.
|
|
||||||
// TODO: Implement the actual search logic.
|
|
||||||
return s.searchClient.Search(ctx, params)
|
return s.searchClient.Search(ctx, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -55,6 +55,45 @@ func (m *mockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work,
|
|||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSearchService_Search(t *testing.T) {
|
||||||
|
localizationRepo := new(mockLocalizationRepository)
|
||||||
|
localizationService := localization.NewService(localizationRepo)
|
||||||
|
weaviateWrapper := new(mockWeaviateWrapper)
|
||||||
|
service := NewService(weaviateWrapper, localizationService)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
testQuery := "test query"
|
||||||
|
testFilters := domainsearch.SearchFilters{
|
||||||
|
Languages: []string{"en"},
|
||||||
|
Authors: []string{"1"},
|
||||||
|
Tags: []string{"test-tag"},
|
||||||
|
Categories: []string{"test-category"},
|
||||||
|
}
|
||||||
|
expectedResults := &domainsearch.SearchResults{
|
||||||
|
Results: []domainsearch.SearchResultItem{
|
||||||
|
{Type: "Work", Entity: domain.Work{Title: "Test Work"}, Score: 0.9},
|
||||||
|
{Type: "Author", Entity: domain.Author{Name: "Test Author"}, Score: 0.8},
|
||||||
|
},
|
||||||
|
TotalResults: 2,
|
||||||
|
Limit: 10,
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
params := domainsearch.SearchParams{
|
||||||
|
Query: testQuery,
|
||||||
|
Filters: testFilters,
|
||||||
|
Limit: 10,
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
weaviateWrapper.On("Search", ctx, params).Return(expectedResults, nil)
|
||||||
|
|
||||||
|
results, err := service.Search(ctx, params)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expectedResults, results)
|
||||||
|
weaviateWrapper.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestIndexService_IndexWork(t *testing.T) {
|
func TestIndexService_IndexWork(t *testing.T) {
|
||||||
localizationRepo := new(mockLocalizationRepository)
|
localizationRepo := new(mockLocalizationRepository)
|
||||||
localizationService := localization.NewService(localizationRepo)
|
localizationService := localization.NewService(localizationRepo)
|
||||||
|
|||||||
@ -50,6 +50,7 @@ type BaseModel struct {
|
|||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
Score float64 `gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TranslatableModel extends BaseModel with language support
|
// TranslatableModel extends BaseModel with language support
|
||||||
|
|||||||
@ -28,8 +28,9 @@ type Config struct {
|
|||||||
NLPMemoryCacheCap int `mapstructure:"NLP_MEMORY_CACHE_CAP"`
|
NLPMemoryCacheCap int `mapstructure:"NLP_MEMORY_CACHE_CAP"`
|
||||||
NLPRedisCacheTTLSeconds int `mapstructure:"NLP_REDIS_CACHE_TTL_SECONDS"`
|
NLPRedisCacheTTLSeconds int `mapstructure:"NLP_REDIS_CACHE_TTL_SECONDS"`
|
||||||
NLPUseLingua bool `mapstructure:"NLP_USE_LINGUA"`
|
NLPUseLingua bool `mapstructure:"NLP_USE_LINGUA"`
|
||||||
NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"`
|
NLPUseTFIDF bool `mapstructure:"NLP_USE_TFIDF"`
|
||||||
BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"`
|
BleveIndexPath string `mapstructure:"BLEVE_INDEX_PATH"`
|
||||||
|
SearchAlpha float64 `mapstructure:"SEARCH_ALPHA"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global config instance
|
// Global config instance
|
||||||
@ -62,6 +63,7 @@ func LoadConfig() (*Config, error) {
|
|||||||
v.SetDefault("NLP_USE_LINGUA", true)
|
v.SetDefault("NLP_USE_LINGUA", true)
|
||||||
v.SetDefault("NLP_USE_TFIDF", true)
|
v.SetDefault("NLP_USE_TFIDF", true)
|
||||||
v.SetDefault("BLEVE_INDEX_PATH", "./bleve_index")
|
v.SetDefault("BLEVE_INDEX_PATH", "./bleve_index")
|
||||||
|
v.SetDefault("SEARCH_ALPHA", 0.7)
|
||||||
|
|
||||||
v.AutomaticEnv()
|
v.AutomaticEnv()
|
||||||
|
|
||||||
|
|||||||
@ -16,12 +16,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type weaviateWrapper struct {
|
type weaviateWrapper struct {
|
||||||
client *weaviate.Client
|
client *weaviate.Client
|
||||||
|
host string
|
||||||
|
searchAlpha float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWeaviateWrapper creates a new WeaviateWrapper that implements the SearchClient interface.
|
// NewWeaviateWrapper creates a new WeaviateWrapper that implements the SearchClient interface.
|
||||||
func NewWeaviateWrapper(client *weaviate.Client) domainsearch.SearchClient {
|
func NewWeaviateWrapper(client *weaviate.Client, host string, searchAlpha float64) domainsearch.SearchClient {
|
||||||
return &weaviateWrapper{client: client}
|
return &weaviateWrapper{client: client, host: host, searchAlpha: searchAlpha}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hybridAlpha returns the alpha value for hybrid search, clamped between 0 and 1.
|
||||||
|
func (w *weaviateWrapper) hybridAlpha() float32 {
|
||||||
|
if w.searchAlpha < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if w.searchAlpha > 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return float32(w.searchAlpha)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search performs a multi-class search against the Weaviate instance.
|
// Search performs a multi-class search against the Weaviate instance.
|
||||||
@ -144,7 +157,9 @@ func (w *weaviateWrapper) addSearchArguments(searcher *graphql.GetBuilder, param
|
|||||||
nearText := w.client.GraphQL().NearTextArgBuilder().WithConcepts(params.Concepts)
|
nearText := w.client.GraphQL().NearTextArgBuilder().WithConcepts(params.Concepts)
|
||||||
searcher.WithNearText(nearText)
|
searcher.WithNearText(nearText)
|
||||||
default:
|
default:
|
||||||
hybrid := w.client.GraphQL().HybridArgumentBuilder().WithQuery(params.Query)
|
hybrid := w.client.GraphQL().HybridArgumentBuilder().
|
||||||
|
WithQuery(params.Query).
|
||||||
|
WithAlpha(w.hybridAlpha())
|
||||||
searcher.WithHybrid(hybrid)
|
searcher.WithHybrid(hybrid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -59,7 +59,7 @@ func (s *WeaviateWrapperIntegrationTestSuite) SetupSuite() {
|
|||||||
require.NoError(s.T(), err)
|
require.NoError(s.T(), err)
|
||||||
s.client = client
|
s.client = client
|
||||||
|
|
||||||
s.wrapper = search.NewWeaviateWrapper(client)
|
s.wrapper = search.NewWeaviateWrapper(client, fmt.Sprintf("%s:%s", host, port.Port()), 0.7)
|
||||||
|
|
||||||
s.createTestSchema(ctx)
|
s.createTestSchema(ctx)
|
||||||
s.seedTestData(ctx)
|
s.seedTestData(ctx)
|
||||||
|
|||||||
26
internal/platform/search/weaviate_wrapper_mock.go
Normal file
26
internal/platform/search/weaviate_wrapper_mock.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package search
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"tercul/internal/domain"
|
||||||
|
domainsearch "tercul/internal/domain/search"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockWeaviateWrapper struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWeaviateWrapper) Search(ctx context.Context, params domainsearch.SearchParams) (*domainsearch.SearchResults, error) {
|
||||||
|
args := m.Called(ctx, params)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(*domainsearch.SearchResults), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWeaviateWrapper) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
||||||
|
args := m.Called(ctx, work, content)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
51
internal/platform/search/weaviate_wrapper_test.go
Normal file
51
internal/platform/search/weaviate_wrapper_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package search
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"tercul/internal/domain"
|
||||||
|
domainsearch "tercul/internal/domain/search"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWeaviateWrapper_Search(t *testing.T) {
|
||||||
|
mockWrapper := new(MockWeaviateWrapper)
|
||||||
|
expectedResults := &domainsearch.SearchResults{
|
||||||
|
Results: []domainsearch.SearchResultItem{
|
||||||
|
{
|
||||||
|
Type: "Work",
|
||||||
|
Entity: domain.Work{
|
||||||
|
Title: "Work 1",
|
||||||
|
Description: "alpha beta",
|
||||||
|
TranslatableModel: domain.TranslatableModel{Language: "en"},
|
||||||
|
},
|
||||||
|
Score: 0.95,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
TotalResults: 1,
|
||||||
|
Limit: 1,
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
params := domainsearch.SearchParams{
|
||||||
|
Query: "alpha",
|
||||||
|
Mode: domainsearch.SearchModeHybrid,
|
||||||
|
Filters: domainsearch.SearchFilters{
|
||||||
|
Languages: []string{"en"},
|
||||||
|
},
|
||||||
|
Limit: 1,
|
||||||
|
Offset: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockWrapper.On("Search", context.Background(), params).Return(expectedResults, nil)
|
||||||
|
|
||||||
|
results, err := mockWrapper.Search(context.Background(), params)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, results)
|
||||||
|
assert.Equal(t, 1, len(results.Results))
|
||||||
|
assert.Equal(t, "Work", results.Results[0].Type)
|
||||||
|
work := results.Results[0].Entity.(domain.Work)
|
||||||
|
assert.Equal(t, "Work 1", work.Title)
|
||||||
|
|
||||||
|
mockWrapper.AssertExpectations(t)
|
||||||
|
}
|
||||||
@ -41,13 +41,10 @@ import (
|
|||||||
type mockSearchClient struct{}
|
type mockSearchClient struct{}
|
||||||
|
|
||||||
func (m *mockSearchClient) Search(ctx context.Context, params search.SearchParams) (*search.SearchResults, error) {
|
func (m *mockSearchClient) Search(ctx context.Context, params search.SearchParams) (*search.SearchResults, error) {
|
||||||
return &search.SearchResults{
|
return &search.SearchResults{}, nil
|
||||||
Results: []search.SearchResultItem{},
|
|
||||||
TotalResults: 0,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, content string) error {
|
func (m *mockSearchClient) IndexWork(ctx context.Context, work *domain.Work, pipeline string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user