from __future__ import annotations from functools import lru_cache import logging from sentence_transformers import SentenceTransformer from api.config import settings logger = logging.getLogger(__name__) class LocalEmbeddingService: def __init__(self) -> None: logger.info( "Loading embedding model: model=%s device=%s", settings.embedding_model, settings.embedding_device, ) self._model = SentenceTransformer( settings.embedding_model, device=settings.embedding_device, ) self._model.max_seq_length = 512 logger.info( "Embedding model loaded: model=%s max_seq_length=%s", settings.embedding_model, self._model.max_seq_length, ) def encode_documents(self, texts: list[str]) -> list[list[float]]: logger.info("Encoding document batch: size=%s", len(texts)) return self._model.encode( texts, prompt_name="search_document", normalize_embeddings=True, convert_to_numpy=True, batch_size=settings.index_batch_size, show_progress_bar=False, ).tolist() def encode_queries(self, texts: list[str]) -> list[list[float]]: logger.info("Encoding query batch: size=%s", len(texts)) return self._model.encode( texts, prompt_name="search_query", normalize_embeddings=True, convert_to_numpy=True, batch_size=settings.index_batch_size, show_progress_bar=False, ).tolist() @lru_cache(maxsize=1) def get_embedding_service() -> LocalEmbeddingService: return LocalEmbeddingService()