Files
2026-05-25 01:12:43 +03:00

123 lines
4.0 KiB
Python

from __future__ import annotations
import logging
from fastapi.concurrency import run_in_threadpool
from api.clients.chroma_store import ChromaVectorStore
from api.config import settings
from api.schemas import ClassificationResult
from api.services.local_embeddings import LocalEmbeddingService
from shared import ORM
logger = logging.getLogger(__name__)
def normalize_law_types_arg(value: list[str] | str | None) -> list[str] | None:
if value is None:
return None
if isinstance(value, str):
return [value]
normalized = [item for item in value if isinstance(item, str) and item.strip()]
return normalized or None
class HybridRetrievalService:
def __init__(
self,
orm: ORM,
embedder: LocalEmbeddingService,
vector_store: ChromaVectorStore,
) -> None:
self.orm = orm
self.embedder = embedder
self.vector_store = vector_store
async def retrieve(
self,
classification: ClassificationResult,
fallback_law_types: list[str] | None,
top_k: int,
) -> list[dict]:
queries = classification.search_queries or []
law_types = normalize_law_types_arg(
classification.filters.get("law_type")
or fallback_law_types
or None
)
logger.info(
"Hybrid retrieval started: queries=%s law_types=%s jurisdiction=%s top_k=%s",
queries,
law_types,
classification.jurisdiction,
top_k,
)
merged_scores: dict[int, float] = {}
for query in queries:
lexical_hits = await self.orm.search_law_chunks_full_text(
query=query,
law_types=law_types,
jurisdiction=classification.jurisdiction,
limit=settings.fts_top_k,
)
logger.info("Full-text hits for query '%s': %s", query, len(lexical_hits))
for rank, hit in enumerate(lexical_hits):
merged_scores[hit["chunk_id"]] = merged_scores.get(hit["chunk_id"], 0.0) + (
1.2 / (rank + 1)
)
query_embedding = await run_in_threadpool(
self.embedder.encode_queries,
[query],
)
vector_hits = await run_in_threadpool(
self.vector_store.query,
query_embedding,
settings.vector_top_k,
)
ids = vector_hits.get("ids", [[]])[0]
distances = vector_hits.get("distances", [[]])[0]
metadatas = vector_hits.get("metadatas", [[]])[0]
logger.info("Vector hits for query '%s': %s", query, len(ids))
for rank, (chunk_id, distance, metadata) in enumerate(
zip(ids, distances, metadatas)
):
if law_types and metadata.get("law_type") not in law_types:
continue
score = 1.0 / (rank + 1)
score += max(0.0, 1.0 - float(distance or 1.0))
merged_scores[int(chunk_id)] = merged_scores.get(int(chunk_id), 0.0) + score
ranked_ids = [
chunk_id
for chunk_id, _ in sorted(
merged_scores.items(),
key=lambda item: item[1],
reverse=True,
)
][: max(top_k * 3, top_k)]
rows = await self.orm.get_law_chunks_with_sources_by_ids(
ranked_ids,
law_types=law_types,
jurisdiction=classification.jurisdiction,
)
by_id = {row["chunk_id"]: row for row in rows}
results = []
for chunk_id in ranked_ids:
row = by_id.get(chunk_id)
if row is None:
continue
row["score"] = round(merged_scores.get(chunk_id, 0.0), 6)
results.append(row)
if len(results) >= top_k:
break
logger.info("Hybrid retrieval completed: returned=%s", len(results))
return results