first commit
This commit is contained in:
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from api.deps import get_ai_service, get_orm, get_retrieval_service
|
||||
from api.schemas import (
|
||||
AnswerRequest,
|
||||
AnswerResponse,
|
||||
ClassificationResult,
|
||||
RetrievedChunk,
|
||||
SearchRequest,
|
||||
SearchResponse,
|
||||
)
|
||||
from api.services.legal_ai import build_fallback_title, infer_law_types
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/rag", tags=["rag"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/search", response_model=SearchResponse)
|
||||
async def search(payload: SearchRequest) -> SearchResponse:
|
||||
logger.info(
|
||||
"RAG search request: category=%s region=%s top_k=%s question_length=%s",
|
||||
payload.category,
|
||||
payload.region,
|
||||
payload.top_k,
|
||||
len(payload.question),
|
||||
)
|
||||
ai_service = get_ai_service()
|
||||
retrieval = get_retrieval_service()
|
||||
|
||||
try:
|
||||
classification = await ai_service.classify(
|
||||
question=payload.question,
|
||||
category=payload.category,
|
||||
region=payload.region,
|
||||
user_type=payload.user_type,
|
||||
history=payload.history,
|
||||
)
|
||||
chunks = await retrieval.retrieve(
|
||||
classification=classification,
|
||||
fallback_law_types=payload.law_types or infer_law_types(payload.category),
|
||||
top_k=payload.top_k,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
logger.exception("RAG search failed")
|
||||
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
||||
return SearchResponse(
|
||||
classification=classification,
|
||||
generated_queries=classification.search_queries,
|
||||
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/answer", response_model=AnswerResponse)
|
||||
async def answer(payload: AnswerRequest) -> AnswerResponse:
|
||||
logger.info(
|
||||
"RAG answer request: user_id=%s consultation_id=%s save_history=%s category=%s region=%s top_k=%s question_length=%s",
|
||||
payload.user_id,
|
||||
payload.consultation_id,
|
||||
payload.save_history,
|
||||
payload.category,
|
||||
payload.region,
|
||||
payload.top_k,
|
||||
len(payload.question),
|
||||
)
|
||||
ai_service = get_ai_service()
|
||||
retrieval = get_retrieval_service()
|
||||
orm = get_orm()
|
||||
|
||||
try:
|
||||
classification = await ai_service.classify(
|
||||
question=payload.question,
|
||||
category=payload.category,
|
||||
region=payload.region,
|
||||
user_type=payload.user_type,
|
||||
history=payload.history,
|
||||
)
|
||||
chunks = await retrieval.retrieve(
|
||||
classification=classification,
|
||||
fallback_law_types=payload.law_types or infer_law_types(payload.category),
|
||||
top_k=payload.top_k,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
logger.exception("RAG answer failed on classification/retrieval stage")
|
||||
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
||||
if not chunks:
|
||||
logger.warning("RAG answer request returned no reliable chunks")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No reliable law chunks were found for this question.",
|
||||
)
|
||||
|
||||
try:
|
||||
answer_text = await ai_service.answer(
|
||||
question=payload.question,
|
||||
category=payload.category,
|
||||
region=payload.region,
|
||||
user_type=payload.user_type,
|
||||
history=payload.history,
|
||||
sources=chunks,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
logger.exception("RAG answer failed on generation stage")
|
||||
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
||||
|
||||
consultation_id = payload.consultation_id
|
||||
user_message_id = None
|
||||
assistant_message_id = None
|
||||
|
||||
if payload.save_history:
|
||||
if payload.user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="user_id is required when save_history=true",
|
||||
)
|
||||
user = await orm.get_user(payload.user_id)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="User was not found. Start the bot first so the profile is created.",
|
||||
)
|
||||
|
||||
if consultation_id is not None:
|
||||
consultation = await orm.get_consultation(
|
||||
consultation_id=consultation_id,
|
||||
user_id=payload.user_id,
|
||||
)
|
||||
if consultation is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Consultation was not found for this user.",
|
||||
)
|
||||
|
||||
if consultation_id is None:
|
||||
try:
|
||||
consultation_title = await ai_service.generate_consultation_title(
|
||||
question=payload.question,
|
||||
category=payload.category or classification.legal_domain,
|
||||
answer=answer_text,
|
||||
)
|
||||
except RuntimeError:
|
||||
logger.exception("Consultation title generation failed, using fallback title")
|
||||
consultation_title = build_fallback_title(payload.question)
|
||||
|
||||
consultation = await orm.create_consultation(
|
||||
user_id=payload.user_id,
|
||||
category=payload.category or classification.legal_domain,
|
||||
title=consultation_title,
|
||||
region=payload.region or classification.region,
|
||||
)
|
||||
consultation_id = consultation.id
|
||||
|
||||
user_message = await orm.create_message(
|
||||
consultation_id=consultation_id,
|
||||
role="user",
|
||||
content=payload.question,
|
||||
)
|
||||
user_message_id = user_message.id
|
||||
|
||||
await orm.create_rag_query(
|
||||
consultation_id=consultation_id,
|
||||
user_message_id=user_message_id,
|
||||
generated_queries=classification.search_queries,
|
||||
retrieved_chunks=chunks,
|
||||
)
|
||||
|
||||
assistant_message = await orm.create_message(
|
||||
consultation_id=consultation_id,
|
||||
role="assistant",
|
||||
content=answer_text,
|
||||
sources_json=chunks,
|
||||
)
|
||||
assistant_message_id = assistant_message.id
|
||||
|
||||
return AnswerResponse(
|
||||
classification=classification,
|
||||
generated_queries=classification.search_queries,
|
||||
retrieved_chunks=[RetrievedChunk(**chunk) for chunk in chunks],
|
||||
answer=answer_text,
|
||||
consultation_id=consultation_id,
|
||||
user_message_id=user_message_id,
|
||||
assistant_message_id=assistant_message_id,
|
||||
)
|
||||
Reference in New Issue
Block a user