first commit

This commit is contained in:
2026-05-25 01:12:43 +03:00
commit bfc22efe24
83 changed files with 8903 additions and 0 deletions
+35
View File
@@ -0,0 +1,35 @@
# OpenRouter endpoint and API key for LLM calls
OPENAI_BASE_URL=https://openrouter.ai/api/v1
OPENAI_API_KEY=sk-or-v1-replace_me
# Main LLM used for classification and final answer generation
LLM_MODEL=openai/gpt-4.1-mini
# Local embedding model for Chroma indexing/search
EMBEDDING_MODEL=deepvk/USER2-small
EMBEDDING_DEVICE=cpu
# Chroma connection settings
CHROMA_HOST=chromadb
CHROMA_PORT=8000
CHROMA_COLLECTION=law_chunks
CHROMA_SSL=false
# Retrieval settings
RAG_TOP_K=5
FTS_TOP_K=20
VECTOR_TOP_K=20
INDEX_BATCH_SIZE=16
LLM_TIMEOUT_SECONDS=90
# FastAPI bind settings
API_HOST=0.0.0.0
API_PORT=8080
LOG_LEVEL=INFO
# Background indexing on startup
AUTO_INDEX_ON_STARTUP=true
AUTO_INDEX_ONLY_IF_EMPTY=true
AUTO_INDEX_RESET_COLLECTION=false
AUTO_INDEX_RETRY_DELAY_SECONDS=15
AUTO_INDEX_MAX_ATTEMPTS=20
+16
View File
@@ -0,0 +1,16 @@
FROM python:3.12-slim
WORKDIR /app
ENV PYTHONPATH=/app
ENV HF_HOME=/root/.cache/huggingface
ENV CUDA_VISIBLE_DEVICES=""
COPY api/requirements.txt /app/api-requirements.txt
RUN pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu torch==2.7.0 && \
pip install --no-cache-dir -r /app/api-requirements.txt
COPY . /app
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8080"]
+1
View File
@@ -0,0 +1 @@
"""FastAPI RAG service package."""
+1
View File
@@ -0,0 +1 @@
"""API clients."""
+88
View File
@@ -0,0 +1,88 @@
from __future__ import annotations
import logging
from typing import Any
import chromadb
from chromadb.config import Settings as ChromaSettings
from api.config import settings
logger = logging.getLogger(__name__)
class ChromaVectorStore:
def __init__(self) -> None:
self._client = chromadb.HttpClient(
host=settings.chroma_host,
port=settings.chroma_port,
ssl=settings.chroma_ssl,
settings=ChromaSettings(anonymized_telemetry=False),
)
self._collection_name = settings.chroma_collection
logger.info(
"Chroma client configured: host=%s port=%s collection=%s ssl=%s",
settings.chroma_host,
settings.chroma_port,
self._collection_name,
settings.chroma_ssl,
)
@property
def collection_name(self) -> str:
return self._collection_name
def get_collection(self):
return self._client.get_or_create_collection(
name=self._collection_name,
metadata={"hnsw:space": "cosine"},
)
def count(self) -> int:
return int(self.get_collection().count())
def reset_collection(self) -> None:
logger.warning("Resetting Chroma collection: %s", self._collection_name)
try:
self._client.delete_collection(self._collection_name)
except Exception:
logger.exception(
"Could not delete Chroma collection before reset, continuing with create_collection"
)
self.get_collection()
def upsert(
self,
ids: list[str],
documents: list[str],
embeddings: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
logger.info(
"Upserting %s embeddings into Chroma collection=%s",
len(ids),
self._collection_name,
)
self.get_collection().upsert(
ids=ids,
documents=documents,
embeddings=embeddings,
metadatas=metadatas,
)
def query(
self,
query_embeddings: list[list[float]],
n_results: int,
) -> dict[str, Any]:
logger.info(
"Running vector query against collection=%s n_results=%s",
self._collection_name,
n_results,
)
return self.get_collection().query(
query_embeddings=query_embeddings,
n_results=n_results,
include=["distances", "metadatas", "documents"],
)
+13
View File
@@ -0,0 +1,13 @@
from __future__ import annotations
from openai import AsyncOpenAI
from api.config import settings
def build_openai_client() -> AsyncOpenAI:
return AsyncOpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url,
timeout=settings.llm_timeout_seconds,
)
+50
View File
@@ -0,0 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass
from decouple import config
def _to_bool(value: str) -> bool:
return value.strip().lower() in {"1", "true", "yes", "on"}
@dataclass(frozen=True)
class Settings:
app_host: str = config("API_HOST", default="0.0.0.0")
app_port: int = config("API_PORT", cast=int, default=8080)
openai_base_url: str = config(
"OPENAI_BASE_URL", default="https://openrouter.ai/api/v1"
)
openai_api_key: str = config("OPENAI_API_KEY", default="")
llm_model: str = config("LLM_MODEL", default="openai/gpt-4.1-mini")
embedding_model: str = config(
"EMBEDDING_MODEL", default="deepvk/USER2-small"
)
embedding_device: str = config("EMBEDDING_DEVICE", default="cpu")
chroma_host: str = config("CHROMA_HOST", default="chromadb")
chroma_port: int = config("CHROMA_PORT", cast=int, default=8000)
chroma_collection: str = config("CHROMA_COLLECTION", default="law_chunks")
chroma_ssl: bool = _to_bool(config("CHROMA_SSL", default="false"))
rag_top_k: int = config("RAG_TOP_K", cast=int, default=5)
fts_top_k: int = config("FTS_TOP_K", cast=int, default=20)
vector_top_k: int = config("VECTOR_TOP_K", cast=int, default=20)
index_batch_size: int = config("INDEX_BATCH_SIZE", cast=int, default=32)
llm_timeout_seconds: int = config("LLM_TIMEOUT_SECONDS", cast=int, default=90)
log_level: str = config("LOG_LEVEL", default="INFO")
auto_index_on_startup: bool = _to_bool(config("AUTO_INDEX_ON_STARTUP", default="true"))
auto_index_only_if_empty: bool = _to_bool(
config("AUTO_INDEX_ONLY_IF_EMPTY", default="true")
)
auto_index_reset_collection: bool = _to_bool(
config("AUTO_INDEX_RESET_COLLECTION", default="false")
)
auto_index_retry_delay_seconds: int = config(
"AUTO_INDEX_RETRY_DELAY_SECONDS", cast=int, default=15
)
auto_index_max_attempts: int = config(
"AUTO_INDEX_MAX_ATTEMPTS", cast=int, default=20
)
settings = Settings()
+45
View File
@@ -0,0 +1,45 @@
from __future__ import annotations
from functools import lru_cache
from api.clients.chroma_store import ChromaVectorStore
from api.clients.openrouter_client import build_openai_client
from api.config import settings
from api.services.indexing import IndexingService
from api.services.legal_ai import LegalAIService
from api.services.local_embeddings import get_embedding_service
from api.services.retrieval import HybridRetrievalService
from shared import ORM
@lru_cache(maxsize=1)
def get_orm() -> ORM:
return ORM()
@lru_cache(maxsize=1)
def get_vector_store() -> ChromaVectorStore:
return ChromaVectorStore()
@lru_cache(maxsize=1)
def get_ai_service() -> LegalAIService:
return LegalAIService(build_openai_client(), settings.llm_model)
@lru_cache(maxsize=1)
def get_retrieval_service() -> HybridRetrievalService:
return HybridRetrievalService(
orm=get_orm(),
embedder=get_embedding_service(),
vector_store=get_vector_store(),
)
@lru_cache(maxsize=1)
def get_indexing_service() -> IndexingService:
return IndexingService(
orm=get_orm(),
embedder=get_embedding_service(),
vector_store=get_vector_store(),
)
+24
View File
@@ -0,0 +1,24 @@
from __future__ import annotations
import logging
from api.config import settings
def configure_logging() -> None:
logging.basicConfig(
level=getattr(logging, settings.log_level.upper(), logging.INFO),
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("chromadb").setLevel(logging.INFO)
for logger_name in (
"chromadb.telemetry",
"chromadb.telemetry.product",
"chromadb.telemetry.product.posthog",
):
noisy_logger = logging.getLogger(logger_name)
noisy_logger.setLevel(logging.CRITICAL)
noisy_logger.propagate = False
noisy_logger.disabled = True
+148
View File
@@ -0,0 +1,148 @@
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import logging
from time import perf_counter
from fastapi import FastAPI, Request
from api.config import settings
from api.deps import get_indexing_service, get_orm, get_vector_store
from api.logging import configure_logging
from api.routers.health import router as health_router
from api.routers.indexing import router as indexing_router
from api.routers.rag import router as rag_router
configure_logging()
logger = logging.getLogger(__name__)
async def run_startup_indexing_task() -> None:
if not settings.auto_index_on_startup:
logger.info("Startup auto-index is disabled")
return
indexing_service = get_indexing_service()
vector_store = get_vector_store()
for attempt in range(1, settings.auto_index_max_attempts + 1):
try:
current_count = vector_store.count()
db_chunk_count = await indexing_service.get_indexable_chunks_count()
should_reset_collection = settings.auto_index_reset_collection
logger.info(
"Startup auto-index check: attempt=%s/%s chroma_count=%s postgres_chunks=%s",
attempt,
settings.auto_index_max_attempts,
current_count,
db_chunk_count,
)
if db_chunk_count == 0:
logger.warning(
"No indexable chunks found in Postgres yet, retrying in %ss",
settings.auto_index_retry_delay_seconds,
)
await asyncio.sleep(settings.auto_index_retry_delay_seconds)
continue
if settings.auto_index_only_if_empty and current_count == db_chunk_count and current_count > 0:
logger.info(
"Skipping startup auto-index because Chroma is already in sync with Postgres: %s items",
current_count,
)
return
if current_count not in {0, db_chunk_count}:
should_reset_collection = True
logger.warning(
"Chroma/Postgres count mismatch detected, forcing collection reset: chroma_count=%s postgres_chunks=%s",
current_count,
db_chunk_count,
)
result = await indexing_service.rebuild(
reset_collection=should_reset_collection,
)
logger.info("Startup auto-index completed successfully: %s", result)
return
except asyncio.CancelledError:
logger.info("Startup auto-index task cancelled")
raise
except Exception:
logger.exception(
"Startup auto-index attempt %s/%s failed",
attempt,
settings.auto_index_max_attempts,
)
await asyncio.sleep(settings.auto_index_retry_delay_seconds)
logger.error(
"Startup auto-index exhausted all %s attempts",
settings.auto_index_max_attempts,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
orm = get_orm()
startup_index_task: asyncio.Task | None = None
logger.info("API startup initiated")
await orm.init_schema()
logger.info("Database schema is ready")
startup_index_task = asyncio.create_task(run_startup_indexing_task())
app.state.startup_index_task = startup_index_task
yield
if startup_index_task is not None and not startup_index_task.done():
startup_index_task.cancel()
try:
await startup_index_task
except asyncio.CancelledError:
pass
await orm.close()
logger.info("API shutdown completed")
app = FastAPI(title="LawBot RAG API", version="0.1.0", lifespan=lifespan)
@app.middleware("http")
async def log_requests(request: Request, call_next):
started_at = perf_counter()
logger.info(
"HTTP request started: method=%s path=%s client=%s",
request.method,
request.url.path,
request.client.host if request.client else "unknown",
)
try:
response = await call_next(request)
except Exception:
duration_ms = round((perf_counter() - started_at) * 1000, 2)
logger.exception(
"HTTP request failed: method=%s path=%s duration_ms=%s",
request.method,
request.url.path,
duration_ms,
)
raise
duration_ms = round((perf_counter() - started_at) * 1000, 2)
logger.info(
"HTTP request completed: method=%s path=%s status=%s duration_ms=%s",
request.method,
request.url.path,
response.status_code,
duration_ms,
)
return response
app.include_router(health_router)
app.include_router(indexing_router)
app.include_router(rag_router)
+1
View File
@@ -0,0 +1 @@
"""Prompt templates for the RAG API."""
+91
View File
@@ -0,0 +1,91 @@
CLASSIFIER_PROMPT = """Ты классификатор юридических вопросов по законам РФ.
Верни только JSON без markdown.
Поля:
- legal_domain
- issue_type
- jurisdiction
- region
- needs_clarification
- clarification_questions
- search_queries
- filters
Правила:
1. jurisdiction всегда RU.
2. Если данных недостаточно, needs_clarification = true.
3. search_queries должны быть пригодны для поиска по базе законов.
4. Не придумывай статьи.
5. Не давай юридический ответ на этом этапе.
6. filters.law_type заполняй только реальными доменами права, если уверен.
"""
ANSWER_PROMPT = """Ты юридический ИИ-консультант по законам РФ.
Твоя задача — подготовить структурированный ответ пользователю простым языком только на основании переданных норм закона.
Жесткие правила:
1. Используй только переданные фрагменты законов.
2. Не придумывай статьи, номера законов, судебную практику и сроки.
3. Если источников недостаточно, прямо скажи об этом.
4. Не обещай победу в суде.
5. Не выдавай себя за адвоката.
6. Не помогай обходить закон.
7. В конце добавь дисклеймер.
8. Не используй markdown, символы **, __, #, списки через `-` и другое markdown-оформление.
9. Пиши обычным текстом. Для акцентов используй короткие заголовки и нумерованные пункты.
10. Никогда не используй слова SOURCES, source, chunk, retrieval, база, векторный поиск, фрагменты, контекст.
11. Нельзя писать фразы вроде: "в ваших SOURCES", "по этим источникам", "на основании этих источников", "в базе нет", "в контексте нет".
12. Если данных не хватает, говори только по-человечески, например: "По тем нормам, которые удалось найти, прямого ответа на этот нюанс нет" или "В найденных нормах этот частный вопрос прямо не раскрыт".
13. Верни только JSON без markdown и без пояснений.
JSON schema:
{
"short_conclusion": "краткий вывод в 1-3 предложениях",
"legal_points": ["ключевая норма 1", "ключевая норма 2"],
"action_steps": ["практический шаг 1", "практический шаг 2"],
"risks": ["риск или ограничение 1", "риск или ограничение 2"]
}
"""
FOLLOW_UP_ANSWER_PROMPT = """Ты юридический ИИ-консультант по законам РФ.
Твоя задача — продолжить уже начатую консультацию и ответить пользователю простым, живым и естественным языком только на основании переданных норм закона.
Жесткие правила:
1. Используй только переданные нормы закона.
2. Учитывай историю консультации и отвечай именно на последний вопрос пользователя.
3. Не придумывай статьи, номера законов, судебную практику и сроки.
4. Если источников недостаточно, прямо скажи об этом.
5. Не обещай победу в суде.
6. Не выдавай себя за адвоката.
7. Не помогай обходить закон.
8. Не используй markdown и символы **, __, #, списки через `-`.
9. Не используй жесткий шаблон с разделами, если вопрос этого не требует.
10. Если уместно, можешь дать короткий пошаговый план.
11. В конце кратко укажи, на какие нормы ты опираешься, и добавь дисклеймер.
12. Никогда не используй слова SOURCES, source, chunk, retrieval, база, векторный поиск, фрагменты, контекст.
13. Нельзя писать фразы вроде: "в ваших SOURCES", "по этим источникам", "на основании этих источников", "в базе нет", "в контексте нет".
14. Если данных не хватает, формулируй это только естественным языком без упоминания внутренней кухни системы.
Формат ответа:
- свободный, разговорный, но деловой и понятный;
- без лишней воды;
- можно использовать короткие абзацы и списки;
- если вопрос уточняющий, отвечай прямо на него, а не повторяй всю предыдущую структуру заново.
"""
CONSULTATION_TITLE_PROMPT = """Ты помогаешь придумать короткий заголовок для юридической консультации.
Правила:
1. Верни только сам заголовок без кавычек, markdown и пояснений.
2. Заголовок должен быть коротким: 3-8 слов, максимум 70 символов.
3. Заголовок должен ясно отражать суть проблемы пользователя.
4. Не используй даты, обращения, вводные слова и канцеляризмы.
5. Не ставь точку в конце.
6. Пиши по-русски.
"""
+9
View File
@@ -0,0 +1,9 @@
fastapi==0.115.9
uvicorn==0.34.2
openai==1.82.0
chromadb==1.0.12
python-decouple==3.8
SQLAlchemy==2.0.38
asyncpg==0.30.0
pydantic==2.10.6
sentence-transformers==5.1.2
+1
View File
@@ -0,0 +1 @@
"""FastAPI routers."""
+15
View File
@@ -0,0 +1,15 @@
import logging
from fastapi import APIRouter
from api.schemas import HealthResponse
router = APIRouter(tags=["health"])
logger = logging.getLogger(__name__)
@router.get("/health", response_model=HealthResponse)
async def healthcheck() -> HealthResponse:
logger.debug("Healthcheck request")
return HealthResponse(status="ok")
+19
View File
@@ -0,0 +1,19 @@
from fastapi import APIRouter
from api.deps import get_indexing_service
from api.schemas import IndexRequest, IndexResponse
router = APIRouter(prefix="/api/v1/index", tags=["index"])
@router.post("/rebuild", response_model=IndexResponse)
async def rebuild_index(payload: IndexRequest) -> IndexResponse:
service = get_indexing_service()
result = await service.rebuild(
source_ids=payload.source_ids,
law_types=payload.law_types,
reset_collection=payload.reset_collection,
batch_size=payload.batch_size,
)
return IndexResponse(**result)
+185
View File
@@ -0,0 +1,185 @@
import logging
from fastapi import APIRouter, HTTPException
from api.deps import get_ai_service, get_orm, get_retrieval_service
from api.schemas import (
AnswerRequest,
AnswerResponse,
ClassificationResult,
RetrievedChunk,
SearchRequest,
SearchResponse,
)
from api.services.legal_ai import build_fallback_title, infer_law_types
router = APIRouter(prefix="/api/v1/rag", tags=["rag"])
logger = logging.getLogger(__name__)
@router.post("/search", response_model=SearchResponse)
async def search(payload: SearchRequest) -> SearchResponse:
logger.info(
"RAG search request: category=%s region=%s top_k=%s question_length=%s",
payload.category,
payload.region,
payload.top_k,
len(payload.question),
)
ai_service = get_ai_service()
retrieval = get_retrieval_service()
try:
classification = await ai_service.classify(
question=payload.question,
category=payload.category,
region=payload.region,
user_type=payload.user_type,
history=payload.history,
)
chunks = await retrieval.retrieve(
classification=classification,
fallback_law_types=payload.law_types or infer_law_types(payload.category),
top_k=payload.top_k,
)
except RuntimeError as exc:
logger.exception("RAG search failed")
raise HTTPException(status_code=502, detail=str(exc)) from exc
return SearchResponse(
classification=classification,
generated_queries=classification.search_queries,
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
)
@router.post("/answer", response_model=AnswerResponse)
async def answer(payload: AnswerRequest) -> AnswerResponse:
logger.info(
"RAG answer request: user_id=%s consultation_id=%s save_history=%s category=%s region=%s top_k=%s question_length=%s",
payload.user_id,
payload.consultation_id,
payload.save_history,
payload.category,
payload.region,
payload.top_k,
len(payload.question),
)
ai_service = get_ai_service()
retrieval = get_retrieval_service()
orm = get_orm()
try:
classification = await ai_service.classify(
question=payload.question,
category=payload.category,
region=payload.region,
user_type=payload.user_type,
history=payload.history,
)
chunks = await retrieval.retrieve(
classification=classification,
fallback_law_types=payload.law_types or infer_law_types(payload.category),
top_k=payload.top_k,
)
except RuntimeError as exc:
logger.exception("RAG answer failed on classification/retrieval stage")
raise HTTPException(status_code=502, detail=str(exc)) from exc
if not chunks:
logger.warning("RAG answer request returned no reliable chunks")
raise HTTPException(
status_code=404,
detail="No reliable law chunks were found for this question.",
)
try:
answer_text = await ai_service.answer(
question=payload.question,
category=payload.category,
region=payload.region,
user_type=payload.user_type,
history=payload.history,
sources=chunks,
)
except RuntimeError as exc:
logger.exception("RAG answer failed on generation stage")
raise HTTPException(status_code=502, detail=str(exc)) from exc
consultation_id = payload.consultation_id
user_message_id = None
assistant_message_id = None
if payload.save_history:
if payload.user_id is None:
raise HTTPException(
status_code=400,
detail="user_id is required when save_history=true",
)
user = await orm.get_user(payload.user_id)
if user is None:
raise HTTPException(
status_code=404,
detail="User was not found. Start the bot first so the profile is created.",
)
if consultation_id is not None:
consultation = await orm.get_consultation(
consultation_id=consultation_id,
user_id=payload.user_id,
)
if consultation is None:
raise HTTPException(
status_code=404,
detail="Consultation was not found for this user.",
)
if consultation_id is None:
try:
consultation_title = await ai_service.generate_consultation_title(
question=payload.question,
category=payload.category or classification.legal_domain,
answer=answer_text,
)
except RuntimeError:
logger.exception("Consultation title generation failed, using fallback title")
consultation_title = build_fallback_title(payload.question)
consultation = await orm.create_consultation(
user_id=payload.user_id,
category=payload.category or classification.legal_domain,
title=consultation_title,
region=payload.region or classification.region,
)
consultation_id = consultation.id
user_message = await orm.create_message(
consultation_id=consultation_id,
role="user",
content=payload.question,
)
user_message_id = user_message.id
await orm.create_rag_query(
consultation_id=consultation_id,
user_message_id=user_message_id,
generated_queries=classification.search_queries,
retrieved_chunks=chunks,
)
assistant_message = await orm.create_message(
consultation_id=consultation_id,
role="assistant",
content=answer_text,
sources_json=chunks,
)
assistant_message_id = assistant_message.id
return AnswerResponse(
classification=classification,
generated_queries=classification.search_queries,
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
answer=answer_text,
consultation_id=consultation_id,
user_message_id=user_message_id,
assistant_message_id=assistant_message_id,
)
+92
View File
@@ -0,0 +1,92 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
class HealthResponse(BaseModel):
status: str
class IndexRequest(BaseModel):
source_ids: list[int] | None = None
law_types: list[str] | None = None
reset_collection: bool = True
batch_size: int | None = None
class IndexResponse(BaseModel):
indexed_chunks: int
indexed_sources: int
collection_name: str
class SearchRequest(BaseModel):
question: str
category: str | None = None
region: str | None = None
user_type: str | None = None
history: list[dict[str, str]] = Field(default_factory=list)
law_types: list[str] | None = None
top_k: int = 5
class RetrievedChunk(BaseModel):
chunk_id: int
source_id: int
source_title: str
source_url: str | None = None
law_type: str | None = None
article_number: str | None = None
article_title: str | None = None
chunk_text: str
metadata: dict[str, Any] = Field(default_factory=dict)
score: float
class ClassificationResult(BaseModel):
legal_domain: str
issue_type: str
jurisdiction: str = "RU"
region: str | None = None
needs_clarification: bool = False
clarification_questions: list[str] = Field(default_factory=list)
search_queries: list[str] = Field(default_factory=list)
filters: dict[str, Any] = Field(default_factory=dict)
class StructuredInitialAnswer(BaseModel):
short_conclusion: str
legal_points: list[str] = Field(default_factory=list)
action_steps: list[str] = Field(default_factory=list)
risks: list[str] = Field(default_factory=list)
class SearchResponse(BaseModel):
classification: ClassificationResult
generated_queries: list[str]
retrieved_chunks: list[RetrievedChunk]
class AnswerRequest(BaseModel):
user_id: int | None = None
consultation_id: int | None = None
save_history: bool = False
question: str
category: str | None = None
region: str | None = None
user_type: str | None = None
history: list[dict[str, str]] = Field(default_factory=list)
law_types: list[str] | None = None
top_k: int = 5
class AnswerResponse(BaseModel):
classification: ClassificationResult
generated_queries: list[str]
retrieved_chunks: list[RetrievedChunk]
answer: str
consultation_id: int | None = None
user_message_id: int | None = None
assistant_message_id: int | None = None
+1
View File
@@ -0,0 +1 @@
"""Service layer for the RAG API."""
+126
View File
@@ -0,0 +1,126 @@
from __future__ import annotations
import logging
from fastapi.concurrency import run_in_threadpool
from api.clients.chroma_store import ChromaVectorStore
from api.config import settings
from api.services.local_embeddings import LocalEmbeddingService
from shared import ORM
logger = logging.getLogger(__name__)
def build_embedding_text(chunk: dict) -> str:
article_title = chunk.get("article_title") or ""
source_title = chunk.get("source_title") or ""
return "\n".join(
part
for part in [
f"source: {source_title}",
f"article: {article_title}",
f"text: {chunk['chunk_text']}",
]
if part.strip()
)
class IndexingService:
def __init__(
self,
orm: ORM,
embedder: LocalEmbeddingService,
vector_store: ChromaVectorStore,
) -> None:
self.orm = orm
self.embedder = embedder
self.vector_store = vector_store
async def get_indexable_chunks_count(
self,
source_ids: list[int] | None = None,
law_types: list[str] | None = None,
) -> int:
chunks = await self.orm.list_chunks_for_indexing(
source_ids=source_ids,
law_types=law_types,
active_only=True,
)
return len(chunks)
async def rebuild(
self,
source_ids: list[int] | None = None,
law_types: list[str] | None = None,
reset_collection: bool = True,
batch_size: int | None = None,
) -> dict:
logger.info(
"Index rebuild started: reset_collection=%s source_ids=%s law_types=%s",
reset_collection,
source_ids,
law_types,
)
chunks = await self.orm.list_chunks_for_indexing(
source_ids=source_ids,
law_types=law_types,
active_only=True,
)
logger.info("Loaded %s chunks from Postgres for indexing", len(chunks))
if reset_collection:
await run_in_threadpool(self.vector_store.reset_collection)
batch_size = batch_size or settings.index_batch_size
indexed_chunks = 0
indexed_sources = len({chunk["source_id"] for chunk in chunks})
for start in range(0, len(chunks), batch_size):
batch = chunks[start : start + batch_size]
batch_number = (start // batch_size) + 1
total_batches = max(1, (len(chunks) + batch_size - 1) // batch_size)
logger.info(
"Indexing batch %s/%s with %s chunks",
batch_number,
total_batches,
len(batch),
)
embeddings = await run_in_threadpool(
self.embedder.encode_documents,
[build_embedding_text(chunk) for chunk in batch],
)
await run_in_threadpool(
self.vector_store.upsert,
[str(chunk["chunk_id"]) for chunk in batch],
[chunk["chunk_text"] for chunk in batch],
embeddings,
[
{
"chunk_id": chunk["chunk_id"],
"source_id": chunk["source_id"],
"source_title": chunk["source_title"],
"source_url": chunk["source_url"],
"law_type": chunk["law_type"] or "",
"jurisdiction": chunk["jurisdiction"],
"article_number": chunk["article_number"] or "",
"article_title": chunk["article_title"] or "",
"version_hash": chunk["version_hash"],
}
for chunk in batch
],
)
indexed_chunks += len(batch)
logger.info(
"Indexed %s/%s chunks into Chroma",
indexed_chunks,
len(chunks),
)
result = {
"indexed_chunks": indexed_chunks,
"indexed_sources": indexed_sources,
"collection_name": self.vector_store.collection_name,
}
logger.info("Index rebuild completed: %s", result)
return result
+724
View File
@@ -0,0 +1,724 @@
from __future__ import annotations
import json
import logging
import re
from openai import AsyncOpenAI
from api.prompts.rag_prompts import (
ANSWER_PROMPT,
CLASSIFIER_PROMPT,
CONSULTATION_TITLE_PROMPT,
FOLLOW_UP_ANSWER_PROMPT,
)
from api.schemas import ClassificationResult, StructuredInitialAnswer
logger = logging.getLogger(__name__)
CATEGORY_MAP = {
"работа": ["labor"],
"труд": ["labor"],
"защита прав потребителей": ["consumer", "civil"],
"потребител": ["consumer", "civil"],
"жилье": ["housing", "civil", "mortgage"],
"аренда": ["housing", "civil"],
"семья": ["family"],
"долги": ["civil", "enforcement"],
"займы": ["civil"],
"договоры": ["civil"],
"договор": ["civil"],
"суд": ["procedural"],
"процесс": ["procedural"],
"административ": ["administrative"],
"уголов": ["criminal"],
"краж": ["criminal"],
"мошеннич": ["criminal"],
}
LAW_TYPE_ALIASES = {
"labor": "labor",
"труд": "labor",
"трудовое право": "labor",
"criminal": "criminal",
"уголов": "criminal",
"civil": "civil",
"граждан": "civil",
"договор": "civil",
"consumer": "consumer",
"защита прав потребителей": "consumer",
"потребител": "consumer",
"housing": "housing",
"жилищ": "housing",
"аренда": "housing",
"family": "family",
"семейн": "family",
"procedural": "procedural",
"процесс": "procedural",
"суд": "procedural",
"administrative": "administrative",
"административ": "administrative",
"enforcement": "enforcement",
"исполнительн": "enforcement",
"mortgage": "mortgage",
"ипотек": "mortgage",
}
INITIAL_ANSWER_RESPONSE_FORMAT = {
"type": "json_schema",
"json_schema": {
"name": "lawbot_initial_answer",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"short_conclusion": {"type": "string"},
"legal_points": {
"type": "array",
"items": {"type": "string"},
},
"action_steps": {
"type": "array",
"items": {"type": "string"},
},
"risks": {
"type": "array",
"items": {"type": "string"},
},
},
"required": [
"short_conclusion",
"legal_points",
"action_steps",
"risks",
],
},
},
}
CLASSIFIER_RESPONSE_FORMAT = {
"type": "json_schema",
"json_schema": {
"name": "lawbot_classifier",
"strict": True,
"schema": {
"type": "object",
"additionalProperties": False,
"properties": {
"legal_domain": {"type": "string"},
"issue_type": {"type": "string"},
"jurisdiction": {"type": "string"},
"region": {
"type": ["string", "null"],
},
"needs_clarification": {"type": "boolean"},
"clarification_questions": {
"type": "array",
"items": {"type": "string"},
},
"search_queries": {
"type": "array",
"items": {"type": "string"},
},
"filters": {
"type": "object",
"additionalProperties": False,
"properties": {
"law_type": {
"type": ["array", "null"],
"items": {"type": "string"},
},
},
},
},
"required": [
"legal_domain",
"issue_type",
"jurisdiction",
"region",
"needs_clarification",
"clarification_questions",
"search_queries",
"filters",
],
},
},
}
def extract_json(content: str, purpose: str = "response") -> dict:
try:
return json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, re.S)
if not match:
logger.error("LLM %s returned non-JSON content: %s", purpose, content)
raise RuntimeError(f"LLM {purpose} returned invalid JSON.")
try:
return json.loads(match.group(0))
except json.JSONDecodeError as exc:
logger.error("LLM %s returned malformed JSON fragment: %s", purpose, content)
raise RuntimeError(f"LLM {purpose} returned malformed JSON.") from exc
def looks_like_llm_refusal(content: str) -> bool:
normalized = " ".join(content.lower().split())
refusal_markers = (
"i cannot assist",
"i can't assist",
"i cannot help",
"i'm sorry, but i cannot",
"не могу помочь с этим",
"не могу помочь в этом",
"не могу содействовать",
"не могу помочь с запросом",
"не могу ответить на этот запрос",
)
return any(marker in normalized for marker in refusal_markers)
def infer_law_types(category: str | None) -> list[str] | None:
if not category:
return None
normalized = category.lower().strip()
for key, law_types in CATEGORY_MAP.items():
if key in normalized:
return law_types
return None
def normalize_law_type_values(value) -> list[str] | None:
if value is None:
return None
raw_values = value if isinstance(value, list) else [value]
normalized_values: list[str] = []
for raw_value in raw_values:
if not isinstance(raw_value, str):
continue
raw_normalized = raw_value.strip().lower()
for alias, code in LAW_TYPE_ALIASES.items():
if alias in raw_normalized:
if code not in normalized_values:
normalized_values.append(code)
break
return normalized_values or None
def extract_message_content(completion, purpose: str) -> str:
choices = getattr(completion, "choices", None)
if not choices:
logger.error(
"LLM %s returned empty choices: model=%s id=%s usage=%s raw=%s",
purpose,
getattr(completion, "model", None),
getattr(completion, "id", None),
getattr(completion, "usage", None),
completion,
)
raise RuntimeError(
"LLM provider returned an empty response. Check OPENROUTER model name and provider response."
)
first_choice = choices[0]
message = getattr(first_choice, "message", None)
if message is None:
logger.error(
"LLM %s returned choice without message: model=%s id=%s choice=%s",
purpose,
getattr(completion, "model", None),
getattr(completion, "id", None),
first_choice,
)
raise RuntimeError("LLM provider returned a malformed response without message.")
content = getattr(message, "content", None)
if content is None:
logger.error(
"LLM %s returned empty message content: model=%s id=%s finish_reason=%s message=%s",
purpose,
getattr(completion, "model", None),
getattr(completion, "id", None),
getattr(first_choice, "finish_reason", None),
message,
)
raise RuntimeError("LLM provider returned an empty message content.")
return content
def build_fallback_title(question: str, limit: int = 70) -> str:
title = " ".join(question.strip().split())
if not title:
return "Юридическая консультация"
title = title.rstrip(" .,!?:;")
if len(title) <= limit:
return title
trimmed = title[: limit - 1].rstrip(" .,!?:;")
return f"{trimmed}"
def infer_primary_law_type(category: str | None, question: str) -> str:
inferred = infer_law_types(category)
if inferred:
return inferred[0]
normalized_question = question.lower()
for key, law_types in CATEGORY_MAP.items():
if key in normalized_question:
return law_types[0]
return "other"
def sanitize_answer_text(answer: str) -> str:
sanitized = answer.strip()
replacements = (
(r"(?i)\bSOURCES\b", "нормах закона"),
(r"(?i)\bsource\b", "нормах закона"),
(r"(?i)\bchunk(?:s)?\b", "нормах закона"),
(r"(?i)\bretrieval\b", "поиске норм"),
(
r"(?i)в ваших нормах закона",
"в найденных нормах закона",
),
(
r"(?i)на основании этих источников",
"по найденным нормам закона",
),
(
r"(?i)по этим источникам",
"по найденным нормам закона",
),
(
r"(?i)в базе нет",
"прямого ответа в найденных нормах нет",
),
(
r"(?i)в контексте нет",
"в найденных нормах прямо не указано",
),
)
for pattern, replacement in replacements:
sanitized = re.sub(pattern, replacement, sanitized)
sanitized = re.sub(r"\s{2,}", " ", sanitized)
return sanitized.strip()
def format_numbered_lines(items: list[str]) -> str:
normalized = [" ".join(item.strip().split()) for item in items if item and item.strip()]
return "\n".join(f"{index}. {item}" for index, item in enumerate(normalized, start=1))
def build_sources_section(sources: list[dict]) -> list[str]:
lines: list[str] = []
seen: set[tuple[str, str, str]] = set()
for source in sources:
title = str(source.get("source_title") or "").strip()
article_number = str(source.get("article_number") or "").strip()
article_title = str(source.get("article_title") or "").strip()
key = (title, article_number, article_title)
if not title or key in seen:
continue
seen.add(key)
if article_number and article_title:
lines.append(f"{title}, ст. {article_number}{article_title}")
elif article_number:
lines.append(f"{title}, ст. {article_number}")
else:
lines.append(title)
if len(lines) >= 5:
break
return lines
def render_structured_initial_answer(
payload: StructuredInitialAnswer,
sources: list[dict],
) -> str:
legal_points = payload.legal_points or ["В найденных нормах прямой ответ на вопрос не раскрыт."]
action_steps = payload.action_steps or ["Уточните обстоятельства и проверьте формулировку вопроса."]
risks = payload.risks or ["Ответ зависит от деталей ситуации и содержания применимых норм."]
source_lines = build_sources_section(sources)
if not source_lines:
source_lines = ["Подходящие нормы закона по этому вопросу автоматически не выделились."]
parts = [
"⚖️ Краткий вывод",
payload.short_conclusion.strip(),
"",
"📌 Что говорит закон",
format_numbered_lines(legal_points),
"",
"✅ Что можно сделать",
format_numbered_lines(action_steps),
"",
"⚠️ Риски и ограничения",
format_numbered_lines(risks),
"",
"📚 Найденные источники",
format_numbered_lines(source_lines),
"",
"❗ Важно",
"Ответ носит информационный характер и не заменяет консультацию юриста.",
]
return "\n".join(parts).strip()
def first_sentence(text: str, limit: int = 220) -> str:
normalized = " ".join(text.split())
normalized = re.sub(r"^\d+\s*\.\s*", "", normalized)
normalized = re.sub(r"\s+([,.;:!?])", r"\1", normalized)
if not normalized:
return ""
match = re.split(r"(?<=[.!?])\s+", normalized, maxsplit=1)
sentence = match[0].strip()
if len(sentence) <= limit:
return sentence
trimmed = sentence[: limit - 1].rstrip(" ,;:")
return f"{trimmed}"
def build_structured_answer_fallback(
*,
question: str,
category: str | None,
sources: list[dict],
) -> StructuredInitialAnswer:
legal_points: list[str] = []
for source in sources[:3]:
article_number = str(source.get("article_number") or "").strip()
article_title = str(source.get("article_title") or "").strip()
chunk_text = str(source.get("chunk_text") or "").strip()
summary = first_sentence(chunk_text)
if article_number and article_title and summary:
legal_points.append(f"Статья {article_number} {article_title}: {summary}")
elif article_number and article_title:
legal_points.append(f"Статья {article_number} {article_title}.")
elif summary:
legal_points.append(summary)
if not legal_points:
legal_points.append("В найденных нормах есть общие ориентиры, но прямой ответ зависит от деталей ситуации.")
category_hint = (category or "").lower()
is_criminal = "уголов" in category_hint or any(
str(source.get("law_type") or "") == "criminal" for source in sources
)
if is_criminal:
short_conclusion = (
"По найденным нормам возможна уголовная ответственность, "
"но точная квалификация и последствия зависят от обстоятельств дела."
)
action_steps = [
"Как можно быстрее обратитесь за очной помощью адвоката по уголовным делам.",
"Соберите и сохраните документы, повестки, протоколы и другие материалы, которые у вас уже есть.",
"Подготовьте точную хронологию событий, потому что для оценки важны обстоятельства и формулировка обвинения.",
]
risks = [
"Точная статья и возможное наказание зависят от обстоятельств, мотива, последствий и процессуального статуса.",
"Без изучения материалов дела нельзя надёжно оценить квалификацию и линию защиты.",
]
else:
short_conclusion = (
"По найденным нормам можно дать только общий ориентир; "
"точный вывод зависит от фактических обстоятельств вопроса."
)
action_steps = [
"Уточните ключевые обстоятельства и формулировку вопроса.",
"Соберите документы и доказательства, которые относятся к ситуации.",
"При необходимости получите очную консультацию профильного юриста.",
]
risks = [
"Ответ может измениться, если появятся новые существенные детали.",
"Без полного набора обстоятельств правовая оценка будет предварительной.",
]
return StructuredInitialAnswer(
short_conclusion=short_conclusion,
legal_points=legal_points,
action_steps=action_steps,
risks=risks,
)
def build_classification_fallback(
*,
question: str,
category: str | None,
region: str | None,
) -> ClassificationResult:
primary_law_type = infer_primary_law_type(category, question)
filters = {"law_type": [primary_law_type]} if primary_law_type != "other" else {}
return ClassificationResult(
legal_domain=primary_law_type,
issue_type="general_question",
jurisdiction="RU",
region=region,
needs_clarification=False,
clarification_questions=[],
search_queries=[question],
filters=filters,
)
class LegalAIService:
def __init__(self, client: AsyncOpenAI, llm_model: str) -> None:
self.client = client
self.llm_model = llm_model
async def classify(
self,
question: str,
category: str | None,
region: str | None,
user_type: str | None = None,
history: list[dict[str, str]] | None = None,
) -> ClassificationResult:
logger.info(
"LLM classification started: category=%s region=%s user_type=%s question_length=%s history_items=%s",
category,
region,
user_type,
len(question),
len(history or []),
)
category_hint = category or "не указана"
region_hint = region or "не указан"
user_type_hint = user_type or "не указан"
history_lines = []
for item in (history or [])[-6:]:
role = item.get("role", "user")
content = item.get("content", "")
history_lines.append(f"{role}: {content}")
history_text = "\n".join(history_lines) if history_lines else "нет"
user_prompt = (
f"Категория пользователя: {category_hint}\n"
f"Регион: {region_hint}\n"
f"Тип пользователя: {user_type_hint}\n"
f"История консультации:\n{history_text}\n"
f"Вопрос: {question}\n"
)
try:
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0,
response_format=CLASSIFIER_RESPONSE_FORMAT,
messages=[
{"role": "system", "content": CLASSIFIER_PROMPT},
{"role": "user", "content": user_prompt},
],
)
except Exception as exc:
logger.warning(
"LLM classification request with schema failed, using heuristic fallback: category=%s question=%s error=%s",
category,
question,
exc,
)
return build_classification_fallback(
question=question,
category=category,
region=region,
)
content = extract_message_content(completion, "classification") or "{}"
try:
payload = extract_json(content, "classification")
except RuntimeError:
logger.warning(
"LLM classification schema response was invalid, using heuristic fallback: category=%s question=%s",
category,
question,
)
return build_classification_fallback(
question=question,
category=category,
region=region,
)
search_queries = payload.get("search_queries") or [question]
filters = payload.get("filters") or {}
normalized_law_types = normalize_law_type_values(filters.get("law_type"))
if "law_type" in filters:
if normalized_law_types:
filters["law_type"] = normalized_law_types
else:
filters.pop("law_type", None)
fallback_law_types = infer_law_types(category)
if fallback_law_types and not filters.get("law_type"):
filters["law_type"] = fallback_law_types
result = ClassificationResult(
legal_domain=payload.get("legal_domain", "other"),
issue_type=payload.get("issue_type", "general_question"),
jurisdiction=payload.get("jurisdiction", "RU"),
region=payload.get("region") or region,
needs_clarification=bool(payload.get("needs_clarification", False)),
clarification_questions=payload.get("clarification_questions", []),
search_queries=search_queries,
filters=filters,
)
logger.info(
"LLM classification completed: legal_domain=%s issue_type=%s queries=%s needs_clarification=%s",
result.legal_domain,
result.issue_type,
result.search_queries,
result.needs_clarification,
)
return result
async def answer(
self,
question: str,
category: str | None,
region: str | None,
user_type: str | None,
history: list[dict[str, str]] | None,
sources: list[dict],
) -> str:
logger.info(
"LLM answer generation started: category=%s region=%s user_type=%s sources=%s question_length=%s history_items=%s",
category,
region,
user_type,
len(sources),
len(question),
len(history or []),
)
serialized_sources = json.dumps(sources, ensure_ascii=False, indent=2)
history_lines = []
for item in (history or [])[-6:]:
role = item.get("role", "user")
content = item.get("content", "")
history_lines.append(f"{role}: {content}")
history_text = "\n".join(history_lines) if history_lines else "нет"
has_consultation_history = bool(history)
answer_prompt = FOLLOW_UP_ANSWER_PROMPT if has_consultation_history else ANSWER_PROMPT
user_prompt = (
f"Категория: {category or 'не указана'}\n"
f"Регион: {region or 'не указан'}\n"
f"Тип пользователя: {user_type or 'не указан'}\n"
f"История консультации:\n{history_text}\n"
f"Вопрос пользователя: {question}\n\n"
f"SOURCES:\n{serialized_sources}"
)
try:
if has_consultation_history:
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0.2,
messages=[
{"role": "system", "content": answer_prompt},
{"role": "user", "content": user_prompt},
],
)
else:
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0.2,
response_format=INITIAL_ANSWER_RESPONSE_FORMAT,
messages=[
{"role": "system", "content": answer_prompt},
{"role": "user", "content": user_prompt},
],
)
except Exception as exc:
if has_consultation_history:
raise
logger.warning(
"LLM initial answer request with schema failed, using structured fallback: category=%s question=%s error=%s",
category,
question,
exc,
)
structured_answer = build_structured_answer_fallback(
question=question,
category=category,
sources=sources,
)
answer = render_structured_initial_answer(structured_answer, sources)
logger.info("LLM answer generation completed via fallback: answer_length=%s", len(answer))
return answer
raw_answer = extract_message_content(completion, "answer").strip()
if has_consultation_history:
answer = sanitize_answer_text(raw_answer)
else:
if looks_like_llm_refusal(raw_answer):
logger.warning(
"LLM returned refusal for initial answer, using structured fallback: category=%s question=%s",
category,
question,
)
structured_answer = build_structured_answer_fallback(
question=question,
category=category,
sources=sources,
)
else:
try:
payload = extract_json(raw_answer, "answer")
structured_answer = StructuredInitialAnswer.model_validate(payload)
except (RuntimeError, ValueError) as exc:
logger.warning(
"LLM initial answer schema response was invalid, using structured fallback: category=%s question=%s error=%s",
category,
question,
exc,
)
structured_answer = build_structured_answer_fallback(
question=question,
category=category,
sources=sources,
)
answer = render_structured_initial_answer(structured_answer, sources)
logger.info("LLM answer generation completed: answer_length=%s", len(answer))
return answer
async def generate_consultation_title(
self,
*,
question: str,
category: str | None,
answer: str,
) -> str:
logger.info(
"LLM consultation title generation started: category=%s question_length=%s answer_length=%s",
category,
len(question),
len(answer),
)
user_prompt = (
f"Категория: {category or 'не указана'}\n"
f"Вопрос пользователя: {question}\n"
f"Краткое содержание ответа:\n{answer[:1500]}"
)
completion = await self.client.chat.completions.create(
model=self.llm_model,
temperature=0,
messages=[
{"role": "system", "content": CONSULTATION_TITLE_PROMPT},
{"role": "user", "content": user_prompt},
],
)
content = extract_message_content(completion, "consultation_title")
title = " ".join(content.strip().split()).strip("\"' ")
title = build_fallback_title(title, limit=70)
logger.info("LLM consultation title generation completed: title=%s", title)
return title
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
from functools import lru_cache
import logging
from sentence_transformers import SentenceTransformer
from api.config import settings
logger = logging.getLogger(__name__)
class LocalEmbeddingService:
def __init__(self) -> None:
logger.info(
"Loading embedding model: model=%s device=%s",
settings.embedding_model,
settings.embedding_device,
)
self._model = SentenceTransformer(
settings.embedding_model,
device=settings.embedding_device,
)
self._model.max_seq_length = 512
logger.info(
"Embedding model loaded: model=%s max_seq_length=%s",
settings.embedding_model,
self._model.max_seq_length,
)
def encode_documents(self, texts: list[str]) -> list[list[float]]:
logger.info("Encoding document batch: size=%s", len(texts))
return self._model.encode(
texts,
prompt_name="search_document",
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=settings.index_batch_size,
show_progress_bar=False,
).tolist()
def encode_queries(self, texts: list[str]) -> list[list[float]]:
logger.info("Encoding query batch: size=%s", len(texts))
return self._model.encode(
texts,
prompt_name="search_query",
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=settings.index_batch_size,
show_progress_bar=False,
).tolist()
@lru_cache(maxsize=1)
def get_embedding_service() -> LocalEmbeddingService:
return LocalEmbeddingService()
+122
View File
@@ -0,0 +1,122 @@
from __future__ import annotations
import logging
from fastapi.concurrency import run_in_threadpool
from api.clients.chroma_store import ChromaVectorStore
from api.config import settings
from api.schemas import ClassificationResult
from api.services.local_embeddings import LocalEmbeddingService
from shared import ORM
logger = logging.getLogger(__name__)
def normalize_law_types_arg(value: list[str] | str | None) -> list[str] | None:
if value is None:
return None
if isinstance(value, str):
return [value]
normalized = [item for item in value if isinstance(item, str) and item.strip()]
return normalized or None
class HybridRetrievalService:
def __init__(
self,
orm: ORM,
embedder: LocalEmbeddingService,
vector_store: ChromaVectorStore,
) -> None:
self.orm = orm
self.embedder = embedder
self.vector_store = vector_store
async def retrieve(
self,
classification: ClassificationResult,
fallback_law_types: list[str] | None,
top_k: int,
) -> list[dict]:
queries = classification.search_queries or []
law_types = normalize_law_types_arg(
classification.filters.get("law_type")
or fallback_law_types
or None
)
logger.info(
"Hybrid retrieval started: queries=%s law_types=%s jurisdiction=%s top_k=%s",
queries,
law_types,
classification.jurisdiction,
top_k,
)
merged_scores: dict[int, float] = {}
for query in queries:
lexical_hits = await self.orm.search_law_chunks_full_text(
query=query,
law_types=law_types,
jurisdiction=classification.jurisdiction,
limit=settings.fts_top_k,
)
logger.info("Full-text hits for query '%s': %s", query, len(lexical_hits))
for rank, hit in enumerate(lexical_hits):
merged_scores[hit["chunk_id"]] = merged_scores.get(hit["chunk_id"], 0.0) + (
1.2 / (rank + 1)
)
query_embedding = await run_in_threadpool(
self.embedder.encode_queries,
[query],
)
vector_hits = await run_in_threadpool(
self.vector_store.query,
query_embedding,
settings.vector_top_k,
)
ids = vector_hits.get("ids", [[]])[0]
distances = vector_hits.get("distances", [[]])[0]
metadatas = vector_hits.get("metadatas", [[]])[0]
logger.info("Vector hits for query '%s': %s", query, len(ids))
for rank, (chunk_id, distance, metadata) in enumerate(
zip(ids, distances, metadatas)
):
if law_types and metadata.get("law_type") not in law_types:
continue
score = 1.0 / (rank + 1)
score += max(0.0, 1.0 - float(distance or 1.0))
merged_scores[int(chunk_id)] = merged_scores.get(int(chunk_id), 0.0) + score
ranked_ids = [
chunk_id
for chunk_id, _ in sorted(
merged_scores.items(),
key=lambda item: item[1],
reverse=True,
)
][: max(top_k * 3, top_k)]
rows = await self.orm.get_law_chunks_with_sources_by_ids(
ranked_ids,
law_types=law_types,
jurisdiction=classification.jurisdiction,
)
by_id = {row["chunk_id"]: row for row in rows}
results = []
for chunk_id in ranked_ids:
row = by_id.get(chunk_id)
if row is None:
continue
row["score"] = round(merged_scores.get(chunk_id, 0.0), 6)
results.append(row)
if len(results) >= top_k:
break
logger.info("Hybrid retrieval completed: returned=%s", len(results))
return results