186 lines
6.3 KiB
Python
186 lines
6.3 KiB
Python
import logging
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
|
|
from api.deps import get_ai_service, get_orm, get_retrieval_service
|
|
from api.schemas import (
|
|
AnswerRequest,
|
|
AnswerResponse,
|
|
ClassificationResult,
|
|
RetrievedChunk,
|
|
SearchRequest,
|
|
SearchResponse,
|
|
)
|
|
from api.services.legal_ai import build_fallback_title, infer_law_types
|
|
|
|
|
|
router = APIRouter(prefix="/api/v1/rag", tags=["rag"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@router.post("/search", response_model=SearchResponse)
|
|
async def search(payload: SearchRequest) -> SearchResponse:
|
|
logger.info(
|
|
"RAG search request: category=%s region=%s top_k=%s question_length=%s",
|
|
payload.category,
|
|
payload.region,
|
|
payload.top_k,
|
|
len(payload.question),
|
|
)
|
|
ai_service = get_ai_service()
|
|
retrieval = get_retrieval_service()
|
|
|
|
try:
|
|
classification = await ai_service.classify(
|
|
question=payload.question,
|
|
category=payload.category,
|
|
region=payload.region,
|
|
user_type=payload.user_type,
|
|
history=payload.history,
|
|
)
|
|
chunks = await retrieval.retrieve(
|
|
classification=classification,
|
|
fallback_law_types=payload.law_types or infer_law_types(payload.category),
|
|
top_k=payload.top_k,
|
|
)
|
|
except RuntimeError as exc:
|
|
logger.exception("RAG search failed")
|
|
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
return SearchResponse(
|
|
classification=classification,
|
|
generated_queries=classification.search_queries,
|
|
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
|
|
)
|
|
|
|
|
|
@router.post("/answer", response_model=AnswerResponse)
|
|
async def answer(payload: AnswerRequest) -> AnswerResponse:
|
|
logger.info(
|
|
"RAG answer request: user_id=%s consultation_id=%s save_history=%s category=%s region=%s top_k=%s question_length=%s",
|
|
payload.user_id,
|
|
payload.consultation_id,
|
|
payload.save_history,
|
|
payload.category,
|
|
payload.region,
|
|
payload.top_k,
|
|
len(payload.question),
|
|
)
|
|
ai_service = get_ai_service()
|
|
retrieval = get_retrieval_service()
|
|
orm = get_orm()
|
|
|
|
try:
|
|
classification = await ai_service.classify(
|
|
question=payload.question,
|
|
category=payload.category,
|
|
region=payload.region,
|
|
user_type=payload.user_type,
|
|
history=payload.history,
|
|
)
|
|
chunks = await retrieval.retrieve(
|
|
classification=classification,
|
|
fallback_law_types=payload.law_types or infer_law_types(payload.category),
|
|
top_k=payload.top_k,
|
|
)
|
|
except RuntimeError as exc:
|
|
logger.exception("RAG answer failed on classification/retrieval stage")
|
|
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
if not chunks:
|
|
logger.warning("RAG answer request returned no reliable chunks")
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="No reliable law chunks were found for this question.",
|
|
)
|
|
|
|
try:
|
|
answer_text = await ai_service.answer(
|
|
question=payload.question,
|
|
category=payload.category,
|
|
region=payload.region,
|
|
user_type=payload.user_type,
|
|
history=payload.history,
|
|
sources=chunks,
|
|
)
|
|
except RuntimeError as exc:
|
|
logger.exception("RAG answer failed on generation stage")
|
|
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
|
|
|
consultation_id = payload.consultation_id
|
|
user_message_id = None
|
|
assistant_message_id = None
|
|
|
|
if payload.save_history:
|
|
if payload.user_id is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="user_id is required when save_history=true",
|
|
)
|
|
user = await orm.get_user(payload.user_id)
|
|
if user is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="User was not found. Start the bot first so the profile is created.",
|
|
)
|
|
|
|
if consultation_id is not None:
|
|
consultation = await orm.get_consultation(
|
|
consultation_id=consultation_id,
|
|
user_id=payload.user_id,
|
|
)
|
|
if consultation is None:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Consultation was not found for this user.",
|
|
)
|
|
|
|
if consultation_id is None:
|
|
try:
|
|
consultation_title = await ai_service.generate_consultation_title(
|
|
question=payload.question,
|
|
category=payload.category or classification.legal_domain,
|
|
answer=answer_text,
|
|
)
|
|
except RuntimeError:
|
|
logger.exception("Consultation title generation failed, using fallback title")
|
|
consultation_title = build_fallback_title(payload.question)
|
|
|
|
consultation = await orm.create_consultation(
|
|
user_id=payload.user_id,
|
|
category=payload.category or classification.legal_domain,
|
|
title=consultation_title,
|
|
region=payload.region or classification.region,
|
|
)
|
|
consultation_id = consultation.id
|
|
|
|
user_message = await orm.create_message(
|
|
consultation_id=consultation_id,
|
|
role="user",
|
|
content=payload.question,
|
|
)
|
|
user_message_id = user_message.id
|
|
|
|
await orm.create_rag_query(
|
|
consultation_id=consultation_id,
|
|
user_message_id=user_message_id,
|
|
generated_queries=classification.search_queries,
|
|
retrieved_chunks=chunks,
|
|
)
|
|
|
|
assistant_message = await orm.create_message(
|
|
consultation_id=consultation_id,
|
|
role="assistant",
|
|
content=answer_text,
|
|
sources_json=chunks,
|
|
)
|
|
assistant_message_id = assistant_message.id
|
|
|
|
return AnswerResponse(
|
|
classification=classification,
|
|
generated_queries=classification.search_queries,
|
|
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
|
|
answer=answer_text,
|
|
consultation_id=consultation_id,
|
|
user_message_id=user_message_id,
|
|
assistant_message_id=assistant_message_id,
|
|
)
|