58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
import logging
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
from api.config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LocalEmbeddingService:
|
|
def __init__(self) -> None:
|
|
logger.info(
|
|
"Loading embedding model: model=%s device=%s",
|
|
settings.embedding_model,
|
|
settings.embedding_device,
|
|
)
|
|
self._model = SentenceTransformer(
|
|
settings.embedding_model,
|
|
device=settings.embedding_device,
|
|
)
|
|
self._model.max_seq_length = 512
|
|
logger.info(
|
|
"Embedding model loaded: model=%s max_seq_length=%s",
|
|
settings.embedding_model,
|
|
self._model.max_seq_length,
|
|
)
|
|
|
|
def encode_documents(self, texts: list[str]) -> list[list[float]]:
|
|
logger.info("Encoding document batch: size=%s", len(texts))
|
|
return self._model.encode(
|
|
texts,
|
|
prompt_name="search_document",
|
|
normalize_embeddings=True,
|
|
convert_to_numpy=True,
|
|
batch_size=settings.index_batch_size,
|
|
show_progress_bar=False,
|
|
).tolist()
|
|
|
|
def encode_queries(self, texts: list[str]) -> list[list[float]]:
|
|
logger.info("Encoding query batch: size=%s", len(texts))
|
|
return self._model.encode(
|
|
texts,
|
|
prompt_name="search_query",
|
|
normalize_embeddings=True,
|
|
convert_to_numpy=True,
|
|
batch_size=settings.index_batch_size,
|
|
show_progress_bar=False,
|
|
).tolist()
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_embedding_service() -> LocalEmbeddingService:
|
|
return LocalEmbeddingService()
|