first commit
This commit is contained in:
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from api.clients.chroma_store import ChromaVectorStore
|
||||
from api.config import settings
|
||||
from api.schemas import ClassificationResult
|
||||
from api.services.local_embeddings import LocalEmbeddingService
|
||||
from shared import ORM
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_law_types_arg(value: list[str] | str | None) -> list[str] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
normalized = [item for item in value if isinstance(item, str) and item.strip()]
|
||||
return normalized or None
|
||||
|
||||
|
||||
class HybridRetrievalService:
|
||||
def __init__(
|
||||
self,
|
||||
orm: ORM,
|
||||
embedder: LocalEmbeddingService,
|
||||
vector_store: ChromaVectorStore,
|
||||
) -> None:
|
||||
self.orm = orm
|
||||
self.embedder = embedder
|
||||
self.vector_store = vector_store
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
classification: ClassificationResult,
|
||||
fallback_law_types: list[str] | None,
|
||||
top_k: int,
|
||||
) -> list[dict]:
|
||||
queries = classification.search_queries or []
|
||||
law_types = normalize_law_types_arg(
|
||||
classification.filters.get("law_type")
|
||||
or fallback_law_types
|
||||
or None
|
||||
)
|
||||
logger.info(
|
||||
"Hybrid retrieval started: queries=%s law_types=%s jurisdiction=%s top_k=%s",
|
||||
queries,
|
||||
law_types,
|
||||
classification.jurisdiction,
|
||||
top_k,
|
||||
)
|
||||
|
||||
merged_scores: dict[int, float] = {}
|
||||
|
||||
for query in queries:
|
||||
lexical_hits = await self.orm.search_law_chunks_full_text(
|
||||
query=query,
|
||||
law_types=law_types,
|
||||
jurisdiction=classification.jurisdiction,
|
||||
limit=settings.fts_top_k,
|
||||
)
|
||||
logger.info("Full-text hits for query '%s': %s", query, len(lexical_hits))
|
||||
for rank, hit in enumerate(lexical_hits):
|
||||
merged_scores[hit["chunk_id"]] = merged_scores.get(hit["chunk_id"], 0.0) + (
|
||||
1.2 / (rank + 1)
|
||||
)
|
||||
|
||||
query_embedding = await run_in_threadpool(
|
||||
self.embedder.encode_queries,
|
||||
[query],
|
||||
)
|
||||
vector_hits = await run_in_threadpool(
|
||||
self.vector_store.query,
|
||||
query_embedding,
|
||||
settings.vector_top_k,
|
||||
)
|
||||
ids = vector_hits.get("ids", [[]])[0]
|
||||
distances = vector_hits.get("distances", [[]])[0]
|
||||
metadatas = vector_hits.get("metadatas", [[]])[0]
|
||||
logger.info("Vector hits for query '%s': %s", query, len(ids))
|
||||
|
||||
for rank, (chunk_id, distance, metadata) in enumerate(
|
||||
zip(ids, distances, metadatas)
|
||||
):
|
||||
if law_types and metadata.get("law_type") not in law_types:
|
||||
continue
|
||||
score = 1.0 / (rank + 1)
|
||||
score += max(0.0, 1.0 - float(distance or 1.0))
|
||||
merged_scores[int(chunk_id)] = merged_scores.get(int(chunk_id), 0.0) + score
|
||||
|
||||
ranked_ids = [
|
||||
chunk_id
|
||||
for chunk_id, _ in sorted(
|
||||
merged_scores.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
][: max(top_k * 3, top_k)]
|
||||
|
||||
rows = await self.orm.get_law_chunks_with_sources_by_ids(
|
||||
ranked_ids,
|
||||
law_types=law_types,
|
||||
jurisdiction=classification.jurisdiction,
|
||||
)
|
||||
by_id = {row["chunk_id"]: row for row in rows}
|
||||
|
||||
results = []
|
||||
for chunk_id in ranked_ids:
|
||||
row = by_id.get(chunk_id)
|
||||
if row is None:
|
||||
continue
|
||||
row["score"] = round(merged_scores.get(chunk_id, 0.0), 6)
|
||||
results.append(row)
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
|
||||
logger.info("Hybrid retrieval completed: returned=%s", len(results))
|
||||
return results
|
||||
Reference in New Issue
Block a user