turash/bugulma/backend/internal/service/ollama_client.go

256 lines
6.6 KiB
Go

package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// OllamaClient handles communication with Ollama API
type OllamaClient struct {
baseURL string
model string
httpClient *http.Client
timeout time.Duration
username string
password string
}
// OllamaClientConfig holds configuration for OllamaClient
type OllamaClientConfig struct {
BaseURL string
Model string
Timeout time.Duration
MaxRetries int
Username string
Password string
}
// DefaultOllamaClientConfig returns default configuration
func DefaultOllamaClientConfig() OllamaClientConfig {
return OllamaClientConfig{
BaseURL: "http://localhost:11434",
Model: "qwen2.5:7b",
Timeout: 120 * time.Second,
MaxRetries: 0, // No retries by default
}
}
// OllamaGenerateRequest represents the request payload for Ollama API
type OllamaGenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
}
// OllamaGenerateResponse represents the response from Ollama API
type OllamaGenerateResponse struct {
Response string `json:"response"`
Done bool `json:"done"`
Error string `json:"error,omitempty"`
}
// NewOllamaClient creates a new Ollama client with default configuration
func NewOllamaClient(baseURL, model string) *OllamaClient {
config := DefaultOllamaClientConfig()
if baseURL != "" {
config.BaseURL = baseURL
}
if model != "" {
config.Model = model
}
return NewOllamaClientWithConfig(config)
}
// NewOllamaClientWithAuth creates a new Ollama client with authentication
func NewOllamaClientWithAuth(baseURL, model, username, password string) *OllamaClient {
config := DefaultOllamaClientConfig()
if baseURL != "" {
config.BaseURL = baseURL
}
if model != "" {
config.Model = model
}
config.Username = username
config.Password = password
return NewOllamaClientWithConfig(config)
}
// NewOllamaClientWithConfig creates a new Ollama client with custom configuration
func NewOllamaClientWithConfig(config OllamaClientConfig) *OllamaClient {
if config.BaseURL == "" {
config.BaseURL = DefaultOllamaClientConfig().BaseURL
}
if config.Model == "" {
config.Model = DefaultOllamaClientConfig().Model
}
if config.Timeout == 0 {
config.Timeout = DefaultOllamaClientConfig().Timeout
}
return &OllamaClient{
baseURL: config.BaseURL,
model: config.Model,
timeout: config.Timeout,
username: config.Username,
password: config.Password,
httpClient: &http.Client{
Timeout: config.Timeout,
},
}
}
const (
// MaxTextLength is the maximum text length allowed for translation (50KB)
MaxTextLength = 50 * 1024
// MinTextLength is the minimum text length (1 character)
MinTextLength = 1
)
// Translate translates text from Russian to the target locale using Ollama
// It accepts a context for cancellation and timeout control
func (c *OllamaClient) Translate(ctx context.Context, text, targetLocale string) (string, error) {
// Validate input
if err := c.validateInput(text, targetLocale); err != nil {
return "", err
}
// Build translation prompt based on target locale
var targetLanguage string
switch targetLocale {
case "en":
targetLanguage = "English"
case "tt":
targetLanguage = "Tatar"
default:
return "", fmt.Errorf("unsupported target locale: %s. Supported: en, tt", targetLocale)
}
// Create a clear translation prompt
prompt := fmt.Sprintf(`Translate the following Russian text to %s. Return only the translation, without any explanations or additional text.
Russian text: %s
Translation:`, targetLanguage, text)
// Prepare request
reqBody := OllamaGenerateRequest{
Model: c.model,
Prompt: prompt,
Stream: false,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
// Make HTTP request with context
apiURL := fmt.Sprintf("%s/api/generate", c.baseURL)
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Add basic authentication if credentials are provided
if c.username != "" && c.password != "" {
req.SetBasicAuth(c.username, c.password)
}
resp, err := c.httpClient.Do(req)
if err != nil {
// Check if error is due to context cancellation/timeout
if ctx.Err() != nil {
return "", fmt.Errorf("request cancelled or timed out: %w", ctx.Err())
}
return "", fmt.Errorf("failed to make request to Ollama: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("Ollama API returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
// Parse response
var ollamaResp OllamaGenerateResponse
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
if ollamaResp.Error != "" {
return "", fmt.Errorf("Ollama API error: %s", ollamaResp.Error)
}
if ollamaResp.Response == "" {
return "", fmt.Errorf("empty response from Ollama")
}
// Clean up the response (remove any extra whitespace)
translation := strings.TrimSpace(ollamaResp.Response)
return translation, nil
}
// SetModel changes the model used for translation
func (c *OllamaClient) SetModel(model string) error {
if model == "" {
return fmt.Errorf("model cannot be empty")
}
c.model = model
return nil
}
// SetBaseURL changes the base URL for Ollama API
func (c *OllamaClient) SetBaseURL(baseURL string) error {
if baseURL == "" {
return fmt.Errorf("base URL cannot be empty")
}
// Validate URL format
_, err := url.Parse(baseURL)
if err != nil {
return fmt.Errorf("invalid base URL format: %w", err)
}
c.baseURL = baseURL
return nil
}
// SetTimeout changes the HTTP client timeout
func (c *OllamaClient) SetTimeout(timeout time.Duration) {
if timeout <= 0 {
timeout = DefaultOllamaClientConfig().Timeout
}
c.timeout = timeout
c.httpClient.Timeout = timeout
}
// validateInput validates the input parameters
func (c *OllamaClient) validateInput(text, targetLocale string) error {
if text == "" {
return fmt.Errorf("text cannot be empty")
}
textLen := len([]rune(text)) // Count runes, not bytes
if textLen < MinTextLength {
return fmt.Errorf("text is too short (minimum %d characters)", MinTextLength)
}
if len(text) > MaxTextLength {
return fmt.Errorf("text is too long (maximum %d bytes, got %d)", MaxTextLength, len(text))
}
if targetLocale != "en" && targetLocale != "tt" {
return fmt.Errorf("unsupported target locale: %s. Supported: en, tt", targetLocale)
}
return nil
}