from __future__ import annotations import logging from fastapi.concurrency import run_in_threadpool from api.clients.chroma_store import ChromaVectorStore from api.config import settings from api.schemas import ClassificationResult from api.services.local_embeddings import LocalEmbeddingService from shared import ORM logger = logging.getLogger(__name__) def normalize_law_types_arg(value: list[str] | str | None) -> list[str] | None: if value is None: return None if isinstance(value, str): return [value] normalized = [item for item in value if isinstance(item, str) and item.strip()] return normalized or None class HybridRetrievalService: def __init__( self, orm: ORM, embedder: LocalEmbeddingService, vector_store: ChromaVectorStore, ) -> None: self.orm = orm self.embedder = embedder self.vector_store = vector_store async def retrieve( self, classification: ClassificationResult, fallback_law_types: list[str] | None, top_k: int, ) -> list[dict]: queries = classification.search_queries or [] law_types = normalize_law_types_arg( classification.filters.get("law_type") or fallback_law_types or None ) logger.info( "Hybrid retrieval started: queries=%s law_types=%s jurisdiction=%s top_k=%s", queries, law_types, classification.jurisdiction, top_k, ) merged_scores: dict[int, float] = {} for query in queries: lexical_hits = await self.orm.search_law_chunks_full_text( query=query, law_types=law_types, jurisdiction=classification.jurisdiction, limit=settings.fts_top_k, ) logger.info("Full-text hits for query '%s': %s", query, len(lexical_hits)) for rank, hit in enumerate(lexical_hits): merged_scores[hit["chunk_id"]] = merged_scores.get(hit["chunk_id"], 0.0) + ( 1.2 / (rank + 1) ) query_embedding = await run_in_threadpool( self.embedder.encode_queries, [query], ) vector_hits = await run_in_threadpool( self.vector_store.query, query_embedding, settings.vector_top_k, ) ids = vector_hits.get("ids", [[]])[0] distances = vector_hits.get("distances", [[]])[0] metadatas = vector_hits.get("metadatas", [[]])[0] logger.info("Vector hits for query '%s': %s", query, len(ids)) for rank, (chunk_id, distance, metadata) in enumerate( zip(ids, distances, metadatas) ): if law_types and metadata.get("law_type") not in law_types: continue score = 1.0 / (rank + 1) score += max(0.0, 1.0 - float(distance or 1.0)) merged_scores[int(chunk_id)] = merged_scores.get(int(chunk_id), 0.0) + score ranked_ids = [ chunk_id for chunk_id, _ in sorted( merged_scores.items(), key=lambda item: item[1], reverse=True, ) ][: max(top_k * 3, top_k)] rows = await self.orm.get_law_chunks_with_sources_by_ids( ranked_ids, law_types=law_types, jurisdiction=classification.jurisdiction, ) by_id = {row["chunk_id"]: row for row in rows} results = [] for chunk_id in ranked_ids: row = by_id.get(chunk_id) if row is None: continue row["score"] = round(merged_scores.get(chunk_id, 0.0), 6) results.append(row) if len(results) >= top_k: break logger.info("Hybrid retrieval completed: returned=%s", len(results)) return results