import logging from fastapi import APIRouter, HTTPException from api.deps import get_ai_service, get_orm, get_retrieval_service from api.schemas import ( AnswerRequest, AnswerResponse, ClassificationResult, RetrievedChunk, SearchRequest, SearchResponse, ) from api.services.legal_ai import build_fallback_title, infer_law_types router = APIRouter(prefix="/api/v1/rag", tags=["rag"]) logger = logging.getLogger(__name__) @router.post("/search", response_model=SearchResponse) async def search(payload: SearchRequest) -> SearchResponse: logger.info( "RAG search request: category=%s region=%s top_k=%s question_length=%s", payload.category, payload.region, payload.top_k, len(payload.question), ) ai_service = get_ai_service() retrieval = get_retrieval_service() try: classification = await ai_service.classify( question=payload.question, category=payload.category, region=payload.region, user_type=payload.user_type, history=payload.history, ) chunks = await retrieval.retrieve( classification=classification, fallback_law_types=payload.law_types or infer_law_types(payload.category), top_k=payload.top_k, ) except RuntimeError as exc: logger.exception("RAG search failed") raise HTTPException(status_code=502, detail=str(exc)) from exc return SearchResponse( classification=classification, generated_queries=classification.search_queries, retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks], ) @router.post("/answer", response_model=AnswerResponse) async def answer(payload: AnswerRequest) -> AnswerResponse: logger.info( "RAG answer request: user_id=%s consultation_id=%s save_history=%s category=%s region=%s top_k=%s question_length=%s", payload.user_id, payload.consultation_id, payload.save_history, payload.category, payload.region, payload.top_k, len(payload.question), ) ai_service = get_ai_service() retrieval = get_retrieval_service() orm = get_orm() try: classification = await ai_service.classify( question=payload.question, category=payload.category, region=payload.region, user_type=payload.user_type, history=payload.history, ) chunks = await retrieval.retrieve( classification=classification, fallback_law_types=payload.law_types or infer_law_types(payload.category), top_k=payload.top_k, ) except RuntimeError as exc: logger.exception("RAG answer failed on classification/retrieval stage") raise HTTPException(status_code=502, detail=str(exc)) from exc if not chunks: logger.warning("RAG answer request returned no reliable chunks") raise HTTPException( status_code=404, detail="No reliable law chunks were found for this question.", ) try: answer_text = await ai_service.answer( question=payload.question, category=payload.category, region=payload.region, user_type=payload.user_type, history=payload.history, sources=chunks, ) except RuntimeError as exc: logger.exception("RAG answer failed on generation stage") raise HTTPException(status_code=502, detail=str(exc)) from exc consultation_id = payload.consultation_id user_message_id = None assistant_message_id = None if payload.save_history: if payload.user_id is None: raise HTTPException( status_code=400, detail="user_id is required when save_history=true", ) user = await orm.get_user(payload.user_id) if user is None: raise HTTPException( status_code=404, detail="User was not found. Start the bot first so the profile is created.", ) if consultation_id is not None: consultation = await orm.get_consultation( consultation_id=consultation_id, user_id=payload.user_id, ) if consultation is None: raise HTTPException( status_code=404, detail="Consultation was not found for this user.", ) if consultation_id is None: try: consultation_title = await ai_service.generate_consultation_title( question=payload.question, category=payload.category or classification.legal_domain, answer=answer_text, ) except RuntimeError: logger.exception("Consultation title generation failed, using fallback title") consultation_title = build_fallback_title(payload.question) consultation = await orm.create_consultation( user_id=payload.user_id, category=payload.category or classification.legal_domain, title=consultation_title, region=payload.region or classification.region, ) consultation_id = consultation.id user_message = await orm.create_message( consultation_id=consultation_id, role="user", content=payload.question, ) user_message_id = user_message.id await orm.create_rag_query( consultation_id=consultation_id, user_message_id=user_message_id, generated_queries=classification.search_queries, retrieved_chunks=chunks, ) assistant_message = await orm.create_message( consultation_id=consultation_id, role="assistant", content=answer_text, sources_json=chunks, ) assistant_message_id = assistant_message.id return AnswerResponse( classification=classification, generated_queries=classification.search_queries, retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks], answer=answer_text, consultation_id=consultation_id, user_message_id=user_message_id, assistant_message_id=assistant_message_id, )