Files
LawBot/api/routers/rag.py
T
2026-05-25 01:12:43 +03:00

186 lines
6.3 KiB
Python

import logging
from fastapi import APIRouter, HTTPException
from api.deps import get_ai_service, get_orm, get_retrieval_service
from api.schemas import (
AnswerRequest,
AnswerResponse,
ClassificationResult,
RetrievedChunk,
SearchRequest,
SearchResponse,
)
from api.services.legal_ai import build_fallback_title, infer_law_types
router = APIRouter(prefix="/api/v1/rag", tags=["rag"])
logger = logging.getLogger(__name__)
@router.post("/search", response_model=SearchResponse)
async def search(payload: SearchRequest) -> SearchResponse:
logger.info(
"RAG search request: category=%s region=%s top_k=%s question_length=%s",
payload.category,
payload.region,
payload.top_k,
len(payload.question),
)
ai_service = get_ai_service()
retrieval = get_retrieval_service()
try:
classification = await ai_service.classify(
question=payload.question,
category=payload.category,
region=payload.region,
user_type=payload.user_type,
history=payload.history,
)
chunks = await retrieval.retrieve(
classification=classification,
fallback_law_types=payload.law_types or infer_law_types(payload.category),
top_k=payload.top_k,
)
except RuntimeError as exc:
logger.exception("RAG search failed")
raise HTTPException(status_code=502, detail=str(exc)) from exc
return SearchResponse(
classification=classification,
generated_queries=classification.search_queries,
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
)
@router.post("/answer", response_model=AnswerResponse)
async def answer(payload: AnswerRequest) -> AnswerResponse:
logger.info(
"RAG answer request: user_id=%s consultation_id=%s save_history=%s category=%s region=%s top_k=%s question_length=%s",
payload.user_id,
payload.consultation_id,
payload.save_history,
payload.category,
payload.region,
payload.top_k,
len(payload.question),
)
ai_service = get_ai_service()
retrieval = get_retrieval_service()
orm = get_orm()
try:
classification = await ai_service.classify(
question=payload.question,
category=payload.category,
region=payload.region,
user_type=payload.user_type,
history=payload.history,
)
chunks = await retrieval.retrieve(
classification=classification,
fallback_law_types=payload.law_types or infer_law_types(payload.category),
top_k=payload.top_k,
)
except RuntimeError as exc:
logger.exception("RAG answer failed on classification/retrieval stage")
raise HTTPException(status_code=502, detail=str(exc)) from exc
if not chunks:
logger.warning("RAG answer request returned no reliable chunks")
raise HTTPException(
status_code=404,
detail="No reliable law chunks were found for this question.",
)
try:
answer_text = await ai_service.answer(
question=payload.question,
category=payload.category,
region=payload.region,
user_type=payload.user_type,
history=payload.history,
sources=chunks,
)
except RuntimeError as exc:
logger.exception("RAG answer failed on generation stage")
raise HTTPException(status_code=502, detail=str(exc)) from exc
consultation_id = payload.consultation_id
user_message_id = None
assistant_message_id = None
if payload.save_history:
if payload.user_id is None:
raise HTTPException(
status_code=400,
detail="user_id is required when save_history=true",
)
user = await orm.get_user(payload.user_id)
if user is None:
raise HTTPException(
status_code=404,
detail="User was not found. Start the bot first so the profile is created.",
)
if consultation_id is not None:
consultation = await orm.get_consultation(
consultation_id=consultation_id,
user_id=payload.user_id,
)
if consultation is None:
raise HTTPException(
status_code=404,
detail="Consultation was not found for this user.",
)
if consultation_id is None:
try:
consultation_title = await ai_service.generate_consultation_title(
question=payload.question,
category=payload.category or classification.legal_domain,
answer=answer_text,
)
except RuntimeError:
logger.exception("Consultation title generation failed, using fallback title")
consultation_title = build_fallback_title(payload.question)
consultation = await orm.create_consultation(
user_id=payload.user_id,
category=payload.category or classification.legal_domain,
title=consultation_title,
region=payload.region or classification.region,
)
consultation_id = consultation.id
user_message = await orm.create_message(
consultation_id=consultation_id,
role="user",
content=payload.question,
)
user_message_id = user_message.id
await orm.create_rag_query(
consultation_id=consultation_id,
user_message_id=user_message_id,
generated_queries=classification.search_queries,
retrieved_chunks=chunks,
)
assistant_message = await orm.create_message(
consultation_id=consultation_id,
role="assistant",
content=answer_text,
sources_json=chunks,
)
assistant_message_id = assistant_message.id
return AnswerResponse(
classification=classification,
generated_queries=classification.search_queries,
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
answer=answer_text,
consultation_id=consultation_id,
user_message_id=user_message_id,
assistant_message_id=assistant_message_id,
)