123 lines
4.0 KiB
Python
123 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from fastapi.concurrency import run_in_threadpool
|
|
|
|
from api.clients.chroma_store import ChromaVectorStore
|
|
from api.config import settings
|
|
from api.schemas import ClassificationResult
|
|
from api.services.local_embeddings import LocalEmbeddingService
|
|
from shared import ORM
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def normalize_law_types_arg(value: list[str] | str | None) -> list[str] | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
return [value]
|
|
normalized = [item for item in value if isinstance(item, str) and item.strip()]
|
|
return normalized or None
|
|
|
|
|
|
class HybridRetrievalService:
|
|
def __init__(
|
|
self,
|
|
orm: ORM,
|
|
embedder: LocalEmbeddingService,
|
|
vector_store: ChromaVectorStore,
|
|
) -> None:
|
|
self.orm = orm
|
|
self.embedder = embedder
|
|
self.vector_store = vector_store
|
|
|
|
async def retrieve(
|
|
self,
|
|
classification: ClassificationResult,
|
|
fallback_law_types: list[str] | None,
|
|
top_k: int,
|
|
) -> list[dict]:
|
|
queries = classification.search_queries or []
|
|
law_types = normalize_law_types_arg(
|
|
classification.filters.get("law_type")
|
|
or fallback_law_types
|
|
or None
|
|
)
|
|
logger.info(
|
|
"Hybrid retrieval started: queries=%s law_types=%s jurisdiction=%s top_k=%s",
|
|
queries,
|
|
law_types,
|
|
classification.jurisdiction,
|
|
top_k,
|
|
)
|
|
|
|
merged_scores: dict[int, float] = {}
|
|
|
|
for query in queries:
|
|
lexical_hits = await self.orm.search_law_chunks_full_text(
|
|
query=query,
|
|
law_types=law_types,
|
|
jurisdiction=classification.jurisdiction,
|
|
limit=settings.fts_top_k,
|
|
)
|
|
logger.info("Full-text hits for query '%s': %s", query, len(lexical_hits))
|
|
for rank, hit in enumerate(lexical_hits):
|
|
merged_scores[hit["chunk_id"]] = merged_scores.get(hit["chunk_id"], 0.0) + (
|
|
1.2 / (rank + 1)
|
|
)
|
|
|
|
query_embedding = await run_in_threadpool(
|
|
self.embedder.encode_queries,
|
|
[query],
|
|
)
|
|
vector_hits = await run_in_threadpool(
|
|
self.vector_store.query,
|
|
query_embedding,
|
|
settings.vector_top_k,
|
|
)
|
|
ids = vector_hits.get("ids", [[]])[0]
|
|
distances = vector_hits.get("distances", [[]])[0]
|
|
metadatas = vector_hits.get("metadatas", [[]])[0]
|
|
logger.info("Vector hits for query '%s': %s", query, len(ids))
|
|
|
|
for rank, (chunk_id, distance, metadata) in enumerate(
|
|
zip(ids, distances, metadatas)
|
|
):
|
|
if law_types and metadata.get("law_type") not in law_types:
|
|
continue
|
|
score = 1.0 / (rank + 1)
|
|
score += max(0.0, 1.0 - float(distance or 1.0))
|
|
merged_scores[int(chunk_id)] = merged_scores.get(int(chunk_id), 0.0) + score
|
|
|
|
ranked_ids = [
|
|
chunk_id
|
|
for chunk_id, _ in sorted(
|
|
merged_scores.items(),
|
|
key=lambda item: item[1],
|
|
reverse=True,
|
|
)
|
|
][: max(top_k * 3, top_k)]
|
|
|
|
rows = await self.orm.get_law_chunks_with_sources_by_ids(
|
|
ranked_ids,
|
|
law_types=law_types,
|
|
jurisdiction=classification.jurisdiction,
|
|
)
|
|
by_id = {row["chunk_id"]: row for row in rows}
|
|
|
|
results = []
|
|
for chunk_id in ranked_ids:
|
|
row = by_id.get(chunk_id)
|
|
if row is None:
|
|
continue
|
|
row["score"] = round(merged_scores.get(chunk_id, 0.0), 6)
|
|
results.append(row)
|
|
if len(results) >= top_k:
|
|
break
|
|
|
|
logger.info("Hybrid retrieval completed: returned=%s", len(results))
|
|
return results
|