307 lines
15 KiB
Python
307 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
|
|
from .config import settings
|
|
from .embeddings import RuBertMiniFridaEmbedder
|
|
from .menu_catalog import MenuCatalog
|
|
from .models import ChatRequest, ChatResponse, IndexResponse, KnowledgeDocument, SourceDocument
|
|
from .openrouter_client import OpenRouterClient
|
|
from .site_scraper import SiteKnowledgeScraper
|
|
from .vector_store import VectorStore
|
|
|
|
|
|
class RagService:
|
|
def __init__(self) -> None:
|
|
self.vector_store = VectorStore()
|
|
self.embedder = RuBertMiniFridaEmbedder()
|
|
self.site_scraper = SiteKnowledgeScraper()
|
|
self.menu_catalog = MenuCatalog()
|
|
self.openrouter = OpenRouterClient()
|
|
self.knowledge_collection = self.vector_store.get_collection(
|
|
settings.knowledge_collection
|
|
)
|
|
self.menu_collection = self.vector_store.get_collection(settings.menu_collection)
|
|
|
|
@staticmethod
|
|
def clear_collection(collection: Any) -> None:
|
|
ids = collection.get(include=[])["ids"]
|
|
if ids:
|
|
collection.delete(ids=ids)
|
|
|
|
async def reindex(self) -> IndexResponse:
|
|
knowledge_documents = await self.site_scraper.scrape()
|
|
self.clear_collection(self.knowledge_collection)
|
|
self.clear_collection(self.menu_collection)
|
|
|
|
if knowledge_documents:
|
|
knowledge_texts = [doc.text for doc in knowledge_documents]
|
|
self.knowledge_collection.add(
|
|
ids=[doc.doc_id for doc in knowledge_documents],
|
|
documents=knowledge_texts,
|
|
embeddings=self.embedder.embed_documents(knowledge_texts),
|
|
metadatas=[
|
|
{
|
|
"title": doc.title,
|
|
"source_type": doc.source_type,
|
|
"source_url": doc.source_url,
|
|
"published_at": doc.published_at.isoformat()
|
|
if doc.published_at
|
|
else "",
|
|
**doc.metadata,
|
|
}
|
|
for doc in knowledge_documents
|
|
],
|
|
)
|
|
|
|
menu_documents = self.menu_catalog.menu_documents()
|
|
if menu_documents:
|
|
menu_texts = [document for _, document in menu_documents]
|
|
self.menu_collection.add(
|
|
ids=[item.item_id for item, _ in menu_documents],
|
|
documents=menu_texts,
|
|
embeddings=self.embedder.embed_documents(menu_texts),
|
|
metadatas=[
|
|
{
|
|
"name": item.name,
|
|
"category": item.category,
|
|
"price": item.price if item.price is not None else -1,
|
|
"price_label": item.price_label,
|
|
"source_url": item.source_url,
|
|
"photo_url": item.photo_url,
|
|
}
|
|
for item, _ in menu_documents
|
|
],
|
|
)
|
|
|
|
return IndexResponse(
|
|
indexed_knowledge_documents=len(knowledge_documents),
|
|
indexed_menu_documents=len(menu_documents),
|
|
menu_items_loaded=len(menu_documents),
|
|
)
|
|
|
|
def retrieve_knowledge(self, query: str) -> list[SourceDocument]:
|
|
if self.knowledge_collection.count() == 0:
|
|
return []
|
|
|
|
query_embedding = self.embedder.embed_queries([query])[0]
|
|
result = self.knowledge_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=settings.top_k,
|
|
)
|
|
documents = result.get("documents", [[]])[0]
|
|
metadatas = result.get("metadatas", [[]])[0]
|
|
distances = result.get("distances", [[]])[0]
|
|
ids = result.get("ids", [[]])[0]
|
|
|
|
sources: list[SourceDocument] = []
|
|
for index, document in enumerate(documents):
|
|
metadata = metadatas[index]
|
|
published_at = metadata.get("published_at") or None
|
|
sources.append(
|
|
SourceDocument(
|
|
source_id=ids[index],
|
|
source_type=str(metadata.get("source_type", "unknown")),
|
|
title=str(metadata.get("title", ids[index])),
|
|
source_url=str(metadata.get("source_url", settings.site_url)),
|
|
snippet=document[:400],
|
|
published_at=published_at,
|
|
score=distances[index] if index < len(distances) else None,
|
|
)
|
|
)
|
|
return sources
|
|
|
|
def build_system_prompt(self, sources: list[SourceDocument]) -> str:
|
|
context_parts = []
|
|
for source in sources:
|
|
published_label = (
|
|
f" | дата: {source.published_at.isoformat()}"
|
|
if source.published_at
|
|
else ""
|
|
)
|
|
context_parts.append(
|
|
f"[{source.source_type}] {source.title}{published_label}\n"
|
|
f"Источник: {source.source_url}\n"
|
|
f"{source.snippet}"
|
|
)
|
|
|
|
context_block = "\n\n".join(context_parts) if context_parts else "Нет найденного контекста."
|
|
return (
|
|
"Ты помощник шаурмечной Горыч из Волгограда.\n"
|
|
"Отвечай по-русски, дружелюбно, естественно и клиентоориентированно.\n"
|
|
"Не начинай каждый ответ с нового приветствия.\n"
|
|
"Отвечай только на текущий вопрос пользователя и не повторяй без необходимости уже сказанное ранее.\n"
|
|
"Не перечисляй ассортимент без запроса. Если человек не просил список позиций, не превращай ответ в каталог.\n"
|
|
"Для рекомендаций предлагай максимум 3 конкретные позиции.\n"
|
|
"Не выдумывай факты. Если данные расходятся, прямо скажи об этом и укажи источник.\n"
|
|
"Если вопрос про режим работы, доставку, контакты, адрес, соцсети, способы заказа или общую информацию о заведении, отвечай по контексту и не используй tool меню.\n"
|
|
"Используй tool find_menu_items только когда пользователь явно просит подобрать, перечислить, сравнить или найти блюда из меню:\n"
|
|
"- что есть в меню;\n"
|
|
"- что посоветуете из конкретной категории;\n"
|
|
"- что есть с определённым ингредиентом;\n"
|
|
"- что можно до определённого бюджета;\n"
|
|
"- что острое, сырное, мясное и так далее, если нужен именно подбор позиций.\n"
|
|
"Если вопрос общий и консультативный, например про вкус, выбор мяса или что лучше взять в целом, сначала ответь по-человечески и не вызывай tool, пока пользователь не попросит конкретные позиции.\n"
|
|
"Если всё же используешь tool и он вернул позиции, назови их по именам и по возможности укажи цену.\n"
|
|
"Если tool ничего не нашёл, честно скажи об этом и предложи уточнить запрос.\n"
|
|
"Если в контексте есть даты, ориентируйся на более свежие данные.\n\n"
|
|
"Формат ответа:\n"
|
|
"- Используй только HTML-теги, подходящие для Telegram/aiogram: <b>, <i>, <u>, <s>, <code>, <pre>, <a href=\"...\">.\n"
|
|
"- Не используй Markdown со звёздочками, подчёркиваниями или решётками.\n"
|
|
"- Не пиши служебные фразы вроде 'выберите вопрос ниже'.\n\n"
|
|
f"Контекст RAG:\n{context_block}"
|
|
)
|
|
|
|
def build_tools(self) -> list[dict[str, Any]]:
|
|
return [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "find_menu_items",
|
|
"description": "Подбирает блюда из меню Горыча по описанию, бюджету, категории и ингредиентам. Использовать только для явных запросов о меню и конкретных позициях, не использовать для вопросов о режиме работы, доставке, контактах и общей информации о заведении.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Свободное описание того, что хочет пользователь.",
|
|
},
|
|
"max_price": {
|
|
"type": "integer",
|
|
"description": "Максимальная цена в рублях.",
|
|
},
|
|
"category": {
|
|
"type": "string",
|
|
"description": "Категория блюда, например: пицца, донар, шаурма.",
|
|
},
|
|
"must_include": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Ингредиенты или слова, которые желательно включить.",
|
|
},
|
|
"must_not_include": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Ингредиенты или слова, которых нужно избегать.",
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Максимум позиций в выдаче.",
|
|
"default": 5,
|
|
},
|
|
},
|
|
"required": [],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
def run_tool(self, name: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
|
|
if name != "find_menu_items":
|
|
return []
|
|
|
|
return self.search_menu(
|
|
query=arguments.get("query", ""),
|
|
max_price=arguments.get("max_price"),
|
|
category=arguments.get("category"),
|
|
must_include=arguments.get("must_include"),
|
|
must_not_include=arguments.get("must_not_include"),
|
|
limit=arguments.get("limit", 5),
|
|
)
|
|
|
|
def search_menu(
|
|
self,
|
|
query: str = "",
|
|
max_price: int | None = None,
|
|
category: str | None = None,
|
|
must_include: list[str] | None = None,
|
|
must_not_include: list[str] | None = None,
|
|
limit: int = 5,
|
|
) -> list[dict[str, Any]]:
|
|
candidate_ids: list[str] | None = None
|
|
semantic_ranks: dict[str, int] | None = None
|
|
|
|
if query and self.menu_collection.count() > 0:
|
|
query_embedding = self.embedder.embed_queries([query])[0]
|
|
semantic_result = self.menu_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=min(max(limit * 4, 10), self.menu_collection.count()),
|
|
)
|
|
candidate_ids = semantic_result.get("ids", [[]])[0]
|
|
semantic_ranks = {
|
|
item_id: rank for rank, item_id in enumerate(candidate_ids, start=1)
|
|
}
|
|
|
|
return self.menu_catalog.search(
|
|
query=query,
|
|
max_price=max_price,
|
|
category=category,
|
|
must_include=must_include,
|
|
must_not_include=must_not_include,
|
|
limit=limit,
|
|
candidate_ids=candidate_ids,
|
|
semantic_ranks=semantic_ranks,
|
|
)
|
|
|
|
async def chat(self, request: ChatRequest) -> ChatResponse:
|
|
sources = self.retrieve_knowledge(request.message)
|
|
messages: list[dict[str, Any]] = [
|
|
{"role": "system", "content": self.build_system_prompt(sources)}
|
|
]
|
|
for message in request.history:
|
|
messages.append({"role": message.role, "content": message.content})
|
|
messages.append({"role": "user", "content": request.message})
|
|
|
|
tools = self.build_tools()
|
|
first_response = await self.openrouter.chat_completion(
|
|
messages=messages,
|
|
tools=tools,
|
|
)
|
|
choice_message = first_response["choices"][0]["message"]
|
|
tool_calls = choice_message.get("tool_calls", [])
|
|
tool_results: list[dict[str, Any]] = []
|
|
model = first_response.get("model", settings.openrouter_model)
|
|
|
|
if tool_calls:
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": choice_message.get("content", ""),
|
|
"tool_calls": tool_calls,
|
|
}
|
|
)
|
|
|
|
for tool_call in tool_calls:
|
|
function_name = tool_call["function"]["name"]
|
|
arguments = json.loads(tool_call["function"]["arguments"] or "{}")
|
|
result = self.run_tool(function_name, arguments)
|
|
tool_results.extend(result)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"name": function_name,
|
|
"content": json.dumps(result, ensure_ascii=False),
|
|
}
|
|
)
|
|
|
|
final_response = await self.openrouter.chat_completion(messages=messages)
|
|
final_message = final_response["choices"][0]["message"]["content"]
|
|
model = final_response.get("model", settings.openrouter_model)
|
|
|
|
return ChatResponse(
|
|
answer=final_message,
|
|
model=model,
|
|
sources=sources,
|
|
tool_results=tool_results,
|
|
)
|
|
|
|
answer = choice_message.get("content", "")
|
|
return ChatResponse(
|
|
answer=answer,
|
|
model=model,
|
|
sources=sources,
|
|
tool_results=tool_results,
|
|
)
|