first commit
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as functional
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class RuBertMiniFridaEmbedder:
|
||||
def __init__(self) -> None:
|
||||
torch.set_grad_enabled(False)
|
||||
self.device = "cpu"
|
||||
self.max_length = settings.embedding_max_length
|
||||
self.batch_size = settings.embedding_batch_size
|
||||
self.cache_dir = Path(settings.huggingface_cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
settings.embedding_model,
|
||||
cache_dir=str(self.cache_dir),
|
||||
)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
settings.embedding_model,
|
||||
cache_dir=str(self.cache_dir),
|
||||
)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def mean_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
masked_state = hidden_state * attention_mask.unsqueeze(-1).float()
|
||||
summed = torch.sum(masked_state, dim=1)
|
||||
counts = attention_mask.sum(dim=1, keepdim=True).float()
|
||||
return summed / counts
|
||||
|
||||
def _encode(self, texts: Iterable[str], prompt: str) -> list[list[float]]:
|
||||
prepared_texts = [f"{prompt}{text}" for text in texts]
|
||||
if not prepared_texts:
|
||||
return []
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
for start in range(0, len(prepared_texts), self.batch_size):
|
||||
batch = prepared_texts[start : start + self.batch_size]
|
||||
encoded = self.tokenizer(
|
||||
batch,
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
encoded = {key: value.to(self.device) for key, value in encoded.items()}
|
||||
outputs = self.model(**encoded)
|
||||
pooled = self.mean_pool(outputs.last_hidden_state, encoded["attention_mask"])
|
||||
normalized = functional.normalize(pooled, p=2, dim=1)
|
||||
embeddings.extend(normalized.cpu().tolist())
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
return self._encode(texts, prompt=settings.embedding_document_prefix)
|
||||
|
||||
def embed_queries(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
return self._encode(texts, prompt=settings.embedding_query_prefix)
|
||||
Reference in New Issue
Block a user