Files
GorichBot/rag_api/app/service.py
T
2026-05-12 23:37:04 +03:00

307 lines
15 KiB
Python

from __future__ import annotations
import json
from typing import Any
from .config import settings
from .embeddings import RuBertMiniFridaEmbedder
from .menu_catalog import MenuCatalog
from .models import ChatRequest, ChatResponse, IndexResponse, KnowledgeDocument, SourceDocument
from .openrouter_client import OpenRouterClient
from .site_scraper import SiteKnowledgeScraper
from .vector_store import VectorStore
class RagService:
def __init__(self) -> None:
self.vector_store = VectorStore()
self.embedder = RuBertMiniFridaEmbedder()
self.site_scraper = SiteKnowledgeScraper()
self.menu_catalog = MenuCatalog()
self.openrouter = OpenRouterClient()
self.knowledge_collection = self.vector_store.get_collection(
settings.knowledge_collection
)
self.menu_collection = self.vector_store.get_collection(settings.menu_collection)
@staticmethod
def clear_collection(collection: Any) -> None:
ids = collection.get(include=[])["ids"]
if ids:
collection.delete(ids=ids)
async def reindex(self) -> IndexResponse:
knowledge_documents = await self.site_scraper.scrape()
self.clear_collection(self.knowledge_collection)
self.clear_collection(self.menu_collection)
if knowledge_documents:
knowledge_texts = [doc.text for doc in knowledge_documents]
self.knowledge_collection.add(
ids=[doc.doc_id for doc in knowledge_documents],
documents=knowledge_texts,
embeddings=self.embedder.embed_documents(knowledge_texts),
metadatas=[
{
"title": doc.title,
"source_type": doc.source_type,
"source_url": doc.source_url,
"published_at": doc.published_at.isoformat()
if doc.published_at
else "",
**doc.metadata,
}
for doc in knowledge_documents
],
)
menu_documents = self.menu_catalog.menu_documents()
if menu_documents:
menu_texts = [document for _, document in menu_documents]
self.menu_collection.add(
ids=[item.item_id for item, _ in menu_documents],
documents=menu_texts,
embeddings=self.embedder.embed_documents(menu_texts),
metadatas=[
{
"name": item.name,
"category": item.category,
"price": item.price if item.price is not None else -1,
"price_label": item.price_label,
"source_url": item.source_url,
"photo_url": item.photo_url,
}
for item, _ in menu_documents
],
)
return IndexResponse(
indexed_knowledge_documents=len(knowledge_documents),
indexed_menu_documents=len(menu_documents),
menu_items_loaded=len(menu_documents),
)
def retrieve_knowledge(self, query: str) -> list[SourceDocument]:
if self.knowledge_collection.count() == 0:
return []
query_embedding = self.embedder.embed_queries([query])[0]
result = self.knowledge_collection.query(
query_embeddings=[query_embedding],
n_results=settings.top_k,
)
documents = result.get("documents", [[]])[0]
metadatas = result.get("metadatas", [[]])[0]
distances = result.get("distances", [[]])[0]
ids = result.get("ids", [[]])[0]
sources: list[SourceDocument] = []
for index, document in enumerate(documents):
metadata = metadatas[index]
published_at = metadata.get("published_at") or None
sources.append(
SourceDocument(
source_id=ids[index],
source_type=str(metadata.get("source_type", "unknown")),
title=str(metadata.get("title", ids[index])),
source_url=str(metadata.get("source_url", settings.site_url)),
snippet=document[:400],
published_at=published_at,
score=distances[index] if index < len(distances) else None,
)
)
return sources
def build_system_prompt(self, sources: list[SourceDocument]) -> str:
context_parts = []
for source in sources:
published_label = (
f" | дата: {source.published_at.isoformat()}"
if source.published_at
else ""
)
context_parts.append(
f"[{source.source_type}] {source.title}{published_label}\n"
f"Источник: {source.source_url}\n"
f"{source.snippet}"
)
context_block = "\n\n".join(context_parts) if context_parts else "Нет найденного контекста."
return (
"Ты помощник шаурмечной Горыч из Волгограда.\n"
"Отвечай по-русски, дружелюбно, естественно и клиентоориентированно.\n"
"Не начинай каждый ответ с нового приветствия.\n"
"Отвечай только на текущий вопрос пользователя и не повторяй без необходимости уже сказанное ранее.\n"
"Не перечисляй ассортимент без запроса. Если человек не просил список позиций, не превращай ответ в каталог.\n"
"Для рекомендаций предлагай максимум 3 конкретные позиции.\n"
"Не выдумывай факты. Если данные расходятся, прямо скажи об этом и укажи источник.\n"
"Если вопрос про режим работы, доставку, контакты, адрес, соцсети, способы заказа или общую информацию о заведении, отвечай по контексту и не используй tool меню.\n"
"Используй tool find_menu_items только когда пользователь явно просит подобрать, перечислить, сравнить или найти блюда из меню:\n"
"- что есть в меню;\n"
"- что посоветуете из конкретной категории;\n"
"- что есть с определённым ингредиентом;\n"
"- что можно до определённого бюджета;\n"
"- что острое, сырное, мясное и так далее, если нужен именно подбор позиций.\n"
"Если вопрос общий и консультативный, например про вкус, выбор мяса или что лучше взять в целом, сначала ответь по-человечески и не вызывай tool, пока пользователь не попросит конкретные позиции.\n"
"Если всё же используешь tool и он вернул позиции, назови их по именам и по возможности укажи цену.\n"
"Если tool ничего не нашёл, честно скажи об этом и предложи уточнить запрос.\n"
"Если в контексте есть даты, ориентируйся на более свежие данные.\n\n"
"Формат ответа:\n"
"- Используй только HTML-теги, подходящие для Telegram/aiogram: <b>, <i>, <u>, <s>, <code>, <pre>, <a href=\"...\">.\n"
"- Не используй Markdown со звёздочками, подчёркиваниями или решётками.\n"
"- Не пиши служебные фразы вроде 'выберите вопрос ниже'.\n\n"
f"Контекст RAG:\n{context_block}"
)
def build_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "find_menu_items",
"description": "Подбирает блюда из меню Горыча по описанию, бюджету, категории и ингредиентам. Использовать только для явных запросов о меню и конкретных позициях, не использовать для вопросов о режиме работы, доставке, контактах и общей информации о заведении.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Свободное описание того, что хочет пользователь.",
},
"max_price": {
"type": "integer",
"description": "Максимальная цена в рублях.",
},
"category": {
"type": "string",
"description": "Категория блюда, например: пицца, донар, шаурма.",
},
"must_include": {
"type": "array",
"items": {"type": "string"},
"description": "Ингредиенты или слова, которые желательно включить.",
},
"must_not_include": {
"type": "array",
"items": {"type": "string"},
"description": "Ингредиенты или слова, которых нужно избегать.",
},
"limit": {
"type": "integer",
"description": "Максимум позиций в выдаче.",
"default": 5,
},
},
"required": [],
},
},
}
]
def run_tool(self, name: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
if name != "find_menu_items":
return []
return self.search_menu(
query=arguments.get("query", ""),
max_price=arguments.get("max_price"),
category=arguments.get("category"),
must_include=arguments.get("must_include"),
must_not_include=arguments.get("must_not_include"),
limit=arguments.get("limit", 5),
)
def search_menu(
self,
query: str = "",
max_price: int | None = None,
category: str | None = None,
must_include: list[str] | None = None,
must_not_include: list[str] | None = None,
limit: int = 5,
) -> list[dict[str, Any]]:
candidate_ids: list[str] | None = None
semantic_ranks: dict[str, int] | None = None
if query and self.menu_collection.count() > 0:
query_embedding = self.embedder.embed_queries([query])[0]
semantic_result = self.menu_collection.query(
query_embeddings=[query_embedding],
n_results=min(max(limit * 4, 10), self.menu_collection.count()),
)
candidate_ids = semantic_result.get("ids", [[]])[0]
semantic_ranks = {
item_id: rank for rank, item_id in enumerate(candidate_ids, start=1)
}
return self.menu_catalog.search(
query=query,
max_price=max_price,
category=category,
must_include=must_include,
must_not_include=must_not_include,
limit=limit,
candidate_ids=candidate_ids,
semantic_ranks=semantic_ranks,
)
async def chat(self, request: ChatRequest) -> ChatResponse:
sources = self.retrieve_knowledge(request.message)
messages: list[dict[str, Any]] = [
{"role": "system", "content": self.build_system_prompt(sources)}
]
for message in request.history:
messages.append({"role": message.role, "content": message.content})
messages.append({"role": "user", "content": request.message})
tools = self.build_tools()
first_response = await self.openrouter.chat_completion(
messages=messages,
tools=tools,
)
choice_message = first_response["choices"][0]["message"]
tool_calls = choice_message.get("tool_calls", [])
tool_results: list[dict[str, Any]] = []
model = first_response.get("model", settings.openrouter_model)
if tool_calls:
messages.append(
{
"role": "assistant",
"content": choice_message.get("content", ""),
"tool_calls": tool_calls,
}
)
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
arguments = json.loads(tool_call["function"]["arguments"] or "{}")
result = self.run_tool(function_name, arguments)
tool_results.extend(result)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": json.dumps(result, ensure_ascii=False),
}
)
final_response = await self.openrouter.chat_completion(messages=messages)
final_message = final_response["choices"][0]["message"]["content"]
model = final_response.get("model", settings.openrouter_model)
return ChatResponse(
answer=final_message,
model=model,
sources=sources,
tool_results=tool_results,
)
answer = choice_message.get("content", "")
return ChatResponse(
answer=answer,
model=model,
sources=sources,
tool_results=tool_results,
)