first commit
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
GORICH_SITE_URL=https://gorych34.ru/
|
||||
|
||||
# ChromaDB
|
||||
CHROMA_PATH=/data/chroma
|
||||
HUGGINGFACE_CACHE_DIR=/data/huggingface
|
||||
KNOWLEDGE_COLLECTION=gorich_knowledge
|
||||
MENU_COLLECTION=gorich_menu
|
||||
MENU_SNAPSHOT_PATH=/data/menu/gorich_menu.json
|
||||
ANONYMIZED_TELEMETRY=false
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_API_KEY=your_openrouter_api_key
|
||||
OPENROUTER_MODEL=mistralai/mistral-medium-3-5
|
||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
|
||||
# Public app metadata
|
||||
PUBLIC_APP_URL=http://localhost:8001
|
||||
PUBLIC_APP_NAME=Gorich Bot RAG
|
||||
|
||||
# Embeddings
|
||||
EMBEDDING_MODEL=sergeyzh/rubert-mini-frida
|
||||
EMBEDDING_QUERY_PREFIX="search_query: "
|
||||
EMBEDDING_DOCUMENT_PREFIX="search_document: "
|
||||
EMBEDDING_MAX_LENGTH=512
|
||||
EMBEDDING_BATCH_SIZE=32
|
||||
|
||||
# RAG
|
||||
REQUEST_TIMEOUT_SECONDS=60
|
||||
RAG_TOP_K=5
|
||||
INDEX_ON_STARTUP=true
|
||||
@@ -0,0 +1,13 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY rag_api/requirements.txt /app/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cpu torch==2.7.0 && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY rag_api/app /app/app
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Settings:
|
||||
app_name: str = "Gorich RAG API"
|
||||
site_url: str = os.getenv("GORICH_SITE_URL", "https://gorych34.ru/")
|
||||
chroma_path: str = os.getenv("CHROMA_PATH", "/data/chroma")
|
||||
huggingface_cache_dir: str = os.getenv("HUGGINGFACE_CACHE_DIR", "/data/huggingface")
|
||||
knowledge_collection: str = os.getenv("KNOWLEDGE_COLLECTION", "gorich_knowledge")
|
||||
menu_collection: str = os.getenv("MENU_COLLECTION", "gorich_menu")
|
||||
menu_snapshot_path: str = os.getenv("MENU_SNAPSHOT_PATH", "/data/menu/gorich_menu.json")
|
||||
openrouter_api_key: str = os.getenv("OPENROUTER_API_KEY", "")
|
||||
openrouter_model: str = os.getenv("OPENROUTER_MODEL", "mistralai/mistral-medium-3-5")
|
||||
openrouter_base_url: str = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
|
||||
public_app_url: str = os.getenv("PUBLIC_APP_URL", "http://localhost:8000")
|
||||
public_app_name: str = os.getenv("PUBLIC_APP_NAME", "Gorich Bot RAG")
|
||||
embedding_model: str = os.getenv(
|
||||
"EMBEDDING_MODEL",
|
||||
"sergeyzh/rubert-mini-frida",
|
||||
)
|
||||
embedding_query_prefix: str = os.getenv("EMBEDDING_QUERY_PREFIX", "search_query: ")
|
||||
embedding_document_prefix: str = os.getenv(
|
||||
"EMBEDDING_DOCUMENT_PREFIX",
|
||||
"search_document: ",
|
||||
)
|
||||
embedding_max_length: int = int(os.getenv("EMBEDDING_MAX_LENGTH", "512"))
|
||||
embedding_batch_size: int = int(os.getenv("EMBEDDING_BATCH_SIZE", "32"))
|
||||
request_timeout: float = float(os.getenv("REQUEST_TIMEOUT_SECONDS", "60"))
|
||||
top_k: int = int(os.getenv("RAG_TOP_K", "5"))
|
||||
index_on_startup: bool = os.getenv("INDEX_ON_STARTUP", "true").lower() == "true"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as functional
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class RuBertMiniFridaEmbedder:
|
||||
def __init__(self) -> None:
|
||||
torch.set_grad_enabled(False)
|
||||
self.device = "cpu"
|
||||
self.max_length = settings.embedding_max_length
|
||||
self.batch_size = settings.embedding_batch_size
|
||||
self.cache_dir = Path(settings.huggingface_cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
settings.embedding_model,
|
||||
cache_dir=str(self.cache_dir),
|
||||
)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
settings.embedding_model,
|
||||
cache_dir=str(self.cache_dir),
|
||||
)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def mean_pool(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
masked_state = hidden_state * attention_mask.unsqueeze(-1).float()
|
||||
summed = torch.sum(masked_state, dim=1)
|
||||
counts = attention_mask.sum(dim=1, keepdim=True).float()
|
||||
return summed / counts
|
||||
|
||||
def _encode(self, texts: Iterable[str], prompt: str) -> list[list[float]]:
|
||||
prepared_texts = [f"{prompt}{text}" for text in texts]
|
||||
if not prepared_texts:
|
||||
return []
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
for start in range(0, len(prepared_texts), self.batch_size):
|
||||
batch = prepared_texts[start : start + self.batch_size]
|
||||
encoded = self.tokenizer(
|
||||
batch,
|
||||
max_length=self.max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
encoded = {key: value.to(self.device) for key, value in encoded.items()}
|
||||
outputs = self.model(**encoded)
|
||||
pooled = self.mean_pool(outputs.last_hidden_state, encoded["attention_mask"])
|
||||
normalized = functional.normalize(pooled, p=2, dim=1)
|
||||
embeddings.extend(normalized.cpu().tolist())
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
return self._encode(texts, prompt=settings.embedding_document_prefix)
|
||||
|
||||
def embed_queries(self, texts: Iterable[str]) -> list[list[float]]:
|
||||
return self._encode(texts, prompt=settings.embedding_query_prefix)
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .config import settings
|
||||
from .menu_catalog import MenuCatalog
|
||||
from .models import ChatRequest, ChatResponse, IndexResponse
|
||||
from .service import RagService
|
||||
|
||||
|
||||
rag_service = RagService()
|
||||
menu_catalog = MenuCatalog()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
if settings.index_on_startup:
|
||||
await rag_service.reindex()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title=settings.app_name, version="1.0.0", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
return await rag_service.chat(request)
|
||||
|
||||
|
||||
@app.post("/admin/reindex", response_model=IndexResponse)
|
||||
async def reindex() -> IndexResponse:
|
||||
return await rag_service.reindex()
|
||||
|
||||
|
||||
@app.get("/menu/search")
|
||||
async def search_menu(
|
||||
query: str = "",
|
||||
max_price: int | None = None,
|
||||
category: str | None = None,
|
||||
must_include: str | None = None,
|
||||
must_not_include: str | None = None,
|
||||
limit: int = 5,
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"items": rag_service.search_menu(
|
||||
query=query,
|
||||
max_price=max_price,
|
||||
category=category,
|
||||
must_include=[value.strip() for value in must_include.split(",")]
|
||||
if must_include
|
||||
else None,
|
||||
must_not_include=[value.strip() for value in must_not_include.split(",")]
|
||||
if must_not_include
|
||||
else None,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from .config import settings
|
||||
from .models import MenuItem, MenuSnapshot
|
||||
|
||||
|
||||
def tokenize(value: str) -> list[str]:
|
||||
raw_tokens = re.findall(r"[a-zA-Zа-яА-Я0-9]+", value.lower())
|
||||
return [
|
||||
token
|
||||
for token in raw_tokens
|
||||
if token not in QUERY_STOPWORDS and (len(token) > 2 or token.isdigit())
|
||||
]
|
||||
|
||||
|
||||
QUERY_STOPWORDS = {
|
||||
"что",
|
||||
"у",
|
||||
"вас",
|
||||
"есть",
|
||||
"из",
|
||||
"как",
|
||||
"ли",
|
||||
"мне",
|
||||
"могу",
|
||||
"хочу",
|
||||
"надо",
|
||||
"для",
|
||||
"под",
|
||||
"про",
|
||||
"или",
|
||||
"это",
|
||||
"эта",
|
||||
"этот",
|
||||
"какой",
|
||||
"какая",
|
||||
"какие",
|
||||
"посоветуй",
|
||||
"посоветуйте",
|
||||
"подбери",
|
||||
"подобрать",
|
||||
"вкусную",
|
||||
"вкусный",
|
||||
"вкусное",
|
||||
}
|
||||
|
||||
|
||||
QUERY_HINTS = {
|
||||
"шаурма": ["шаурма", "классика"],
|
||||
"шаурмы": ["шаурма", "классика"],
|
||||
"шаверма": ["шаурма", "классика"],
|
||||
"шавуха": ["шаурма", "классика"],
|
||||
"острый": ["халапеньо", "шрирача", "том", "ям"],
|
||||
"острая": ["халапеньо", "шрирача", "том", "ям"],
|
||||
"острое": ["халапеньо", "шрирача", "том", "ям"],
|
||||
"острого": ["халапеньо", "шрирача", "том", "ям"],
|
||||
"пикантный": ["халапеньо", "шрирача", "том", "ям"],
|
||||
"сыр": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||||
"сыром": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||||
"сыра": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||||
"сырный": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||||
"сырная": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||||
"рыбный": ["лосось"],
|
||||
"рыбная": ["лосось"],
|
||||
"мясной": ["свинина", "курица", "ростбиф", "колбаски", "пепперони"],
|
||||
"мясная": ["свинина", "курица", "ростбиф", "колбаски", "пепперони"],
|
||||
}
|
||||
|
||||
CATEGORY_ALIASES = {
|
||||
"шаурмы": "шаурма",
|
||||
"шаверма": "шаурма",
|
||||
"шавуха": "шаурма",
|
||||
}
|
||||
|
||||
|
||||
class MenuCatalog:
|
||||
def __init__(self) -> None:
|
||||
self.snapshot_path = Path(settings.menu_snapshot_path)
|
||||
|
||||
def exists(self) -> bool:
|
||||
return self.snapshot_path.exists()
|
||||
|
||||
def load_snapshot(self) -> MenuSnapshot:
|
||||
data = json.loads(self.snapshot_path.read_text(encoding="utf-8"))
|
||||
return MenuSnapshot.model_validate(data)
|
||||
|
||||
def menu_documents(self) -> list[tuple[MenuItem, str]]:
|
||||
if not self.exists():
|
||||
return []
|
||||
|
||||
snapshot = self.load_snapshot()
|
||||
documents: list[tuple[MenuItem, str]] = []
|
||||
for item in snapshot.items:
|
||||
text = " | ".join(
|
||||
[
|
||||
item.name,
|
||||
item.category,
|
||||
item.description,
|
||||
", ".join(item.ingredients),
|
||||
item.size or "",
|
||||
item.price_label,
|
||||
]
|
||||
)
|
||||
documents.append((item, text))
|
||||
return documents
|
||||
|
||||
def items_map(self) -> dict[str, MenuItem]:
|
||||
if not self.exists():
|
||||
return {}
|
||||
|
||||
snapshot = self.load_snapshot()
|
||||
return {item.item_id: item for item in snapshot.items}
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str = "",
|
||||
max_price: int | None = None,
|
||||
category: str | None = None,
|
||||
must_include: list[str] | None = None,
|
||||
must_not_include: list[str] | None = None,
|
||||
limit: int = 5,
|
||||
candidate_ids: list[str] | None = None,
|
||||
semantic_ranks: dict[str, int] | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
if not self.exists():
|
||||
return []
|
||||
|
||||
must_include = [value.lower() for value in (must_include or [])]
|
||||
must_not_include = [value.lower() for value in (must_not_include or [])]
|
||||
query_tokens = tokenize(query)
|
||||
normalized_category = category.lower() if category else None
|
||||
if normalized_category in CATEGORY_ALIASES:
|
||||
normalized_category = CATEGORY_ALIASES[normalized_category]
|
||||
hint_tokens = []
|
||||
for token in query_tokens:
|
||||
hint_tokens.extend(QUERY_HINTS.get(token, []))
|
||||
candidate_set = set(candidate_ids or [])
|
||||
semantic_ranks = semantic_ranks or {}
|
||||
|
||||
scored_items: list[tuple[int, MenuItem]] = []
|
||||
for item, text in self.menu_documents():
|
||||
if candidate_set and item.item_id not in candidate_set:
|
||||
continue
|
||||
|
||||
lowered = text.lower()
|
||||
|
||||
if normalized_category and item.category.lower() != normalized_category:
|
||||
continue
|
||||
if max_price is not None and item.price is not None and item.price > max_price:
|
||||
continue
|
||||
if max_price is not None and item.price is None:
|
||||
continue
|
||||
if any(value not in lowered for value in must_include):
|
||||
continue
|
||||
if any(value in lowered for value in must_not_include):
|
||||
continue
|
||||
|
||||
score = 0
|
||||
for token in query_tokens:
|
||||
if token in lowered:
|
||||
score += 3
|
||||
if token in item.name.lower():
|
||||
score += 5
|
||||
|
||||
for token in hint_tokens:
|
||||
if token in lowered:
|
||||
score += 6
|
||||
if token == item.category.lower():
|
||||
score += 8
|
||||
|
||||
for token in must_include:
|
||||
if token in lowered:
|
||||
score += 4
|
||||
|
||||
if item.item_id in semantic_ranks:
|
||||
score += max(0, 20 - semantic_ranks[item.item_id])
|
||||
|
||||
if not query_tokens and not must_include and category:
|
||||
score += 1
|
||||
|
||||
scored_items.append((score, item))
|
||||
|
||||
scored_items.sort(
|
||||
key=lambda row: (
|
||||
row[0],
|
||||
-(row[1].price or 0),
|
||||
row[1].name,
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
for score, item in scored_items[:limit]:
|
||||
results.append(
|
||||
{
|
||||
"item_id": item.item_id,
|
||||
"name": item.name,
|
||||
"category": item.category,
|
||||
"description": item.description,
|
||||
"ingredients": item.ingredients,
|
||||
"price": item.price,
|
||||
"price_label": item.price_label,
|
||||
"size": item.size,
|
||||
"photo_url": item.photo_url,
|
||||
"source_url": item.source_url,
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
history: list[ChatMessage] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SourceDocument(BaseModel):
|
||||
source_id: str
|
||||
source_type: str
|
||||
title: str
|
||||
source_url: str
|
||||
snippet: str
|
||||
published_at: datetime | None = None
|
||||
score: float | None = None
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
answer: str
|
||||
model: str
|
||||
sources: list[SourceDocument]
|
||||
tool_results: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class IndexResponse(BaseModel):
|
||||
indexed_knowledge_documents: int
|
||||
indexed_menu_documents: int
|
||||
menu_items_loaded: int
|
||||
|
||||
|
||||
class KnowledgeDocument(BaseModel):
|
||||
doc_id: str
|
||||
title: str
|
||||
text: str
|
||||
source_type: str
|
||||
source_url: str
|
||||
published_at: datetime | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MenuItem(BaseModel):
|
||||
item_id: str
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
ingredients: list[str]
|
||||
price: int | None = None
|
||||
price_label: str
|
||||
size: str | None = None
|
||||
photo_url: str
|
||||
source_url: str
|
||||
scraped_at: datetime
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MenuSnapshot(BaseModel):
|
||||
source_url: str
|
||||
scraped_at: datetime
|
||||
total_items: int
|
||||
items: list[MenuItem]
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb.telemetry.product import ProductTelemetryClient, ProductTelemetryEvent
|
||||
from overrides import override
|
||||
|
||||
|
||||
class NoOpProductTelemetry(ProductTelemetryClient):
|
||||
@override
|
||||
def capture(self, event: ProductTelemetryEvent) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class OpenRouterClient:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.openrouter_base_url.rstrip("/")
|
||||
self.api_key = settings.openrouter_api_key
|
||||
self.model = settings.openrouter_model
|
||||
self.timeout = settings.request_timeout
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
temperature: float = 0.2,
|
||||
) -> dict[str, Any]:
|
||||
if not self.api_key:
|
||||
raise RuntimeError("OPENROUTER_API_KEY is not configured")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": settings.public_app_url,
|
||||
"X-Title": settings.public_app_name,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .config import settings
|
||||
from .embeddings import RuBertMiniFridaEmbedder
|
||||
from .menu_catalog import MenuCatalog
|
||||
from .models import ChatRequest, ChatResponse, IndexResponse, KnowledgeDocument, SourceDocument
|
||||
from .openrouter_client import OpenRouterClient
|
||||
from .site_scraper import SiteKnowledgeScraper
|
||||
from .vector_store import VectorStore
|
||||
|
||||
|
||||
class RagService:
|
||||
def __init__(self) -> None:
|
||||
self.vector_store = VectorStore()
|
||||
self.embedder = RuBertMiniFridaEmbedder()
|
||||
self.site_scraper = SiteKnowledgeScraper()
|
||||
self.menu_catalog = MenuCatalog()
|
||||
self.openrouter = OpenRouterClient()
|
||||
self.knowledge_collection = self.vector_store.get_collection(
|
||||
settings.knowledge_collection
|
||||
)
|
||||
self.menu_collection = self.vector_store.get_collection(settings.menu_collection)
|
||||
|
||||
@staticmethod
|
||||
def clear_collection(collection: Any) -> None:
|
||||
ids = collection.get(include=[])["ids"]
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
|
||||
async def reindex(self) -> IndexResponse:
|
||||
knowledge_documents = await self.site_scraper.scrape()
|
||||
self.clear_collection(self.knowledge_collection)
|
||||
self.clear_collection(self.menu_collection)
|
||||
|
||||
if knowledge_documents:
|
||||
knowledge_texts = [doc.text for doc in knowledge_documents]
|
||||
self.knowledge_collection.add(
|
||||
ids=[doc.doc_id for doc in knowledge_documents],
|
||||
documents=knowledge_texts,
|
||||
embeddings=self.embedder.embed_documents(knowledge_texts),
|
||||
metadatas=[
|
||||
{
|
||||
"title": doc.title,
|
||||
"source_type": doc.source_type,
|
||||
"source_url": doc.source_url,
|
||||
"published_at": doc.published_at.isoformat()
|
||||
if doc.published_at
|
||||
else "",
|
||||
**doc.metadata,
|
||||
}
|
||||
for doc in knowledge_documents
|
||||
],
|
||||
)
|
||||
|
||||
menu_documents = self.menu_catalog.menu_documents()
|
||||
if menu_documents:
|
||||
menu_texts = [document for _, document in menu_documents]
|
||||
self.menu_collection.add(
|
||||
ids=[item.item_id for item, _ in menu_documents],
|
||||
documents=menu_texts,
|
||||
embeddings=self.embedder.embed_documents(menu_texts),
|
||||
metadatas=[
|
||||
{
|
||||
"name": item.name,
|
||||
"category": item.category,
|
||||
"price": item.price if item.price is not None else -1,
|
||||
"price_label": item.price_label,
|
||||
"source_url": item.source_url,
|
||||
"photo_url": item.photo_url,
|
||||
}
|
||||
for item, _ in menu_documents
|
||||
],
|
||||
)
|
||||
|
||||
return IndexResponse(
|
||||
indexed_knowledge_documents=len(knowledge_documents),
|
||||
indexed_menu_documents=len(menu_documents),
|
||||
menu_items_loaded=len(menu_documents),
|
||||
)
|
||||
|
||||
def retrieve_knowledge(self, query: str) -> list[SourceDocument]:
|
||||
if self.knowledge_collection.count() == 0:
|
||||
return []
|
||||
|
||||
query_embedding = self.embedder.embed_queries([query])[0]
|
||||
result = self.knowledge_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=settings.top_k,
|
||||
)
|
||||
documents = result.get("documents", [[]])[0]
|
||||
metadatas = result.get("metadatas", [[]])[0]
|
||||
distances = result.get("distances", [[]])[0]
|
||||
ids = result.get("ids", [[]])[0]
|
||||
|
||||
sources: list[SourceDocument] = []
|
||||
for index, document in enumerate(documents):
|
||||
metadata = metadatas[index]
|
||||
published_at = metadata.get("published_at") or None
|
||||
sources.append(
|
||||
SourceDocument(
|
||||
source_id=ids[index],
|
||||
source_type=str(metadata.get("source_type", "unknown")),
|
||||
title=str(metadata.get("title", ids[index])),
|
||||
source_url=str(metadata.get("source_url", settings.site_url)),
|
||||
snippet=document[:400],
|
||||
published_at=published_at,
|
||||
score=distances[index] if index < len(distances) else None,
|
||||
)
|
||||
)
|
||||
return sources
|
||||
|
||||
def build_system_prompt(self, sources: list[SourceDocument]) -> str:
|
||||
context_parts = []
|
||||
for source in sources:
|
||||
published_label = (
|
||||
f" | дата: {source.published_at.isoformat()}"
|
||||
if source.published_at
|
||||
else ""
|
||||
)
|
||||
context_parts.append(
|
||||
f"[{source.source_type}] {source.title}{published_label}\n"
|
||||
f"Источник: {source.source_url}\n"
|
||||
f"{source.snippet}"
|
||||
)
|
||||
|
||||
context_block = "\n\n".join(context_parts) if context_parts else "Нет найденного контекста."
|
||||
return (
|
||||
"Ты помощник шаурмечной Горыч из Волгограда.\n"
|
||||
"Отвечай по-русски, дружелюбно, естественно и клиентоориентированно.\n"
|
||||
"Не начинай каждый ответ с нового приветствия.\n"
|
||||
"Отвечай только на текущий вопрос пользователя и не повторяй без необходимости уже сказанное ранее.\n"
|
||||
"Не перечисляй ассортимент без запроса. Если человек не просил список позиций, не превращай ответ в каталог.\n"
|
||||
"Для рекомендаций предлагай максимум 3 конкретные позиции.\n"
|
||||
"Не выдумывай факты. Если данные расходятся, прямо скажи об этом и укажи источник.\n"
|
||||
"Если вопрос про режим работы, доставку, контакты, адрес, соцсети, способы заказа или общую информацию о заведении, отвечай по контексту и не используй tool меню.\n"
|
||||
"Используй tool find_menu_items только когда пользователь явно просит подобрать, перечислить, сравнить или найти блюда из меню:\n"
|
||||
"- что есть в меню;\n"
|
||||
"- что посоветуете из конкретной категории;\n"
|
||||
"- что есть с определённым ингредиентом;\n"
|
||||
"- что можно до определённого бюджета;\n"
|
||||
"- что острое, сырное, мясное и так далее, если нужен именно подбор позиций.\n"
|
||||
"Если вопрос общий и консультативный, например про вкус, выбор мяса или что лучше взять в целом, сначала ответь по-человечески и не вызывай tool, пока пользователь не попросит конкретные позиции.\n"
|
||||
"Если всё же используешь tool и он вернул позиции, назови их по именам и по возможности укажи цену.\n"
|
||||
"Если tool ничего не нашёл, честно скажи об этом и предложи уточнить запрос.\n"
|
||||
"Если в контексте есть даты, ориентируйся на более свежие данные.\n\n"
|
||||
"Формат ответа:\n"
|
||||
"- Используй только HTML-теги, подходящие для Telegram/aiogram: <b>, <i>, <u>, <s>, <code>, <pre>, <a href=\"...\">.\n"
|
||||
"- Не используй Markdown со звёздочками, подчёркиваниями или решётками.\n"
|
||||
"- Не пиши служебные фразы вроде 'выберите вопрос ниже'.\n\n"
|
||||
f"Контекст RAG:\n{context_block}"
|
||||
)
|
||||
|
||||
def build_tools(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_menu_items",
|
||||
"description": "Подбирает блюда из меню Горыча по описанию, бюджету, категории и ингредиентам. Использовать только для явных запросов о меню и конкретных позициях, не использовать для вопросов о режиме работы, доставке, контактах и общей информации о заведении.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Свободное описание того, что хочет пользователь.",
|
||||
},
|
||||
"max_price": {
|
||||
"type": "integer",
|
||||
"description": "Максимальная цена в рублях.",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Категория блюда, например: пицца, донар, шаурма.",
|
||||
},
|
||||
"must_include": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Ингредиенты или слова, которые желательно включить.",
|
||||
},
|
||||
"must_not_include": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Ингредиенты или слова, которых нужно избегать.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Максимум позиций в выдаче.",
|
||||
"default": 5,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def run_tool(self, name: str, arguments: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
if name != "find_menu_items":
|
||||
return []
|
||||
|
||||
return self.search_menu(
|
||||
query=arguments.get("query", ""),
|
||||
max_price=arguments.get("max_price"),
|
||||
category=arguments.get("category"),
|
||||
must_include=arguments.get("must_include"),
|
||||
must_not_include=arguments.get("must_not_include"),
|
||||
limit=arguments.get("limit", 5),
|
||||
)
|
||||
|
||||
def search_menu(
|
||||
self,
|
||||
query: str = "",
|
||||
max_price: int | None = None,
|
||||
category: str | None = None,
|
||||
must_include: list[str] | None = None,
|
||||
must_not_include: list[str] | None = None,
|
||||
limit: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate_ids: list[str] | None = None
|
||||
semantic_ranks: dict[str, int] | None = None
|
||||
|
||||
if query and self.menu_collection.count() > 0:
|
||||
query_embedding = self.embedder.embed_queries([query])[0]
|
||||
semantic_result = self.menu_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=min(max(limit * 4, 10), self.menu_collection.count()),
|
||||
)
|
||||
candidate_ids = semantic_result.get("ids", [[]])[0]
|
||||
semantic_ranks = {
|
||||
item_id: rank for rank, item_id in enumerate(candidate_ids, start=1)
|
||||
}
|
||||
|
||||
return self.menu_catalog.search(
|
||||
query=query,
|
||||
max_price=max_price,
|
||||
category=category,
|
||||
must_include=must_include,
|
||||
must_not_include=must_not_include,
|
||||
limit=limit,
|
||||
candidate_ids=candidate_ids,
|
||||
semantic_ranks=semantic_ranks,
|
||||
)
|
||||
|
||||
async def chat(self, request: ChatRequest) -> ChatResponse:
|
||||
sources = self.retrieve_knowledge(request.message)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": self.build_system_prompt(sources)}
|
||||
]
|
||||
for message in request.history:
|
||||
messages.append({"role": message.role, "content": message.content})
|
||||
messages.append({"role": "user", "content": request.message})
|
||||
|
||||
tools = self.build_tools()
|
||||
first_response = await self.openrouter.chat_completion(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
)
|
||||
choice_message = first_response["choices"][0]["message"]
|
||||
tool_calls = choice_message.get("tool_calls", [])
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
model = first_response.get("model", settings.openrouter_model)
|
||||
|
||||
if tool_calls:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": choice_message.get("content", ""),
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
arguments = json.loads(tool_call["function"]["arguments"] or "{}")
|
||||
result = self.run_tool(function_name, arguments)
|
||||
tool_results.extend(result)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"name": function_name,
|
||||
"content": json.dumps(result, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
|
||||
final_response = await self.openrouter.chat_completion(messages=messages)
|
||||
final_message = final_response["choices"][0]["message"]["content"]
|
||||
model = final_response.get("model", settings.openrouter_model)
|
||||
|
||||
return ChatResponse(
|
||||
answer=final_message,
|
||||
model=model,
|
||||
sources=sources,
|
||||
tool_results=tool_results,
|
||||
)
|
||||
|
||||
answer = choice_message.get("content", "")
|
||||
return ChatResponse(
|
||||
answer=answer,
|
||||
model=model,
|
||||
sources=sources,
|
||||
tool_results=tool_results,
|
||||
)
|
||||
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from .config import settings
|
||||
from .models import KnowledgeDocument
|
||||
|
||||
|
||||
MAP_PATTERN = re.compile(
|
||||
r"yandexMaps\.addMap\('([^']+)'\s*,\s*'([^']+)'\s*,\s*'([^']+)'\)"
|
||||
)
|
||||
|
||||
|
||||
def normalize_spaces(value: str) -> str:
|
||||
return " ".join(value.replace("\xa0", " ").split())
|
||||
|
||||
|
||||
def deduplicate_preserving_order(values: list[str]) -> list[str]:
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
if value and value not in seen:
|
||||
seen.add(value)
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
|
||||
def is_meaningful_value(value: str) -> bool:
|
||||
return any(char.isalnum() for char in value)
|
||||
|
||||
|
||||
class SiteKnowledgeScraper:
|
||||
ABOUT_MARKER = "ТЕРРИТОРИЯ БЫСТРОГО ПИТАНИЯ В ВОЛГОГРАДЕ"
|
||||
MENU_MARKER = "МЕНЮ"
|
||||
DELIVERY_MARKER = "ДОСТАВКА"
|
||||
CONTACT_MARKER = "КОНТАКТЫ"
|
||||
CONTACT_END_MARKERS = ("Закрыть", "OK")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.site_url = settings.site_url
|
||||
self.timeout = settings.request_timeout
|
||||
|
||||
async def fetch_homepage(self) -> str:
|
||||
async with httpx.AsyncClient(
|
||||
headers={"User-Agent": "Mozilla/5.0"},
|
||||
follow_redirects=True,
|
||||
timeout=self.timeout,
|
||||
) as client:
|
||||
response = await client.get(self.site_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
def visible_strings(self, soup: BeautifulSoup) -> list[str]:
|
||||
return [
|
||||
normalized
|
||||
for text in soup.stripped_strings
|
||||
for normalized in [normalize_spaces(text)]
|
||||
if normalized and is_meaningful_value(normalized)
|
||||
]
|
||||
|
||||
def find_marker(self, values: list[str], marker: str, start: int = 0) -> int | None:
|
||||
for index in range(start, len(values)):
|
||||
if values[index] == marker:
|
||||
return index
|
||||
return None
|
||||
|
||||
def find_last_marker(self, values: list[str], marker: str, start: int = 0) -> int | None:
|
||||
for index in range(len(values) - 1, start - 1, -1):
|
||||
if values[index] == marker:
|
||||
return index
|
||||
return None
|
||||
|
||||
def slice_between_markers(
|
||||
self,
|
||||
values: list[str],
|
||||
start_marker: str,
|
||||
end_markers: tuple[str, ...],
|
||||
start_at: int = 0,
|
||||
) -> list[str]:
|
||||
start_index = self.find_marker(values, start_marker, start_at)
|
||||
if start_index is None:
|
||||
return []
|
||||
|
||||
end_index = len(values)
|
||||
for marker in end_markers:
|
||||
marker_index = self.find_marker(values, marker, start_index + 1)
|
||||
if marker_index is not None:
|
||||
end_index = min(end_index, marker_index)
|
||||
|
||||
return values[start_index:end_index]
|
||||
|
||||
def extract_social_links(self, soup: BeautifulSoup) -> list[str]:
|
||||
links: list[str] = []
|
||||
for node in soup.select("[data-page-link]"):
|
||||
href = node.get("data-page-link")
|
||||
label = normalize_spaces(node.get_text(" ", strip=True))
|
||||
if not href:
|
||||
continue
|
||||
if label:
|
||||
links.append(f"{label}: {href}")
|
||||
else:
|
||||
links.append(str(href))
|
||||
return deduplicate_preserving_order(links)
|
||||
|
||||
def extract_map_coordinates(self, html: str) -> str | None:
|
||||
match = MAP_PATTERN.search(html)
|
||||
if not match:
|
||||
return None
|
||||
latitude = normalize_spaces(match.group(2))
|
||||
longitude = normalize_spaces(match.group(3))
|
||||
return f"{latitude}, {longitude}"
|
||||
|
||||
def parse_homepage(self, html: str) -> list[KnowledgeDocument]:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
strings = self.visible_strings(soup)
|
||||
documents: list[KnowledgeDocument] = []
|
||||
scraped_at = datetime.now(timezone.utc)
|
||||
|
||||
meta_description = soup.select_one('meta[name="description"]')
|
||||
if meta_description and meta_description.get("content"):
|
||||
documents.append(
|
||||
KnowledgeDocument(
|
||||
doc_id="site-meta-description",
|
||||
title="Краткое описание заведения",
|
||||
text=normalize_spaces(meta_description["content"]),
|
||||
source_type="about",
|
||||
source_url=self.site_url,
|
||||
metadata={"scraped_at": scraped_at.isoformat()},
|
||||
)
|
||||
)
|
||||
|
||||
about_section = self.slice_between_markers(
|
||||
strings,
|
||||
self.ABOUT_MARKER,
|
||||
(self.MENU_MARKER,),
|
||||
)
|
||||
if about_section:
|
||||
documents.append(
|
||||
KnowledgeDocument(
|
||||
doc_id="site-about",
|
||||
title=about_section[0],
|
||||
text="\n".join(deduplicate_preserving_order(about_section[1:])),
|
||||
source_type="about",
|
||||
source_url=self.site_url,
|
||||
metadata={"scraped_at": scraped_at.isoformat()},
|
||||
)
|
||||
)
|
||||
|
||||
social_links = self.extract_social_links(soup)
|
||||
if social_links:
|
||||
documents.append(
|
||||
KnowledgeDocument(
|
||||
doc_id="site-links",
|
||||
title="Соцсети и внешние площадки",
|
||||
text="\n".join(social_links),
|
||||
source_type="links",
|
||||
source_url=self.site_url,
|
||||
metadata={"scraped_at": scraped_at.isoformat()},
|
||||
)
|
||||
)
|
||||
|
||||
menu_index = self.find_marker(strings, self.MENU_MARKER)
|
||||
delivery_start = self.find_last_marker(
|
||||
strings,
|
||||
self.DELIVERY_MARKER,
|
||||
start=(menu_index + 1) if menu_index is not None else 0,
|
||||
)
|
||||
contact_start = self.find_last_marker(
|
||||
strings,
|
||||
self.CONTACT_MARKER,
|
||||
start=(delivery_start + 1) if delivery_start is not None else 0,
|
||||
)
|
||||
delivery_section = (
|
||||
strings[delivery_start:contact_start]
|
||||
if delivery_start is not None and contact_start is not None and contact_start > delivery_start
|
||||
else []
|
||||
)
|
||||
if delivery_section:
|
||||
documents.append(
|
||||
KnowledgeDocument(
|
||||
doc_id="site-delivery",
|
||||
title=delivery_section[0],
|
||||
text="\n".join(deduplicate_preserving_order(delivery_section[1:])),
|
||||
source_type="delivery",
|
||||
source_url=self.site_url,
|
||||
metadata={"scraped_at": scraped_at.isoformat()},
|
||||
)
|
||||
)
|
||||
|
||||
auth_index = len(strings)
|
||||
if contact_start is not None:
|
||||
for marker in self.CONTACT_END_MARKERS:
|
||||
marker_index = self.find_marker(strings, marker, contact_start + 1)
|
||||
if marker_index is not None:
|
||||
auth_index = min(auth_index, marker_index)
|
||||
contact_section = (
|
||||
strings[contact_start:auth_index]
|
||||
if contact_start is not None and auth_index > contact_start
|
||||
else []
|
||||
)
|
||||
if contact_section:
|
||||
metadata = {"scraped_at": scraped_at.isoformat()}
|
||||
coordinates = self.extract_map_coordinates(html)
|
||||
if coordinates:
|
||||
metadata["map_coordinates"] = coordinates
|
||||
|
||||
documents.append(
|
||||
KnowledgeDocument(
|
||||
doc_id="site-contact",
|
||||
title=contact_section[0],
|
||||
text="\n".join(deduplicate_preserving_order(contact_section[1:])),
|
||||
source_type="contact",
|
||||
source_url=self.site_url,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
async def scrape(self) -> list[KnowledgeDocument]:
|
||||
html = await self.fetch_homepage()
|
||||
return self.parse_homepage(html)
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self) -> None:
|
||||
chroma_settings = ChromaSettings(
|
||||
anonymized_telemetry=False,
|
||||
chroma_product_telemetry_impl="app.noop_telemetry.NoOpProductTelemetry",
|
||||
chroma_telemetry_impl="app.noop_telemetry.NoOpProductTelemetry",
|
||||
)
|
||||
self.client = PersistentClient(
|
||||
path=settings.chroma_path,
|
||||
settings=chroma_settings,
|
||||
)
|
||||
|
||||
def get_collection(self, name: str) -> Collection:
|
||||
return self.client.get_or_create_collection(name=name)
|
||||
@@ -0,0 +1,7 @@
|
||||
beautifulsoup4==4.12.3
|
||||
chromadb==1.0.8
|
||||
fastapi==0.115.9
|
||||
httpx==0.28.1
|
||||
pydantic==2.11.4
|
||||
transformers==4.57.1
|
||||
uvicorn==0.34.2
|
||||
Reference in New Issue
Block a user