Files
LawBot/api/services/local_embeddings.py
T
2026-05-25 01:12:43 +03:00

58 lines
1.7 KiB
Python

from __future__ import annotations
from functools import lru_cache
import logging
from sentence_transformers import SentenceTransformer
from api.config import settings
logger = logging.getLogger(__name__)
class LocalEmbeddingService:
def __init__(self) -> None:
logger.info(
"Loading embedding model: model=%s device=%s",
settings.embedding_model,
settings.embedding_device,
)
self._model = SentenceTransformer(
settings.embedding_model,
device=settings.embedding_device,
)
self._model.max_seq_length = 512
logger.info(
"Embedding model loaded: model=%s max_seq_length=%s",
settings.embedding_model,
self._model.max_seq_length,
)
def encode_documents(self, texts: list[str]) -> list[list[float]]:
logger.info("Encoding document batch: size=%s", len(texts))
return self._model.encode(
texts,
prompt_name="search_document",
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=settings.index_batch_size,
show_progress_bar=False,
).tolist()
def encode_queries(self, texts: list[str]) -> list[list[float]]:
logger.info("Encoding query batch: size=%s", len(texts))
return self._model.encode(
texts,
prompt_name="search_query",
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=settings.index_batch_size,
show_progress_bar=False,
).tolist()
@lru_cache(maxsize=1)
def get_embedding_service() -> LocalEmbeddingService:
return LocalEmbeddingService()