from __future__ import annotations from pathlib import Path from typing import Iterable import torch import torch.nn.functional as functional from transformers import AutoModel, AutoTokenizer from .config import settings class RuBertMiniFridaEmbedder: def __init__(self) -> None: torch.set_grad_enabled(False) self.device = "cpu" self.max_length = settings.embedding_max_length self.batch_size = settings.embedding_batch_size self.cache_dir = Path(settings.huggingface_cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.tokenizer = AutoTokenizer.from_pretrained( settings.embedding_model, cache_dir=str(self.cache_dir), ) self.model = AutoModel.from_pretrained( settings.embedding_model, cache_dir=str(self.cache_dir), ) self.model.to(self.device) self.model.eval() @staticmethod def mean_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: masked_state = hidden_state * attention_mask.unsqueeze(-1).float() summed = torch.sum(masked_state, dim=1) counts = attention_mask.sum(dim=1, keepdim=True).float() return summed / counts def _encode(self, texts: Iterable[str], prompt: str) -> list[list[float]]: prepared_texts = [f"{prompt}{text}" for text in texts] if not prepared_texts: return [] embeddings: list[list[float]] = [] for start in range(0, len(prepared_texts), self.batch_size): batch = prepared_texts[start : start + self.batch_size] encoded = self.tokenizer( batch, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt", ) encoded = {key: value.to(self.device) for key, value in encoded.items()} outputs = self.model(**encoded) pooled = self.mean_pool(outputs.last_hidden_state, encoded["attention_mask"]) normalized = functional.normalize(pooled, p=2, dim=1) embeddings.extend(normalized.cpu().tolist()) return embeddings def embed_documents(self, texts: Iterable[str]) -> list[list[float]]: return self._encode(texts, prompt=settings.embedding_document_prefix) def embed_queries(self, texts: Iterable[str]) -> list[list[float]]: return self._encode(texts, prompt=settings.embedding_query_prefix)