from __future__ import annotations import json from typing import Any from .config import settings from .embeddings import RuBertMiniFridaEmbedder from .menu_catalog import MenuCatalog from .models import ChatRequest, ChatResponse, IndexResponse, KnowledgeDocument, SourceDocument from .openrouter_client import OpenRouterClient from .site_scraper import SiteKnowledgeScraper from .vector_store import VectorStore class RagService: def __init__(self) -> None: self.vector_store = VectorStore() self.embedder = RuBertMiniFridaEmbedder() self.site_scraper = SiteKnowledgeScraper() self.menu_catalog = MenuCatalog() self.openrouter = OpenRouterClient() self.knowledge_collection = self.vector_store.get_collection( settings.knowledge_collection ) self.menu_collection = self.vector_store.get_collection(settings.menu_collection) @staticmethod def clear_collection(collection: Any) -> None: ids = collection.get(include=[])["ids"] if ids: collection.delete(ids=ids) async def reindex(self) -> IndexResponse: knowledge_documents = await self.site_scraper.scrape() self.clear_collection(self.knowledge_collection) self.clear_collection(self.menu_collection) if knowledge_documents: knowledge_texts = [doc.text for doc in knowledge_documents] self.knowledge_collection.add( ids=[doc.doc_id for doc in knowledge_documents], documents=knowledge_texts, embeddings=self.embedder.embed_documents(knowledge_texts), metadatas=[ { "title": doc.title, "source_type": doc.source_type, "source_url": doc.source_url, "published_at": doc.published_at.isoformat() if doc.published_at else "", **doc.metadata, } for doc in knowledge_documents ], ) menu_documents = self.menu_catalog.menu_documents() if menu_documents: menu_texts = [document for _, document in menu_documents] self.menu_collection.add( ids=[item.item_id for item, _ in menu_documents], documents=menu_texts, embeddings=self.embedder.embed_documents(menu_texts), metadatas=[ { "name": item.name, "category": item.category, "price": item.price if item.price is not None else -1, "price_label": item.price_label, "source_url": item.source_url, "photo_url": item.photo_url, } for item, _ in menu_documents ], ) return IndexResponse( indexed_knowledge_documents=len(knowledge_documents), indexed_menu_documents=len(menu_documents), menu_items_loaded=len(menu_documents), ) def retrieve_knowledge(self, query: str) -> list[SourceDocument]: if self.knowledge_collection.count() == 0: return [] query_embedding = self.embedder.embed_queries([query])[0] result = self.knowledge_collection.query( query_embeddings=[query_embedding], n_results=settings.top_k, ) documents = result.get("documents", [[]])[0] metadatas = result.get("metadatas", [[]])[0] distances = result.get("distances", [[]])[0] ids = result.get("ids", [[]])[0] sources: list[SourceDocument] = [] for index, document in enumerate(documents): metadata = metadatas[index] published_at = metadata.get("published_at") or None sources.append( SourceDocument( source_id=ids[index], source_type=str(metadata.get("source_type", "unknown")), title=str(metadata.get("title", ids[index])), source_url=str(metadata.get("source_url", settings.site_url)), snippet=document[:400], published_at=published_at, score=distances[index] if index < len(distances) else None, ) ) return sources def build_system_prompt(self, sources: list[SourceDocument]) -> str: context_parts = [] for source in sources: published_label = ( f" | дата: {source.published_at.isoformat()}" if source.published_at else "" ) context_parts.append( f"[{source.source_type}] {source.title}{published_label}\n" f"Источник: {source.source_url}\n" f"{source.snippet}" ) context_block = "\n\n".join(context_parts) if context_parts else "Нет найденного контекста." return ( "Ты помощник шаурмечной Горыч из Волгограда.\n" "Отвечай по-русски, дружелюбно, естественно и клиентоориентированно.\n" "Не начинай каждый ответ с нового приветствия.\n" "Отвечай только на текущий вопрос пользователя и не повторяй без необходимости уже сказанное ранее.\n" "Не перечисляй ассортимент без запроса. Если человек не просил список позиций, не превращай ответ в каталог.\n" "Для рекомендаций предлагай максимум 3 конкретные позиции.\n" "Не выдумывай факты. Если данные расходятся, прямо скажи об этом и укажи источник.\n" "Если вопрос про режим работы, доставку, контакты, адрес, соцсети, способы заказа или общую информацию о заведении, отвечай по контексту и не используй tool меню.\n" "Используй tool find_menu_items только когда пользователь явно просит подобрать, перечислить, сравнить или найти блюда из меню:\n" "- что есть в меню;\n" "- что посоветуете из конкретной категории;\n" "- что есть с определённым ингредиентом;\n" "- что можно до определённого бюджета;\n" "- что острое, сырное, мясное и так далее, если нужен именно подбор позиций.\n" "Если вопрос общий и консультативный, например про вкус, выбор мяса или что лучше взять в целом, сначала ответь по-человечески и не вызывай tool, пока пользователь не попросит конкретные позиции.\n" "Если всё же используешь tool и он вернул позиции, назови их по именам и по возможности укажи цену.\n" "Если tool ничего не нашёл, честно скажи об этом и предложи уточнить запрос.\n" "Если в контексте есть даты, ориентируйся на более свежие данные.\n\n" "Формат ответа:\n" "- Используй только HTML-теги, подходящие для Telegram/aiogram: , , , , ,
, .\n"
            "- Не используй Markdown со звёздочками, подчёркиваниями или решётками.\n"
            "- Не пиши служебные фразы вроде 'выберите вопрос ниже'.\n\n"
            f"Контекст RAG:\n{context_block}"
        )

    def build_tools(self) -> list[dict[str, Any]]:
        return [
            {
                "type": "function",
                "function": {
                    "name": "find_menu_items",
                    "description": "Подбирает блюда из меню Горыча по описанию, бюджету, категории и ингредиентам. Использовать только для явных запросов о меню и конкретных позициях, не использовать для вопросов о режиме работы, доставке, контактах и общей информации о заведении.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "query": {
                                "type": "string",
                                "description": "Свободное описание того, что хочет пользователь.",
                            },
                            "max_price": {
                                "type": "integer",
                                "description": "Максимальная цена в рублях.",
                            },
                            "category": {
                                "type": "string",
                                "description": "Категория блюда, например: пицца, донар, шаурма.",
                            },
                            "must_include": {
                                "type": "array",
                                "items": {"type": "string"},
                                "description": "Ингредиенты или слова, которые желательно включить.",
                            },
                            "must_not_include": {
                                "type": "array",
                                "items": {"type": "string"},
                                "description": "Ингредиенты или слова, которых нужно избегать.",
                            },
                            "limit": {
                                "type": "integer",
                                "description": "Максимум позиций в выдаче.",
                                "default": 5,
                            },
                        },
                        "required": [],
                    },
                },
            }
        ]

    def run_tool(self, name: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
        if name != "find_menu_items":
            return []

        return self.search_menu(
            query=arguments.get("query", ""),
            max_price=arguments.get("max_price"),
            category=arguments.get("category"),
            must_include=arguments.get("must_include"),
            must_not_include=arguments.get("must_not_include"),
            limit=arguments.get("limit", 5),
        )

    def search_menu(
        self,
        query: str = "",
        max_price: int | None = None,
        category: str | None = None,
        must_include: list[str] | None = None,
        must_not_include: list[str] | None = None,
        limit: int = 5,
    ) -> list[dict[str, Any]]:
        candidate_ids: list[str] | None = None
        semantic_ranks: dict[str, int] | None = None

        if query and self.menu_collection.count() > 0:
            query_embedding = self.embedder.embed_queries([query])[0]
            semantic_result = self.menu_collection.query(
                query_embeddings=[query_embedding],
                n_results=min(max(limit * 4, 10), self.menu_collection.count()),
            )
            candidate_ids = semantic_result.get("ids", [[]])[0]
            semantic_ranks = {
                item_id: rank for rank, item_id in enumerate(candidate_ids, start=1)
            }

        return self.menu_catalog.search(
            query=query,
            max_price=max_price,
            category=category,
            must_include=must_include,
            must_not_include=must_not_include,
            limit=limit,
            candidate_ids=candidate_ids,
            semantic_ranks=semantic_ranks,
        )

    async def chat(self, request: ChatRequest) -> ChatResponse:
        sources = self.retrieve_knowledge(request.message)
        messages: list[dict[str, Any]] = [
            {"role": "system", "content": self.build_system_prompt(sources)}
        ]
        for message in request.history:
            messages.append({"role": message.role, "content": message.content})
        messages.append({"role": "user", "content": request.message})

        tools = self.build_tools()
        first_response = await self.openrouter.chat_completion(
            messages=messages,
            tools=tools,
        )
        choice_message = first_response["choices"][0]["message"]
        tool_calls = choice_message.get("tool_calls", [])
        tool_results: list[dict[str, Any]] = []
        model = first_response.get("model", settings.openrouter_model)

        if tool_calls:
            messages.append(
                {
                    "role": "assistant",
                    "content": choice_message.get("content", ""),
                    "tool_calls": tool_calls,
                }
            )

            for tool_call in tool_calls:
                function_name = tool_call["function"]["name"]
                arguments = json.loads(tool_call["function"]["arguments"] or "{}")
                result = self.run_tool(function_name, arguments)
                tool_results.extend(result)
                messages.append(
                    {
                        "role": "tool",
                        "tool_call_id": tool_call["id"],
                        "name": function_name,
                        "content": json.dumps(result, ensure_ascii=False),
                    }
                )

            final_response = await self.openrouter.chat_completion(messages=messages)
            final_message = final_response["choices"][0]["message"]["content"]
            model = final_response.get("model", settings.openrouter_model)

            return ChatResponse(
                answer=final_message,
                model=model,
                sources=sources,
                    tool_results=tool_results,
            )

        answer = choice_message.get("content", "")
        return ChatResponse(
            answer=answer,
            model=model,
            sources=sources,
            tool_results=tool_results,
        )