Files
2026-05-12 23:37:04 +03:00

66 lines
2.5 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Iterable
import torch
import torch.nn.functional as functional
from transformers import AutoModel, AutoTokenizer
from .config import settings
class RuBertMiniFridaEmbedder:
def __init__(self) -> None:
torch.set_grad_enabled(False)
self.device = "cpu"
self.max_length = settings.embedding_max_length
self.batch_size = settings.embedding_batch_size
self.cache_dir = Path(settings.huggingface_cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.tokenizer = AutoTokenizer.from_pretrained(
settings.embedding_model,
cache_dir=str(self.cache_dir),
)
self.model = AutoModel.from_pretrained(
settings.embedding_model,
cache_dir=str(self.cache_dir),
)
self.model.to(self.device)
self.model.eval()
@staticmethod
def mean_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
masked_state = hidden_state * attention_mask.unsqueeze(-1).float()
summed = torch.sum(masked_state, dim=1)
counts = attention_mask.sum(dim=1, keepdim=True).float()
return summed / counts
def _encode(self, texts: Iterable[str], prompt: str) -> list[list[float]]:
prepared_texts = [f"{prompt}{text}" for text in texts]
if not prepared_texts:
return []
embeddings: list[list[float]] = []
for start in range(0, len(prepared_texts), self.batch_size):
batch = prepared_texts[start : start + self.batch_size]
encoded = self.tokenizer(
batch,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
encoded = {key: value.to(self.device) for key, value in encoded.items()}
outputs = self.model(**encoded)
pooled = self.mean_pool(outputs.last_hidden_state, encoded["attention_mask"])
normalized = functional.normalize(pooled, p=2, dim=1)
embeddings.extend(normalized.cpu().tolist())
return embeddings
def embed_documents(self, texts: Iterable[str]) -> list[list[float]]:
return self._encode(texts, prompt=settings.embedding_document_prefix)
def embed_queries(self, texts: Iterable[str]) -> list[list[float]]:
return self._encode(texts, prompt=settings.embedding_query_prefix)