Files
GorichBot/rag_api/app/menu_catalog.py
T
2026-05-12 23:37:04 +03:00

215 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import re
from pathlib import Path
from .config import settings
from .models import MenuItem, MenuSnapshot
def tokenize(value: str) -> list[str]:
raw_tokens = re.findall(r"[a-zA-Zа-яА-Я0-9]+", value.lower())
return [
token
for token in raw_tokens
if token not in QUERY_STOPWORDS and (len(token) > 2 or token.isdigit())
]
QUERY_STOPWORDS = {
"что",
"у",
"вас",
"есть",
"из",
"как",
"ли",
"мне",
"могу",
"хочу",
"надо",
"для",
"под",
"про",
"или",
"это",
"эта",
"этот",
"какой",
"какая",
"какие",
"посоветуй",
"посоветуйте",
"подбери",
"подобрать",
"вкусную",
"вкусный",
"вкусное",
}
QUERY_HINTS = {
"шаурма": ["шаурма", "классика"],
"шаурмы": ["шаурма", "классика"],
"шаверма": ["шаурма", "классика"],
"шавуха": ["шаурма", "классика"],
"острый": ["халапеньо", "шрирача", "том", "ям"],
"острая": ["халапеньо", "шрирача", "том", "ям"],
"острое": ["халапеньо", "шрирача", "том", "ям"],
"острого": ["халапеньо", "шрирача", "том", "ям"],
"пикантный": ["халапеньо", "шрирача", "том", "ям"],
"сыр": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
"сыром": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
"сыра": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
"сырный": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
"сырная": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
"рыбный": ["лосось"],
"рыбная": ["лосось"],
"мясной": ["свинина", "курица", "ростбиф", "колбаски", "пепперони"],
"мясная": ["свинина", "курица", "ростбиф", "колбаски", "пепперони"],
}
CATEGORY_ALIASES = {
"шаурмы": "шаурма",
"шаверма": "шаурма",
"шавуха": "шаурма",
}
class MenuCatalog:
def __init__(self) -> None:
self.snapshot_path = Path(settings.menu_snapshot_path)
def exists(self) -> bool:
return self.snapshot_path.exists()
def load_snapshot(self) -> MenuSnapshot:
data = json.loads(self.snapshot_path.read_text(encoding="utf-8"))
return MenuSnapshot.model_validate(data)
def menu_documents(self) -> list[tuple[MenuItem, str]]:
if not self.exists():
return []
snapshot = self.load_snapshot()
documents: list[tuple[MenuItem, str]] = []
for item in snapshot.items:
text = " | ".join(
[
item.name,
item.category,
item.description,
", ".join(item.ingredients),
item.size or "",
item.price_label,
]
)
documents.append((item, text))
return documents
def items_map(self) -> dict[str, MenuItem]:
if not self.exists():
return {}
snapshot = self.load_snapshot()
return {item.item_id: item for item in snapshot.items}
def search(
self,
query: str = "",
max_price: int | None = None,
category: str | None = None,
must_include: list[str] | None = None,
must_not_include: list[str] | None = None,
limit: int = 5,
candidate_ids: list[str] | None = None,
semantic_ranks: dict[str, int] | None = None,
) -> list[dict[str, object]]:
if not self.exists():
return []
must_include = [value.lower() for value in (must_include or [])]
must_not_include = [value.lower() for value in (must_not_include or [])]
query_tokens = tokenize(query)
normalized_category = category.lower() if category else None
if normalized_category in CATEGORY_ALIASES:
normalized_category = CATEGORY_ALIASES[normalized_category]
hint_tokens = []
for token in query_tokens:
hint_tokens.extend(QUERY_HINTS.get(token, []))
candidate_set = set(candidate_ids or [])
semantic_ranks = semantic_ranks or {}
scored_items: list[tuple[int, MenuItem]] = []
for item, text in self.menu_documents():
if candidate_set and item.item_id not in candidate_set:
continue
lowered = text.lower()
if normalized_category and item.category.lower() != normalized_category:
continue
if max_price is not None and item.price is not None and item.price > max_price:
continue
if max_price is not None and item.price is None:
continue
if any(value not in lowered for value in must_include):
continue
if any(value in lowered for value in must_not_include):
continue
score = 0
for token in query_tokens:
if token in lowered:
score += 3
if token in item.name.lower():
score += 5
for token in hint_tokens:
if token in lowered:
score += 6
if token == item.category.lower():
score += 8
for token in must_include:
if token in lowered:
score += 4
if item.item_id in semantic_ranks:
score += max(0, 20 - semantic_ranks[item.item_id])
if not query_tokens and not must_include and category:
score += 1
scored_items.append((score, item))
scored_items.sort(
key=lambda row: (
row[0],
-(row[1].price or 0),
row[1].name,
),
reverse=True,
)
results: list[dict[str, object]] = []
for score, item in scored_items[:limit]:
results.append(
{
"item_id": item.item_id,
"name": item.name,
"category": item.category,
"description": item.description,
"ingredients": item.ingredients,
"price": item.price,
"price_label": item.price_label,
"size": item.size,
"photo_url": item.photo_url,
"source_url": item.source_url,
"score": score,
}
)
return results