724 lines
28 KiB
Python
724 lines
28 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Iterable
|
|
|
|
from decouple import config
|
|
from sqlalchemy import delete, func, select, text, update
|
|
|
|
from shared.engine import create_async_engine, get_session_maker
|
|
from shared.models import (
|
|
Admin,
|
|
BaseModel,
|
|
Blacklist,
|
|
Consultation,
|
|
LawChunk,
|
|
LawSource,
|
|
Message,
|
|
RagQuery,
|
|
Setting,
|
|
User,
|
|
)
|
|
from shared.types import JSONDict
|
|
|
|
|
|
def resolve_database_url() -> str:
|
|
database_url = config("DATABASE_URL", default=None)
|
|
if database_url:
|
|
return database_url
|
|
|
|
env_values = load_env_file_values(
|
|
Path(__file__).resolve().parent.parent / "postgres.env"
|
|
)
|
|
|
|
return (
|
|
f"postgresql+asyncpg://{config('POSTGRES_USER', default=env_values.get('POSTGRES_USER'))}:"
|
|
f"{config('POSTGRES_PASSWORD', default=env_values.get('POSTGRES_PASSWORD'))}"
|
|
f"@{config('POSTGRES_HOST', default=env_values.get('POSTGRES_HOST'))}:"
|
|
f"{config('POSTGRES_PORT', default=env_values.get('POSTGRES_PORT'))}/"
|
|
f"{config('POSTGRES_DB', default=env_values.get('POSTGRES_DB'))}"
|
|
)
|
|
|
|
|
|
def load_env_file_values(path: Path) -> dict[str, str]:
|
|
if not path.exists():
|
|
return {}
|
|
|
|
values: dict[str, str] = {}
|
|
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
|
line = raw_line.strip()
|
|
if not line or line.startswith("#") or "=" not in line:
|
|
continue
|
|
key, value = line.split("=", 1)
|
|
values[key.strip()] = value.strip().strip("'").strip('"')
|
|
return values
|
|
|
|
|
|
def normalize_law_types_arg(
|
|
law_types: list[str] | str | None,
|
|
) -> list[str] | None:
|
|
if law_types is None:
|
|
return None
|
|
if isinstance(law_types, str):
|
|
return [law_types]
|
|
|
|
normalized = [
|
|
item.strip()
|
|
for item in law_types
|
|
if isinstance(item, str) and item.strip()
|
|
]
|
|
return normalized or None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SourceUpsertDecision:
|
|
action: str
|
|
should_replace_chunks: bool
|
|
|
|
|
|
def classify_source_update(
|
|
existing_version_hash: str | None, new_version_hash: str
|
|
) -> SourceUpsertDecision:
|
|
if existing_version_hash is None:
|
|
return SourceUpsertDecision(action="create", should_replace_chunks=True)
|
|
if existing_version_hash == new_version_hash:
|
|
return SourceUpsertDecision(action="reuse", should_replace_chunks=False)
|
|
return SourceUpsertDecision(action="supersede", should_replace_chunks=True)
|
|
|
|
|
|
class ORM:
|
|
def __init__(self, database_url: str | None = None):
|
|
self.database_url = database_url or resolve_database_url()
|
|
self.async_engine = create_async_engine(url=self.database_url)
|
|
self.session_maker = get_session_maker(self.async_engine)
|
|
|
|
async def init_schema(self) -> None:
|
|
async with self.async_engine.begin() as conn:
|
|
await conn.run_sync(BaseModel.metadata.create_all)
|
|
await conn.execute(
|
|
text(
|
|
"ALTER TABLE users "
|
|
"ADD COLUMN IF NOT EXISTS country TEXT NOT NULL DEFAULT 'Россия'"
|
|
)
|
|
)
|
|
await conn.execute(
|
|
text("ALTER TABLE users ADD COLUMN IF NOT EXISTS region TEXT")
|
|
)
|
|
await conn.execute(
|
|
text(
|
|
"ALTER TABLE users "
|
|
"ADD COLUMN IF NOT EXISTS user_type TEXT NOT NULL DEFAULT 'physical_person'"
|
|
)
|
|
)
|
|
await conn.execute(
|
|
text(
|
|
"ALTER TABLE users "
|
|
"ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT now()"
|
|
)
|
|
)
|
|
await conn.execute(
|
|
text(
|
|
"UPDATE users "
|
|
"SET country = COALESCE(country, 'Россия'), "
|
|
"user_type = COALESCE(user_type, 'physical_person'), "
|
|
"updated_at = COALESCE(updated_at, register_date, now())"
|
|
)
|
|
)
|
|
|
|
async def proceed_schemas(self) -> None:
|
|
await self.init_schema()
|
|
|
|
async def close(self) -> None:
|
|
await self.async_engine.dispose()
|
|
|
|
async def get_active_law_source_by_url(self, source_url: str) -> LawSource | None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
result = await session.scalars(
|
|
select(LawSource)
|
|
.where(
|
|
LawSource.source_url == source_url,
|
|
LawSource.is_active.is_(True),
|
|
)
|
|
.order_by(LawSource.loaded_at.desc())
|
|
)
|
|
return result.first()
|
|
|
|
async def upsert_law_source(self, payload: JSONDict) -> tuple[LawSource, bool]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
existing = await session.scalars(
|
|
select(LawSource)
|
|
.where(
|
|
LawSource.source_url == payload["source_url"],
|
|
LawSource.is_active.is_(True),
|
|
)
|
|
.order_by(LawSource.loaded_at.desc())
|
|
)
|
|
active = existing.first()
|
|
decision = classify_source_update(
|
|
active.version_hash if active else None,
|
|
payload["version_hash"],
|
|
)
|
|
|
|
if decision.action == "reuse":
|
|
return active, False
|
|
|
|
if active is not None:
|
|
active.is_active = False
|
|
await session.flush()
|
|
|
|
source = LawSource(**payload)
|
|
session.add(source)
|
|
await session.flush()
|
|
return source, True
|
|
|
|
async def replace_law_chunks(
|
|
self, source_id: int, chunks: Iterable[JSONDict]
|
|
) -> int:
|
|
chunk_rows = list(chunks)
|
|
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(LawChunk).where(LawChunk.source_id == source_id)
|
|
)
|
|
|
|
for row in chunk_rows:
|
|
session.add(
|
|
LawChunk(
|
|
source_id=source_id,
|
|
chunk_index=row["chunk_index"],
|
|
article_number=row.get("article_number"),
|
|
article_title=row.get("article_title"),
|
|
chunk_text=row["chunk_text"],
|
|
chunk_metadata=row["metadata"],
|
|
created_at=row.get(
|
|
"created_at", datetime.now(timezone.utc)
|
|
),
|
|
)
|
|
)
|
|
|
|
await session.flush()
|
|
await session.execute(
|
|
text(
|
|
"UPDATE law_chunks "
|
|
"SET tsv = to_tsvector('russian', chunk_text) "
|
|
"WHERE source_id = :source_id"
|
|
),
|
|
{"source_id": source_id},
|
|
)
|
|
|
|
return len(chunk_rows)
|
|
|
|
async def list_active_law_sources(
|
|
self, law_types: list[str] | None = None
|
|
) -> list[LawSource]:
|
|
law_types = normalize_law_types_arg(law_types)
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
stmt = select(LawSource).where(LawSource.is_active.is_(True))
|
|
if law_types:
|
|
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
|
|
|
result = await session.scalars(stmt.order_by(LawSource.title))
|
|
return result.all()
|
|
|
|
async def get_law_source(self, source_id: int) -> LawSource | None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
return await session.get(LawSource, source_id)
|
|
|
|
async def get_chunks_by_source(self, source_id: int) -> list[LawChunk]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
result = await session.scalars(
|
|
select(LawChunk)
|
|
.where(LawChunk.source_id == source_id)
|
|
.order_by(LawChunk.chunk_index)
|
|
)
|
|
return result.all()
|
|
|
|
async def get_chunks_count_by_source(self, source_id: int) -> int:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
result = await session.scalar(
|
|
select(func.count())
|
|
.select_from(LawChunk)
|
|
.where(LawChunk.source_id == source_id)
|
|
)
|
|
return int(result or 0)
|
|
|
|
async def list_chunks_for_indexing(
|
|
self,
|
|
source_ids: list[int] | None = None,
|
|
law_types: list[str] | None = None,
|
|
active_only: bool = True,
|
|
) -> list[JSONDict]:
|
|
law_types = normalize_law_types_arg(law_types)
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
stmt = (
|
|
select(LawChunk, LawSource)
|
|
.join(LawSource, LawSource.id == LawChunk.source_id)
|
|
.order_by(LawSource.id, LawChunk.chunk_index)
|
|
)
|
|
if active_only:
|
|
stmt = stmt.where(LawSource.is_active.is_(True))
|
|
if source_ids:
|
|
stmt = stmt.where(LawSource.id.in_(source_ids))
|
|
if law_types:
|
|
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
|
|
|
rows = await session.execute(stmt)
|
|
payloads: list[JSONDict] = []
|
|
for chunk, source in rows.all():
|
|
payloads.append(
|
|
{
|
|
"chunk_id": chunk.id,
|
|
"source_id": source.id,
|
|
"source_title": source.title,
|
|
"source_url": source.source_url,
|
|
"law_type": source.law_type,
|
|
"jurisdiction": source.jurisdiction,
|
|
"version_hash": source.version_hash,
|
|
"article_number": chunk.article_number,
|
|
"article_title": chunk.article_title,
|
|
"chunk_text": chunk.chunk_text,
|
|
"metadata": chunk.chunk_metadata,
|
|
}
|
|
)
|
|
return payloads
|
|
|
|
async def get_law_chunks_with_sources_by_ids(
|
|
self,
|
|
chunk_ids: list[int],
|
|
law_types: list[str] | None = None,
|
|
jurisdiction: str = "RU",
|
|
) -> list[JSONDict]:
|
|
law_types = normalize_law_types_arg(law_types)
|
|
if not chunk_ids:
|
|
return []
|
|
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
stmt = (
|
|
select(LawChunk, LawSource)
|
|
.join(LawSource, LawSource.id == LawChunk.source_id)
|
|
.where(
|
|
LawChunk.id.in_(chunk_ids),
|
|
LawSource.is_active.is_(True),
|
|
LawSource.jurisdiction == jurisdiction,
|
|
)
|
|
)
|
|
if law_types:
|
|
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
|
|
|
rows = await session.execute(stmt)
|
|
by_id: dict[int, JSONDict] = {}
|
|
for chunk, source in rows.all():
|
|
by_id[chunk.id] = {
|
|
"chunk_id": chunk.id,
|
|
"source_id": source.id,
|
|
"source_title": source.title,
|
|
"source_url": source.source_url,
|
|
"law_type": source.law_type,
|
|
"jurisdiction": source.jurisdiction,
|
|
"article_number": chunk.article_number,
|
|
"article_title": chunk.article_title,
|
|
"chunk_text": chunk.chunk_text,
|
|
"metadata": chunk.chunk_metadata,
|
|
}
|
|
|
|
return [by_id[chunk_id] for chunk_id in chunk_ids if chunk_id in by_id]
|
|
|
|
async def search_law_chunks_full_text(
|
|
self,
|
|
query: str,
|
|
law_types: list[str] | None = None,
|
|
jurisdiction: str = "RU",
|
|
limit: int = 20,
|
|
) -> list[JSONDict]:
|
|
law_types = normalize_law_types_arg(law_types)
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
ts_query = func.plainto_tsquery("russian", query)
|
|
score = func.ts_rank(LawChunk.tsv, ts_query).label("score")
|
|
stmt = (
|
|
select(LawChunk, LawSource, score)
|
|
.join(LawSource, LawSource.id == LawChunk.source_id)
|
|
.where(
|
|
LawSource.jurisdiction == jurisdiction,
|
|
LawSource.is_active.is_(True),
|
|
LawChunk.tsv.op("@@")(ts_query),
|
|
)
|
|
.order_by(score.desc())
|
|
.limit(limit)
|
|
)
|
|
if law_types:
|
|
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
|
|
|
rows = await session.execute(stmt)
|
|
payloads: list[JSONDict] = []
|
|
for chunk, source, result_score in rows.all():
|
|
payloads.append(
|
|
{
|
|
"chunk_id": chunk.id,
|
|
"source_id": source.id,
|
|
"source_title": source.title,
|
|
"source_url": source.source_url,
|
|
"law_type": source.law_type,
|
|
"jurisdiction": source.jurisdiction,
|
|
"article_number": chunk.article_number,
|
|
"article_title": chunk.article_title,
|
|
"chunk_text": chunk.chunk_text,
|
|
"metadata": chunk.chunk_metadata,
|
|
"score": float(result_score or 0.0),
|
|
}
|
|
)
|
|
return payloads
|
|
|
|
async def create_consultation(
|
|
self,
|
|
user_id: int,
|
|
category: str,
|
|
title: str | None = None,
|
|
region: str | None = None,
|
|
status: str = "active",
|
|
) -> Consultation:
|
|
now = datetime.now(timezone.utc)
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
consultation = Consultation(
|
|
user_id=user_id,
|
|
category=category,
|
|
title=title,
|
|
region=region,
|
|
status=status,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
session.add(consultation)
|
|
await session.flush()
|
|
return consultation
|
|
|
|
async def create_message(
|
|
self,
|
|
consultation_id: int,
|
|
role: str,
|
|
content: str,
|
|
sources_json: Any | None = None,
|
|
) -> Message:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
message = Message(
|
|
consultation_id=consultation_id,
|
|
role=role,
|
|
content=content,
|
|
sources_json=sources_json,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(message)
|
|
await session.execute(
|
|
update(Consultation)
|
|
.where(Consultation.id == consultation_id)
|
|
.values(updated_at=datetime.now(timezone.utc))
|
|
)
|
|
await session.flush()
|
|
return message
|
|
|
|
async def create_rag_query(
|
|
self,
|
|
consultation_id: int | None,
|
|
user_message_id: int | None,
|
|
generated_queries: list[str],
|
|
retrieved_chunks: list[JSONDict],
|
|
) -> RagQuery:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
rag_query = RagQuery(
|
|
consultation_id=consultation_id,
|
|
user_message_id=user_message_id,
|
|
generated_queries=generated_queries,
|
|
retrieved_chunks=retrieved_chunks,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(rag_query)
|
|
await session.flush()
|
|
return rag_query
|
|
|
|
async def is_user_exists(self, user_id: int) -> bool:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.execute(
|
|
select(User.user_id).where(User.user_id == user_id)
|
|
)
|
|
return query.one_or_none() is not None
|
|
|
|
async def create_user(
|
|
self, user_id: int, username: str | None, fullname: str, register_date: datetime
|
|
) -> int | None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
existing = await session.get(User, user_id)
|
|
if existing is not None:
|
|
existing.username = username
|
|
existing.fullname = fullname
|
|
existing.updated_at = datetime.now(timezone.utc)
|
|
await session.flush()
|
|
return existing.user_id
|
|
|
|
user = User(
|
|
user_id=user_id,
|
|
username=username,
|
|
fullname=fullname,
|
|
country="Россия",
|
|
user_type="physical_person",
|
|
register_date=register_date,
|
|
updated_at=register_date,
|
|
)
|
|
session.add(user)
|
|
await session.flush()
|
|
return user.user_id
|
|
|
|
async def set_users_field(
|
|
self, user_id: int, field: str, value: int | str | bool
|
|
) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
update(User)
|
|
.where(User.user_id == user_id)
|
|
.values(
|
|
{
|
|
getattr(User, field): value,
|
|
User.updated_at: datetime.now(timezone.utc),
|
|
}
|
|
)
|
|
)
|
|
|
|
async def set_user_region(self, user_id: int, region: str) -> None:
|
|
await self.set_users_field(user_id, "region", region)
|
|
|
|
async def set_user_type(self, user_id: int, user_type: str) -> None:
|
|
await self.set_users_field(user_id, "user_type", user_type)
|
|
|
|
async def get_user(self, user_id: int) -> User | None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(
|
|
select(User).where(User.user_id == user_id)
|
|
)
|
|
return query.one_or_none()
|
|
|
|
async def get_all_users(self) -> list[User]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(select(User))
|
|
return query.all()
|
|
|
|
async def get_users_count(self) -> int:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalar(select(func.count()).select_from(User))
|
|
return int(query or 0)
|
|
|
|
async def get_all_user_ids(self) -> list[int]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(select(User.user_id))
|
|
return query.all()
|
|
|
|
async def delete_user(self, user_id: int) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(delete(User).where(User.user_id == user_id))
|
|
|
|
async def get_consultation(
|
|
self, consultation_id: int, user_id: int | None = None
|
|
) -> Consultation | None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
stmt = select(Consultation).where(Consultation.id == consultation_id)
|
|
if user_id is not None:
|
|
stmt = stmt.where(Consultation.user_id == user_id)
|
|
query = await session.scalars(stmt)
|
|
return query.one_or_none()
|
|
|
|
async def list_user_consultations(
|
|
self, user_id: int, limit: int = 20
|
|
) -> list[Consultation]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(
|
|
select(Consultation)
|
|
.where(Consultation.user_id == user_id)
|
|
.order_by(Consultation.updated_at.desc(), Consultation.id.desc())
|
|
.limit(limit)
|
|
)
|
|
return query.all()
|
|
|
|
async def get_consultation_messages(
|
|
self,
|
|
consultation_id: int,
|
|
limit: int | None = None,
|
|
) -> list[Message]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
stmt = (
|
|
select(Message)
|
|
.where(Message.consultation_id == consultation_id)
|
|
.order_by(Message.created_at.asc(), Message.id.asc())
|
|
)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
query = await session.scalars(stmt)
|
|
return query.all()
|
|
|
|
async def delete_consultation(self, consultation_id: int, user_id: int) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(Consultation).where(
|
|
Consultation.id == consultation_id,
|
|
Consultation.user_id == user_id,
|
|
)
|
|
)
|
|
|
|
async def count_user_consultations_since(
|
|
self, user_id: int, since: datetime
|
|
) -> int:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalar(
|
|
select(func.count())
|
|
.select_from(Consultation)
|
|
.where(
|
|
Consultation.user_id == user_id,
|
|
Consultation.created_at >= since,
|
|
)
|
|
)
|
|
return int(query or 0)
|
|
|
|
async def count_user_messages_in_consultation(self, consultation_id: int) -> int:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalar(
|
|
select(func.count())
|
|
.select_from(Message)
|
|
.where(
|
|
Message.consultation_id == consultation_id,
|
|
Message.role == "user",
|
|
)
|
|
)
|
|
return int(query or 0)
|
|
|
|
async def is_admin_exists(self, user_id: int) -> bool:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.execute(
|
|
select(Admin.user_id).where(Admin.user_id == user_id)
|
|
)
|
|
return query.one_or_none() is not None
|
|
|
|
async def create_admin(self, user_id: int, username: str, fullname: str) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
admin = Admin(user_id=user_id, username=username, fullname=fullname)
|
|
await session.merge(admin)
|
|
|
|
async def get_admin(self, user_id: int) -> Admin | None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(
|
|
select(Admin).where(Admin.user_id == user_id)
|
|
)
|
|
return query.one_or_none()
|
|
|
|
async def get_all_admins(self) -> list[Admin]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(select(Admin))
|
|
return query.all()
|
|
|
|
async def delete_admin(self, user_id: int) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(delete(Admin).where(Admin.user_id == user_id))
|
|
|
|
async def set_admin_field(
|
|
self, user_id: int, field: str, value: int | str | bool
|
|
) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
update(Admin)
|
|
.where(Admin.user_id == user_id)
|
|
.values({getattr(Admin, field): value})
|
|
)
|
|
|
|
async def is_setting_exists(self, name: str) -> bool:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.execute(
|
|
select(Setting).where(Setting.name == name)
|
|
)
|
|
return query.one_or_none() is not None
|
|
|
|
async def create_setting(self, name: str, value: Any) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
setting = Setting(name=name, value=value)
|
|
await session.merge(setting)
|
|
|
|
async def init_settings(self) -> None:
|
|
return None
|
|
|
|
async def get_setting_value(self, name: str) -> Any:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(
|
|
select(Setting.value).where(Setting.name == name)
|
|
)
|
|
return query.one_or_none()
|
|
|
|
async def update_setting_value(self, name: str, value: dict | list) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
update(Setting)
|
|
.where(Setting.name == name)
|
|
.values({getattr(Setting, "value"): value})
|
|
)
|
|
|
|
async def is_blacklisted(self, user_id: int) -> bool:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.execute(
|
|
select(Blacklist).where(Blacklist.user_id == user_id)
|
|
)
|
|
return query.one_or_none() is not None
|
|
|
|
async def create_blacklist(self, user_id: int) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
blacklist = Blacklist(user_id=user_id)
|
|
await session.merge(blacklist)
|
|
|
|
async def get_all_blacklist(self) -> list[int]:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
query = await session.scalars(
|
|
select(Blacklist.user_id).order_by(Blacklist.user_id)
|
|
)
|
|
return query.all()
|
|
|
|
async def delete_blacklist(self, user_id: int) -> None:
|
|
async with self.session_maker() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
delete(Blacklist).where(Blacklist.user_id == user_id)
|
|
)
|