first commit

This commit is contained in:
2026-05-25 01:12:43 +03:00
commit bfc22efe24
83 changed files with 8903 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
"""Service layer for the RAG API."""
+126
View File
@@ -0,0 +1,126 @@
from __future__ import annotations
import logging
from fastapi.concurrency import run_in_threadpool
from api.clients.chroma_store import ChromaVectorStore
from api.config import settings
from api.services.local_embeddings import LocalEmbeddingService
from shared import ORM
logger = logging.getLogger(__name__)
def build_embedding_text(chunk: dict) -> str:
article_title = chunk.get("article_title") or ""
source_title = chunk.get("source_title") or ""
return "\n".join(
part
for part in [
f"source: {source_title}",
f"article: {article_title}",
f"text: {chunk['chunk_text']}",
]
if part.strip()
)
class IndexingService:
def __init__(
self,
orm: ORM,
embedder: LocalEmbeddingService,
vector_store: ChromaVectorStore,
) -> None:
self.orm = orm
self.embedder = embedder
self.vector_store = vector_store
async def get_indexable_chunks_count(
self,
source_ids: list[int] | None = None,
law_types: list[str] | None = None,
) -> int:
chunks = await self.orm.list_chunks_for_indexing(
source_ids=source_ids,
law_types=law_types,
active_only=True,
)
return len(chunks)
async def rebuild(
self,
source_ids: list[int] | None = None,
law_types: list[str] | None = None,
reset_collection: bool = True,
batch_size: int | None = None,
) -> dict:
logger.info(
"Index rebuild started: reset_collection=%s source_ids=%s law_types=%s",
reset_collection,
source_ids,
law_types,
)
chunks = await self.orm.list_chunks_for_indexing(
source_ids=source_ids,
law_types=law_types,
active_only=True,
)
logger.info("Loaded %s chunks from Postgres for indexing", len(chunks))
if reset_collection:
await run_in_threadpool(self.vector_store.reset_collection)
batch_size = batch_size or settings.index_batch_size
indexed_chunks = 0
indexed_sources = len({chunk["source_id"] for chunk in chunks})
for start in range(0, len(chunks), batch_size):
batch = chunks[start : start + batch_size]
batch_number = (start // batch_size) + 1
total_batches = max(1, (len(chunks) + batch_size - 1) // batch_size)
logger.info(
"Indexing batch %s/%s with %s chunks",
batch_number,
total_batches,
len(batch),
)
embeddings = await run_in_threadpool(
self.embedder.encode_documents,
[build_embedding_text(chunk) for chunk in batch],
)
await run_in_threadpool(
self.vector_store.upsert,
[str(chunk["chunk_id"]) for chunk in batch],
[chunk["chunk_text"] for chunk in batch],
embeddings,
[
{
"chunk_id": chunk["chunk_id"],
"source_id": chunk["source_id"],
"source_title": chunk["source_title"],
"source_url": chunk["source_url"],
"law_type": chunk["law_type"] or "",
"jurisdiction": chunk["jurisdiction"],
"article_number": chunk["article_number"] or "",
"article_title": chunk["article_title"] or "",
"version_hash": chunk["version_hash"],
}
for chunk in batch
],
)
indexed_chunks += len(batch)
logger.info(
"Indexed %s/%s chunks into Chroma",
indexed_chunks,
len(chunks),
)
result = {
"indexed_chunks": indexed_chunks,
"indexed_sources": indexed_sources,
"collection_name": self.vector_store.collection_name,
}
logger.info("Index rebuild completed: %s", result)
return result
+724
View File
@@ -0,0 +1,724 @@
from __future__ import annotations
import json
import logging
import re
from openai import AsyncOpenAI
from api.prompts.rag_prompts import (
ANSWER_PROMPT,
CLASSIFIER_PROMPT,
CONSULTATION_TITLE_PROMPT,
FOLLOW_UP_ANSWER_PROMPT,
)
from api.schemas import ClassificationResult, StructuredInitialAnswer
logger = logging.getLogger(__name__)
CATEGORY_MAP = {
"работа": ["labor"],
"труд": ["labor"],
"защита прав потребителей": ["consumer", "civil"],
"потребител": ["consumer", "civil"],
"жилье": ["housing", "civil", "mortgage"],
"аренда": ["housing", "civil"],
"семья": ["family"],
"долги": ["civil", "enforcement"],
"займы": ["civil"],
"договоры": ["civil"],
"договор": ["civil"],
"суд": ["procedural"],
"процесс": ["procedural"],
"административ": ["administrative"],
"уголов": ["criminal"],
"краж": ["criminal"],
"мошеннич": ["criminal"],
}
LAW_TYPE_ALIASES = {
"labor": "labor",
"труд": "labor",
"трудовое право": "labor",
"criminal": "criminal",
"уголов": "criminal",
"civil": "civil",
"граждан": "civil",
"договор": "civil",
"consumer": "consumer",
"защита прав потребителей": "consumer",
"потребител": "consumer",
"housing": "housing",
"жилищ": "housing",
"аренда": "housing",
"family": "family",
"семейн": "family",
"procedural": "procedural",
"процесс": "procedural",
"суд": "procedural",
"administrative": "administrative",
"административ": "administrative",
"enforcement": "enforcement",
"исполнительн": "enforcement",
"mortgage": "mortgage",
"ипотек": "mortgage",
}
INITIAL_ANSWER_RESPONSE_FORMAT = {
"type": "json_schema",
"json_schema": {
"name": "lawbot_initial_answer",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"short_conclusion": {"type": "string"},
"legal_points": {
"type": "array",
"items": {"type": "string"},
},
"action_steps": {
"type": "array",
"items": {"type": "string"},
},
"risks": {
"type": "array",
"items": {"type": "string"},
},
},
"required": [
"short_conclusion",
"legal_points",
"action_steps",
"risks",
],
},
},
}
CLASSIFIER_RESPONSE_FORMAT = {
"type": "json_schema",
"json_schema": {
"name": "lawbot_classifier",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"legal_domain": {"type": "string"},
"issue_type": {"type": "string"},
"jurisdiction": {"type": "string"},
"region": {
"type": ["string", "null"],
},
"needs_clarification": {"type": "boolean"},
"clarification_questions": {
"type": "array",
"items": {"type": "string"},
},
"search_queries": {
"type": "array",
"items": {"type": "string"},
},
"filters": {
"type": "object",
"additionalProperties": False,
"properties": {
"law_type": {
"type": ["array", "null"],
"items": {"type": "string"},
},
},
},
},
"required": [
"legal_domain",
"issue_type",
"jurisdiction",
"region",
"needs_clarification",
"clarification_questions",
"search_queries",
"filters",
],
},
},
}
def extract_json(content: str, purpose: str = "response") -> dict:
try:
return json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, re.S)
if not match:
logger.error("LLM %s returned non-JSON content: %s", purpose, content)
raise RuntimeError(f"LLM {purpose} returned invalid JSON.")
try:
return json.loads(match.group(0))
except json.JSONDecodeError as exc:
logger.error("LLM %s returned malformed JSON fragment: %s", purpose, content)
raise RuntimeError(f"LLM {purpose} returned malformed JSON.") from exc
def looks_like_llm_refusal(content: str) -> bool:
normalized = " ".join(content.lower().split())
refusal_markers = (
"i cannot assist",
"i can't assist",
"i cannot help",
"i'm sorry, but i cannot",
"не могу помочь с этим",
"не могу помочь в этом",
"не могу содействовать",
"не могу помочь с запросом",
"не могу ответить на этот запрос",
)
return any(marker in normalized for marker in refusal_markers)
def infer_law_types(category: str | None) -> list[str] | None:
if not category:
return None
normalized = category.lower().strip()
for key, law_types in CATEGORY_MAP.items():
if key in normalized:
return law_types
return None
def normalize_law_type_values(value) -> list[str] | None:
if value is None:
return None
raw_values = value if isinstance(value, list) else [value]
normalized_values: list[str] = []
for raw_value in raw_values:
if not isinstance(raw_value, str):
continue
raw_normalized = raw_value.strip().lower()
for alias, code in LAW_TYPE_ALIASES.items():
if alias in raw_normalized:
if code not in normalized_values:
normalized_values.append(code)
break
return normalized_values or None
def extract_message_content(completion, purpose: str) -> str:
choices = getattr(completion, "choices", None)
if not choices:
logger.error(
"LLM %s returned empty choices: model=%s id=%s usage=%s raw=%s",
purpose,
getattr(completion, "model", None),
getattr(completion, "id", None),
getattr(completion, "usage", None),
completion,
)
raise RuntimeError(
"LLM provider returned an empty response. Check OPENROUTER model name and provider response."
)
first_choice = choices[0]
message = getattr(first_choice, "message", None)
if message is None:
logger.error(
"LLM %s returned choice without message: model=%s id=%s choice=%s",
purpose,
getattr(completion, "model", None),
getattr(completion, "id", None),
first_choice,
)
raise RuntimeError("LLM provider returned a malformed response without message.")
content = getattr(message, "content", None)
if content is None:
logger.error(
"LLM %s returned empty message content: model=%s id=%s finish_reason=%s message=%s",
purpose,
getattr(completion, "model", None),
getattr(completion, "id", None),
getattr(first_choice, "finish_reason", None),
message,
)
raise RuntimeError("LLM provider returned an empty message content.")
return content
def build_fallback_title(question: str, limit: int = 70) -> str:
title = " ".join(question.strip().split())
if not title:
return "Юридическая консультация"
title = title.rstrip(" .,!?:;")
if len(title) <= limit:
return title
trimmed = title[: limit - 1].rstrip(" .,!?:;")
return f"{trimmed}"
def infer_primary_law_type(category: str | None, question: str) -> str:
inferred = infer_law_types(category)
if inferred:
return inferred[0]
normalized_question = question.lower()
for key, law_types in CATEGORY_MAP.items():
if key in normalized_question:
return law_types[0]
return "other"
def sanitize_answer_text(answer: str) -> str:
sanitized = answer.strip()
replacements = (
(r"(?i)\bSOURCES\b", "нормах закона"),
(r"(?i)\bsource\b", "нормах закона"),
(r"(?i)\bchunk(?:s)?\b", "нормах закона"),
(r"(?i)\bretrieval\b", "поиске норм"),
(
r"(?i)в ваших нормах закона",
"в найденных нормах закона",
),
(
r"(?i)на основании этих источников",
"по найденным нормам закона",
),
(
r"(?i)по этим источникам",
"по найденным нормам закона",
),
(
r"(?i)в базе нет",
"прямого ответа в найденных нормах нет",
),
(
r"(?i)в контексте нет",
"в найденных нормах прямо не указано",
),
)
for pattern, replacement in replacements:
sanitized = re.sub(pattern, replacement, sanitized)
sanitized = re.sub(r"\s{2,}", " ", sanitized)
return sanitized.strip()
def format_numbered_lines(items: list[str]) -> str:
normalized = [" ".join(item.strip().split()) for item in items if item and item.strip()]
return "\n".join(f"{index}. {item}" for index, item in enumerate(normalized, start=1))
def build_sources_section(sources: list[dict]) -> list[str]:
lines: list[str] = []
seen: set[tuple[str, str, str]] = set()
for source in sources:
title = str(source.get("source_title") or "").strip()
article_number = str(source.get("article_number") or "").strip()
article_title = str(source.get("article_title") or "").strip()
key = (title, article_number, article_title)
if not title or key in seen:
continue
seen.add(key)
if article_number and article_title:
lines.append(f"{title}, ст. {article_number}{article_title}")
elif article_number:
lines.append(f"{title}, ст. {article_number}")
else:
lines.append(title)
if len(lines) >= 5:
break
return lines
def render_structured_initial_answer(
payload: StructuredInitialAnswer,
sources: list[dict],
) -> str:
legal_points = payload.legal_points or ["В найденных нормах прямой ответ на вопрос не раскрыт."]
action_steps = payload.action_steps or ["Уточните обстоятельства и проверьте формулировку вопроса."]
risks = payload.risks or ["Ответ зависит от деталей ситуации и содержания применимых норм."]
source_lines = build_sources_section(sources)
if not source_lines:
source_lines = ["Подходящие нормы закона по этому вопросу автоматически не выделились."]
parts = [
"⚖️ Краткий вывод",
payload.short_conclusion.strip(),
"",
"📌 Что говорит закон",
format_numbered_lines(legal_points),
"",
"✅ Что можно сделать",
format_numbered_lines(action_steps),
"",
"⚠️ Риски и ограничения",
format_numbered_lines(risks),
"",
"📚 Найденные источники",
format_numbered_lines(source_lines),
"",
"❗ Важно",
"Ответ носит информационный характер и не заменяет консультацию юриста.",
]
return "\n".join(parts).strip()
def first_sentence(text: str, limit: int = 220) -> str:
normalized = " ".join(text.split())
normalized = re.sub(r"^\d+\s*\.\s*", "", normalized)
normalized = re.sub(r"\s+([,.;:!?])", r"\1", normalized)
if not normalized:
return ""
match = re.split(r"(?<=[.!?])\s+", normalized, maxsplit=1)
sentence = match[0].strip()
if len(sentence) <= limit:
return sentence
trimmed = sentence[: limit - 1].rstrip(" ,;:")
return f"{trimmed}"
def build_structured_answer_fallback(
*,
question: str,
category: str | None,
sources: list[dict],
) -> StructuredInitialAnswer:
legal_points: list[str] = []
for source in sources[:3]:
article_number = str(source.get("article_number") or "").strip()
article_title = str(source.get("article_title") or "").strip()
chunk_text = str(source.get("chunk_text") or "").strip()
summary = first_sentence(chunk_text)
if article_number and article_title and summary:
legal_points.append(f"Статья {article_number} {article_title}: {summary}")
elif article_number and article_title:
legal_points.append(f"Статья {article_number} {article_title}.")
elif summary:
legal_points.append(summary)
if not legal_points:
legal_points.append("В найденных нормах есть общие ориентиры, но прямой ответ зависит от деталей ситуации.")
category_hint = (category or "").lower()
is_criminal = "уголов" in category_hint or any(
str(source.get("law_type") or "") == "criminal" for source in sources
)
if is_criminal:
short_conclusion = (
"По найденным нормам возможна уголовная ответственность, "
"но точная квалификация и последствия зависят от обстоятельств дела."
)
action_steps = [
"Как можно быстрее обратитесь за очной помощью адвоката по уголовным делам.",
"Соберите и сохраните документы, повестки, протоколы и другие материалы, которые у вас уже есть.",
"Подготовьте точную хронологию событий, потому что для оценки важны обстоятельства и формулировка обвинения.",
]
risks = [
"Точная статья и возможное наказание зависят от обстоятельств, мотива, последствий и процессуального статуса.",
"Без изучения материалов дела нельзя надёжно оценить квалификацию и линию защиты.",
]
else:
short_conclusion = (
"По найденным нормам можно дать только общий ориентир; "
"точный вывод зависит от фактических обстоятельств вопроса."
)
action_steps = [
"Уточните ключевые обстоятельства и формулировку вопроса.",
"Соберите документы и доказательства, которые относятся к ситуации.",
"При необходимости получите очную консультацию профильного юриста.",
]
risks = [
"Ответ может измениться, если появятся новые существенные детали.",
"Без полного набора обстоятельств правовая оценка будет предварительной.",
]
return StructuredInitialAnswer(
short_conclusion=short_conclusion,
legal_points=legal_points,
action_steps=action_steps,
risks=risks,
)
def build_classification_fallback(
*,
question: str,
category: str | None,
region: str | None,
) -> ClassificationResult:
primary_law_type = infer_primary_law_type(category, question)
filters = {"law_type": [primary_law_type]} if primary_law_type != "other" else {}
return ClassificationResult(
legal_domain=primary_law_type,
issue_type="general_question",
jurisdiction="RU",
region=region,
needs_clarification=False,
clarification_questions=[],
search_queries=[question],
filters=filters,
)
class LegalAIService:
def __init__(self, client: AsyncOpenAI, llm_model: str) -> None:
self.client = client
self.llm_model = llm_model
async def classify(
self,
question: str,
category: str | None,
region: str | None,
user_type: str | None = None,
history: list[dict[str, str]] | None = None,
) -> ClassificationResult:
logger.info(
"LLM classification started: category=%s region=%s user_type=%s question_length=%s history_items=%s",
category,
region,
user_type,
len(question),
len(history or []),
)
category_hint = category or "не указана"
region_hint = region or "не указан"
user_type_hint = user_type or "не указан"
history_lines = []
for item in (history or [])[-6:]:
role = item.get("role", "user")
content = item.get("content", "")
history_lines.append(f"{role}: {content}")
history_text = "\n".join(history_lines) if history_lines else "нет"
user_prompt = (
f"Категория пользователя: {category_hint}\n"
f"Регион: {region_hint}\n"
f"Тип пользователя: {user_type_hint}\n"
f"История консультации:\n{history_text}\n"
f"Вопрос: {question}\n"
)
try:
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0,
response_format=CLASSIFIER_RESPONSE_FORMAT,
messages=[
{"role": "system", "content": CLASSIFIER_PROMPT},
{"role": "user", "content": user_prompt},
],
)
except Exception as exc:
logger.warning(
"LLM classification request with schema failed, using heuristic fallback: category=%s question=%s error=%s",
category,
question,
exc,
)
return build_classification_fallback(
question=question,
category=category,
region=region,
)
content = extract_message_content(completion, "classification") or "{}"
try:
payload = extract_json(content, "classification")
except RuntimeError:
logger.warning(
"LLM classification schema response was invalid, using heuristic fallback: category=%s question=%s",
category,
question,
)
return build_classification_fallback(
question=question,
category=category,
region=region,
)
search_queries = payload.get("search_queries") or [question]
filters = payload.get("filters") or {}
normalized_law_types = normalize_law_type_values(filters.get("law_type"))
if "law_type" in filters:
if normalized_law_types:
filters["law_type"] = normalized_law_types
else:
filters.pop("law_type", None)
fallback_law_types = infer_law_types(category)
if fallback_law_types and not filters.get("law_type"):
filters["law_type"] = fallback_law_types
result = ClassificationResult(
legal_domain=payload.get("legal_domain", "other"),
issue_type=payload.get("issue_type", "general_question"),
jurisdiction=payload.get("jurisdiction", "RU"),
region=payload.get("region") or region,
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=payload.get("clarification_questions", []),
search_queries=search_queries,
filters=filters,
)
logger.info(
"LLM classification completed: legal_domain=%s issue_type=%s queries=%s needs_clarification=%s",
result.legal_domain,
result.issue_type,
result.search_queries,
result.needs_clarification,
)
return result
async def answer(
self,
question: str,
category: str | None,
region: str | None,
user_type: str | None,
history: list[dict[str, str]] | None,
sources: list[dict],
) -> str:
logger.info(
"LLM answer generation started: category=%s region=%s user_type=%s sources=%s question_length=%s history_items=%s",
category,
region,
user_type,
len(sources),
len(question),
len(history or []),
)
serialized_sources = json.dumps(sources, ensure_ascii=False, indent=2)
history_lines = []
for item in (history or [])[-6:]:
role = item.get("role", "user")
content = item.get("content", "")
history_lines.append(f"{role}: {content}")
history_text = "\n".join(history_lines) if history_lines else "нет"
has_consultation_history = bool(history)
answer_prompt = FOLLOW_UP_ANSWER_PROMPT if has_consultation_history else ANSWER_PROMPT
user_prompt = (
f"Категория: {category or 'не указана'}\n"
f"Регион: {region or 'не указан'}\n"
f"Тип пользователя: {user_type or 'не указан'}\n"
f"История консультации:\n{history_text}\n"
f"Вопрос пользователя: {question}\n\n"
f"SOURCES:\n{serialized_sources}"
)
try:
if has_consultation_history:
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0.2,
messages=[
{"role": "system", "content": answer_prompt},
{"role": "user", "content": user_prompt},
],
)
else:
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0.2,
response_format=INITIAL_ANSWER_RESPONSE_FORMAT,
messages=[
{"role": "system", "content": answer_prompt},
{"role": "user", "content": user_prompt},
],
)
except Exception as exc:
if has_consultation_history:
raise
logger.warning(
"LLM initial answer request with schema failed, using structured fallback: category=%s question=%s error=%s",
category,
question,
exc,
)
structured_answer = build_structured_answer_fallback(
question=question,
category=category,
sources=sources,
)
answer = render_structured_initial_answer(structured_answer, sources)
logger.info("LLM answer generation completed via fallback: answer_length=%s", len(answer))
return answer
raw_answer = extract_message_content(completion, "answer").strip()
if has_consultation_history:
answer = sanitize_answer_text(raw_answer)
else:
if looks_like_llm_refusal(raw_answer):
logger.warning(
"LLM returned refusal for initial answer, using structured fallback: category=%s question=%s",
category,
question,
)
structured_answer = build_structured_answer_fallback(
question=question,
category=category,
sources=sources,
)
else:
try:
payload = extract_json(raw_answer, "answer")
structured_answer = StructuredInitialAnswer.model_validate(payload)
except (RuntimeError, ValueError) as exc:
logger.warning(
"LLM initial answer schema response was invalid, using structured fallback: category=%s question=%s error=%s",
category,
question,
exc,
)
structured_answer = build_structured_answer_fallback(
question=question,
category=category,
sources=sources,
)
answer = render_structured_initial_answer(structured_answer, sources)
logger.info("LLM answer generation completed: answer_length=%s", len(answer))
return answer
async def generate_consultation_title(
self,
*,
question: str,
category: str | None,
answer: str,
) -> str:
logger.info(
"LLM consultation title generation started: category=%s question_length=%s answer_length=%s",
category,
len(question),
len(answer),
)
user_prompt = (
f"Категория: {category or 'не указана'}\n"
f"Вопрос пользователя: {question}\n"
f"Краткое содержание ответа:\n{answer[:1500]}"
)
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0,
messages=[
{"role": "system", "content": CONSULTATION_TITLE_PROMPT},
{"role": "user", "content": user_prompt},
],
)
content = extract_message_content(completion, "consultation_title")
title = " ".join(content.strip().split()).strip("\"' ")
title = build_fallback_title(title, limit=70)
logger.info("LLM consultation title generation completed: title=%s", title)
return title
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
from functools import lru_cache
import logging
from sentence_transformers import SentenceTransformer
from api.config import settings
logger = logging.getLogger(__name__)
class LocalEmbeddingService:
def __init__(self) -> None:
logger.info(
"Loading embedding model: model=%s device=%s",
settings.embedding_model,
settings.embedding_device,
)
self._model = SentenceTransformer(
settings.embedding_model,
device=settings.embedding_device,
)
self._model.max_seq_length = 512
logger.info(
"Embedding model loaded: model=%s max_seq_length=%s",
settings.embedding_model,
self._model.max_seq_length,
)
def encode_documents(self, texts: list[str]) -> list[list[float]]:
logger.info("Encoding document batch: size=%s", len(texts))
return self._model.encode(
texts,
prompt_name="search_document",
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=settings.index_batch_size,
show_progress_bar=False,
).tolist()
def encode_queries(self, texts: list[str]) -> list[list[float]]:
logger.info("Encoding query batch: size=%s", len(texts))
return self._model.encode(
texts,
prompt_name="search_query",
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=settings.index_batch_size,
show_progress_bar=False,
).tolist()
@lru_cache(maxsize=1)
def get_embedding_service() -> LocalEmbeddingService:
return LocalEmbeddingService()
+122
View File
@@ -0,0 +1,122 @@
from __future__ import annotations
import logging
from fastapi.concurrency import run_in_threadpool
from api.clients.chroma_store import ChromaVectorStore
from api.config import settings
from api.schemas import ClassificationResult
from api.services.local_embeddings import LocalEmbeddingService
from shared import ORM
logger = logging.getLogger(__name__)
def normalize_law_types_arg(value: list[str] | str | None) -> list[str] | None:
if value is None:
return None
if isinstance(value, str):
return [value]
normalized = [item for item in value if isinstance(item, str) and item.strip()]
return normalized or None
class HybridRetrievalService:
def __init__(
self,
orm: ORM,
embedder: LocalEmbeddingService,
vector_store: ChromaVectorStore,
) -> None:
self.orm = orm
self.embedder = embedder
self.vector_store = vector_store
async def retrieve(
self,
classification: ClassificationResult,
fallback_law_types: list[str] | None,
top_k: int,
) -> list[dict]:
queries = classification.search_queries or []
law_types = normalize_law_types_arg(
classification.filters.get("law_type")
or fallback_law_types
or None
)
logger.info(
"Hybrid retrieval started: queries=%s law_types=%s jurisdiction=%s top_k=%s",
queries,
law_types,
classification.jurisdiction,
top_k,
)
merged_scores: dict[int, float] = {}
for query in queries:
lexical_hits = await self.orm.search_law_chunks_full_text(
query=query,
law_types=law_types,
jurisdiction=classification.jurisdiction,
limit=settings.fts_top_k,
)
logger.info("Full-text hits for query '%s': %s", query, len(lexical_hits))
for rank, hit in enumerate(lexical_hits):
merged_scores[hit["chunk_id"]] = merged_scores.get(hit["chunk_id"], 0.0) + (
1.2 / (rank + 1)
)
query_embedding = await run_in_threadpool(
self.embedder.encode_queries,
[query],
)
vector_hits = await run_in_threadpool(
self.vector_store.query,
query_embedding,
settings.vector_top_k,
)
ids = vector_hits.get("ids", [[]])[0]
distances = vector_hits.get("distances", [[]])[0]
metadatas = vector_hits.get("metadatas", [[]])[0]
logger.info("Vector hits for query '%s': %s", query, len(ids))
for rank, (chunk_id, distance, metadata) in enumerate(
zip(ids, distances, metadatas)
):
if law_types and metadata.get("law_type") not in law_types:
continue
score = 1.0 / (rank + 1)
score += max(0.0, 1.0 - float(distance or 1.0))
merged_scores[int(chunk_id)] = merged_scores.get(int(chunk_id), 0.0) + score
ranked_ids = [
chunk_id
for chunk_id, _ in sorted(
merged_scores.items(),
key=lambda item: item[1],
reverse=True,
)
][: max(top_k * 3, top_k)]
rows = await self.orm.get_law_chunks_with_sources_by_ids(
ranked_ids,
law_types=law_types,
jurisdiction=classification.jurisdiction,
)
by_id = {row["chunk_id"]: row for row in rows}
results = []
for chunk_id in ranked_ids:
row = by_id.get(chunk_id)
if row is None:
continue
row["score"] = round(merged_scores.get(chunk_id, 0.0), 6)
results.append(row)
if len(results) >= top_k:
break
logger.info("Hybrid retrieval completed: returned=%s", len(results))
return results