215 lines
6.9 KiB
Python
215 lines
6.9 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from .config import settings
|
||
from .models import MenuItem, MenuSnapshot
|
||
|
||
|
||
def tokenize(value: str) -> list[str]:
|
||
raw_tokens = re.findall(r"[a-zA-Zа-яА-Я0-9]+", value.lower())
|
||
return [
|
||
token
|
||
for token in raw_tokens
|
||
if token not in QUERY_STOPWORDS and (len(token) > 2 or token.isdigit())
|
||
]
|
||
|
||
|
||
QUERY_STOPWORDS = {
|
||
"что",
|
||
"у",
|
||
"вас",
|
||
"есть",
|
||
"из",
|
||
"как",
|
||
"ли",
|
||
"мне",
|
||
"могу",
|
||
"хочу",
|
||
"надо",
|
||
"для",
|
||
"под",
|
||
"про",
|
||
"или",
|
||
"это",
|
||
"эта",
|
||
"этот",
|
||
"какой",
|
||
"какая",
|
||
"какие",
|
||
"посоветуй",
|
||
"посоветуйте",
|
||
"подбери",
|
||
"подобрать",
|
||
"вкусную",
|
||
"вкусный",
|
||
"вкусное",
|
||
}
|
||
|
||
|
||
QUERY_HINTS = {
|
||
"шаурма": ["шаурма", "классика"],
|
||
"шаурмы": ["шаурма", "классика"],
|
||
"шаверма": ["шаурма", "классика"],
|
||
"шавуха": ["шаурма", "классика"],
|
||
"острый": ["халапеньо", "шрирача", "том", "ям"],
|
||
"острая": ["халапеньо", "шрирача", "том", "ям"],
|
||
"острое": ["халапеньо", "шрирача", "том", "ям"],
|
||
"острого": ["халапеньо", "шрирача", "том", "ям"],
|
||
"пикантный": ["халапеньо", "шрирача", "том", "ям"],
|
||
"сыр": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||
"сыром": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||
"сыра": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||
"сырный": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||
"сырная": ["сыр", "моцарелла", "пармезан", "крем", "чиз"],
|
||
"рыбный": ["лосось"],
|
||
"рыбная": ["лосось"],
|
||
"мясной": ["свинина", "курица", "ростбиф", "колбаски", "пепперони"],
|
||
"мясная": ["свинина", "курица", "ростбиф", "колбаски", "пепперони"],
|
||
}
|
||
|
||
CATEGORY_ALIASES = {
|
||
"шаурмы": "шаурма",
|
||
"шаверма": "шаурма",
|
||
"шавуха": "шаурма",
|
||
}
|
||
|
||
|
||
class MenuCatalog:
|
||
def __init__(self) -> None:
|
||
self.snapshot_path = Path(settings.menu_snapshot_path)
|
||
|
||
def exists(self) -> bool:
|
||
return self.snapshot_path.exists()
|
||
|
||
def load_snapshot(self) -> MenuSnapshot:
|
||
data = json.loads(self.snapshot_path.read_text(encoding="utf-8"))
|
||
return MenuSnapshot.model_validate(data)
|
||
|
||
def menu_documents(self) -> list[tuple[MenuItem, str]]:
|
||
if not self.exists():
|
||
return []
|
||
|
||
snapshot = self.load_snapshot()
|
||
documents: list[tuple[MenuItem, str]] = []
|
||
for item in snapshot.items:
|
||
text = " | ".join(
|
||
[
|
||
item.name,
|
||
item.category,
|
||
item.description,
|
||
", ".join(item.ingredients),
|
||
item.size or "",
|
||
item.price_label,
|
||
]
|
||
)
|
||
documents.append((item, text))
|
||
return documents
|
||
|
||
def items_map(self) -> dict[str, MenuItem]:
|
||
if not self.exists():
|
||
return {}
|
||
|
||
snapshot = self.load_snapshot()
|
||
return {item.item_id: item for item in snapshot.items}
|
||
|
||
def search(
|
||
self,
|
||
query: str = "",
|
||
max_price: int | None = None,
|
||
category: str | None = None,
|
||
must_include: list[str] | None = None,
|
||
must_not_include: list[str] | None = None,
|
||
limit: int = 5,
|
||
candidate_ids: list[str] | None = None,
|
||
semantic_ranks: dict[str, int] | None = None,
|
||
) -> list[dict[str, object]]:
|
||
if not self.exists():
|
||
return []
|
||
|
||
must_include = [value.lower() for value in (must_include or [])]
|
||
must_not_include = [value.lower() for value in (must_not_include or [])]
|
||
query_tokens = tokenize(query)
|
||
normalized_category = category.lower() if category else None
|
||
if normalized_category in CATEGORY_ALIASES:
|
||
normalized_category = CATEGORY_ALIASES[normalized_category]
|
||
hint_tokens = []
|
||
for token in query_tokens:
|
||
hint_tokens.extend(QUERY_HINTS.get(token, []))
|
||
candidate_set = set(candidate_ids or [])
|
||
semantic_ranks = semantic_ranks or {}
|
||
|
||
scored_items: list[tuple[int, MenuItem]] = []
|
||
for item, text in self.menu_documents():
|
||
if candidate_set and item.item_id not in candidate_set:
|
||
continue
|
||
|
||
lowered = text.lower()
|
||
|
||
if normalized_category and item.category.lower() != normalized_category:
|
||
continue
|
||
if max_price is not None and item.price is not None and item.price > max_price:
|
||
continue
|
||
if max_price is not None and item.price is None:
|
||
continue
|
||
if any(value not in lowered for value in must_include):
|
||
continue
|
||
if any(value in lowered for value in must_not_include):
|
||
continue
|
||
|
||
score = 0
|
||
for token in query_tokens:
|
||
if token in lowered:
|
||
score += 3
|
||
if token in item.name.lower():
|
||
score += 5
|
||
|
||
for token in hint_tokens:
|
||
if token in lowered:
|
||
score += 6
|
||
if token == item.category.lower():
|
||
score += 8
|
||
|
||
for token in must_include:
|
||
if token in lowered:
|
||
score += 4
|
||
|
||
if item.item_id in semantic_ranks:
|
||
score += max(0, 20 - semantic_ranks[item.item_id])
|
||
|
||
if not query_tokens and not must_include and category:
|
||
score += 1
|
||
|
||
scored_items.append((score, item))
|
||
|
||
scored_items.sort(
|
||
key=lambda row: (
|
||
row[0],
|
||
-(row[1].price or 0),
|
||
row[1].name,
|
||
),
|
||
reverse=True,
|
||
)
|
||
|
||
results: list[dict[str, object]] = []
|
||
for score, item in scored_items[:limit]:
|
||
results.append(
|
||
{
|
||
"item_id": item.item_id,
|
||
"name": item.name,
|
||
"category": item.category,
|
||
"description": item.description,
|
||
"ingredients": item.ingredients,
|
||
"price": item.price,
|
||
"price_label": item.price_label,
|
||
"size": item.size,
|
||
"photo_url": item.photo_url,
|
||
"source_url": item.source_url,
|
||
"score": score,
|
||
}
|
||
)
|
||
|
||
return results
|