66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import torch
|
|
import torch.nn.functional as functional
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
from .config import settings
|
|
|
|
|
|
class RuBertMiniFridaEmbedder:
|
|
def __init__(self) -> None:
|
|
torch.set_grad_enabled(False)
|
|
self.device = "cpu"
|
|
self.max_length = settings.embedding_max_length
|
|
self.batch_size = settings.embedding_batch_size
|
|
self.cache_dir = Path(settings.huggingface_cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
settings.embedding_model,
|
|
cache_dir=str(self.cache_dir),
|
|
)
|
|
self.model = AutoModel.from_pretrained(
|
|
settings.embedding_model,
|
|
cache_dir=str(self.cache_dir),
|
|
)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
@staticmethod
|
|
def mean_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
masked_state = hidden_state * attention_mask.unsqueeze(-1).float()
|
|
summed = torch.sum(masked_state, dim=1)
|
|
counts = attention_mask.sum(dim=1, keepdim=True).float()
|
|
return summed / counts
|
|
|
|
def _encode(self, texts: Iterable[str], prompt: str) -> list[list[float]]:
|
|
prepared_texts = [f"{prompt}{text}" for text in texts]
|
|
if not prepared_texts:
|
|
return []
|
|
|
|
embeddings: list[list[float]] = []
|
|
for start in range(0, len(prepared_texts), self.batch_size):
|
|
batch = prepared_texts[start : start + self.batch_size]
|
|
encoded = self.tokenizer(
|
|
batch,
|
|
max_length=self.max_length,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
encoded = {key: value.to(self.device) for key, value in encoded.items()}
|
|
outputs = self.model(**encoded)
|
|
pooled = self.mean_pool(outputs.last_hidden_state, encoded["attention_mask"])
|
|
normalized = functional.normalize(pooled, p=2, dim=1)
|
|
embeddings.extend(normalized.cpu().tolist())
|
|
return embeddings
|
|
|
|
def embed_documents(self, texts: Iterable[str]) -> list[list[float]]:
|
|
return self._encode(texts, prompt=settings.embedding_document_prefix)
|
|
|
|
def embed_queries(self, texts: Iterable[str]) -> list[list[float]]:
|
|
return self._encode(texts, prompt=settings.embedding_query_prefix)
|