first commit
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from api.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalEmbeddingService:
|
||||
def __init__(self) -> None:
|
||||
logger.info(
|
||||
"Loading embedding model: model=%s device=%s",
|
||||
settings.embedding_model,
|
||||
settings.embedding_device,
|
||||
)
|
||||
self._model = SentenceTransformer(
|
||||
settings.embedding_model,
|
||||
device=settings.embedding_device,
|
||||
)
|
||||
self._model.max_seq_length = 512
|
||||
logger.info(
|
||||
"Embedding model loaded: model=%s max_seq_length=%s",
|
||||
settings.embedding_model,
|
||||
self._model.max_seq_length,
|
||||
)
|
||||
|
||||
def encode_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
logger.info("Encoding document batch: size=%s", len(texts))
|
||||
return self._model.encode(
|
||||
texts,
|
||||
prompt_name="search_document",
|
||||
normalize_embeddings=True,
|
||||
convert_to_numpy=True,
|
||||
batch_size=settings.index_batch_size,
|
||||
show_progress_bar=False,
|
||||
).tolist()
|
||||
|
||||
def encode_queries(self, texts: list[str]) -> list[list[float]]:
|
||||
logger.info("Encoding query batch: size=%s", len(texts))
|
||||
return self._model.encode(
|
||||
texts,
|
||||
prompt_name="search_query",
|
||||
normalize_embeddings=True,
|
||||
convert_to_numpy=True,
|
||||
batch_size=settings.index_batch_size,
|
||||
show_progress_bar=False,
|
||||
).tolist()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_embedding_service() -> LocalEmbeddingService:
|
||||
return LocalEmbeddingService()
|
||||
Reference in New Issue
Block a user