first commit
This commit is contained in:
@@ -0,0 +1,723 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
from decouple import config
|
||||
from sqlalchemy import delete, func, select, text, update
|
||||
|
||||
from shared.engine import create_async_engine, get_session_maker
|
||||
from shared.models import (
|
||||
Admin,
|
||||
BaseModel,
|
||||
Blacklist,
|
||||
Consultation,
|
||||
LawChunk,
|
||||
LawSource,
|
||||
Message,
|
||||
RagQuery,
|
||||
Setting,
|
||||
User,
|
||||
)
|
||||
from shared.types import JSONDict
|
||||
|
||||
|
||||
def resolve_database_url() -> str:
|
||||
database_url = config("DATABASE_URL", default=None)
|
||||
if database_url:
|
||||
return database_url
|
||||
|
||||
env_values = load_env_file_values(
|
||||
Path(__file__).resolve().parent.parent / "postgres.env"
|
||||
)
|
||||
|
||||
return (
|
||||
f"postgresql+asyncpg://{config('POSTGRES_USER', default=env_values.get('POSTGRES_USER'))}:"
|
||||
f"{config('POSTGRES_PASSWORD', default=env_values.get('POSTGRES_PASSWORD'))}"
|
||||
f"@{config('POSTGRES_HOST', default=env_values.get('POSTGRES_HOST'))}:"
|
||||
f"{config('POSTGRES_PORT', default=env_values.get('POSTGRES_PORT'))}/"
|
||||
f"{config('POSTGRES_DB', default=env_values.get('POSTGRES_DB'))}"
|
||||
)
|
||||
|
||||
|
||||
def load_env_file_values(path: Path) -> dict[str, str]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
|
||||
values: dict[str, str] = {}
|
||||
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
values[key.strip()] = value.strip().strip("'").strip('"')
|
||||
return values
|
||||
|
||||
|
||||
def normalize_law_types_arg(
|
||||
law_types: list[str] | str | None,
|
||||
) -> list[str] | None:
|
||||
if law_types is None:
|
||||
return None
|
||||
if isinstance(law_types, str):
|
||||
return [law_types]
|
||||
|
||||
normalized = [
|
||||
item.strip()
|
||||
for item in law_types
|
||||
if isinstance(item, str) and item.strip()
|
||||
]
|
||||
return normalized or None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SourceUpsertDecision:
|
||||
action: str
|
||||
should_replace_chunks: bool
|
||||
|
||||
|
||||
def classify_source_update(
|
||||
existing_version_hash: str | None, new_version_hash: str
|
||||
) -> SourceUpsertDecision:
|
||||
if existing_version_hash is None:
|
||||
return SourceUpsertDecision(action="create", should_replace_chunks=True)
|
||||
if existing_version_hash == new_version_hash:
|
||||
return SourceUpsertDecision(action="reuse", should_replace_chunks=False)
|
||||
return SourceUpsertDecision(action="supersede", should_replace_chunks=True)
|
||||
|
||||
|
||||
class ORM:
|
||||
def __init__(self, database_url: str | None = None):
|
||||
self.database_url = database_url or resolve_database_url()
|
||||
self.async_engine = create_async_engine(url=self.database_url)
|
||||
self.session_maker = get_session_maker(self.async_engine)
|
||||
|
||||
async def init_schema(self) -> None:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(BaseModel.metadata.create_all)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE users "
|
||||
"ADD COLUMN IF NOT EXISTS country TEXT NOT NULL DEFAULT 'Россия'"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text("ALTER TABLE users ADD COLUMN IF NOT EXISTS region TEXT")
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE users "
|
||||
"ADD COLUMN IF NOT EXISTS user_type TEXT NOT NULL DEFAULT 'physical_person'"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE users "
|
||||
"ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT now()"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"UPDATE users "
|
||||
"SET country = COALESCE(country, 'Россия'), "
|
||||
"user_type = COALESCE(user_type, 'physical_person'), "
|
||||
"updated_at = COALESCE(updated_at, register_date, now())"
|
||||
)
|
||||
)
|
||||
|
||||
async def proceed_schemas(self) -> None:
|
||||
await self.init_schema()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.async_engine.dispose()
|
||||
|
||||
async def get_active_law_source_by_url(self, source_url: str) -> LawSource | None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
result = await session.scalars(
|
||||
select(LawSource)
|
||||
.where(
|
||||
LawSource.source_url == source_url,
|
||||
LawSource.is_active.is_(True),
|
||||
)
|
||||
.order_by(LawSource.loaded_at.desc())
|
||||
)
|
||||
return result.first()
|
||||
|
||||
async def upsert_law_source(self, payload: JSONDict) -> tuple[LawSource, bool]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
existing = await session.scalars(
|
||||
select(LawSource)
|
||||
.where(
|
||||
LawSource.source_url == payload["source_url"],
|
||||
LawSource.is_active.is_(True),
|
||||
)
|
||||
.order_by(LawSource.loaded_at.desc())
|
||||
)
|
||||
active = existing.first()
|
||||
decision = classify_source_update(
|
||||
active.version_hash if active else None,
|
||||
payload["version_hash"],
|
||||
)
|
||||
|
||||
if decision.action == "reuse":
|
||||
return active, False
|
||||
|
||||
if active is not None:
|
||||
active.is_active = False
|
||||
await session.flush()
|
||||
|
||||
source = LawSource(**payload)
|
||||
session.add(source)
|
||||
await session.flush()
|
||||
return source, True
|
||||
|
||||
async def replace_law_chunks(
|
||||
self, source_id: int, chunks: Iterable[JSONDict]
|
||||
) -> int:
|
||||
chunk_rows = list(chunks)
|
||||
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(LawChunk).where(LawChunk.source_id == source_id)
|
||||
)
|
||||
|
||||
for row in chunk_rows:
|
||||
session.add(
|
||||
LawChunk(
|
||||
source_id=source_id,
|
||||
chunk_index=row["chunk_index"],
|
||||
article_number=row.get("article_number"),
|
||||
article_title=row.get("article_title"),
|
||||
chunk_text=row["chunk_text"],
|
||||
chunk_metadata=row["metadata"],
|
||||
created_at=row.get(
|
||||
"created_at", datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
await session.flush()
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE law_chunks "
|
||||
"SET tsv = to_tsvector('russian', chunk_text) "
|
||||
"WHERE source_id = :source_id"
|
||||
),
|
||||
{"source_id": source_id},
|
||||
)
|
||||
|
||||
return len(chunk_rows)
|
||||
|
||||
async def list_active_law_sources(
|
||||
self, law_types: list[str] | None = None
|
||||
) -> list[LawSource]:
|
||||
law_types = normalize_law_types_arg(law_types)
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = select(LawSource).where(LawSource.is_active.is_(True))
|
||||
if law_types:
|
||||
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
||||
|
||||
result = await session.scalars(stmt.order_by(LawSource.title))
|
||||
return result.all()
|
||||
|
||||
async def get_law_source(self, source_id: int) -> LawSource | None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
return await session.get(LawSource, source_id)
|
||||
|
||||
async def get_chunks_by_source(self, source_id: int) -> list[LawChunk]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
result = await session.scalars(
|
||||
select(LawChunk)
|
||||
.where(LawChunk.source_id == source_id)
|
||||
.order_by(LawChunk.chunk_index)
|
||||
)
|
||||
return result.all()
|
||||
|
||||
async def get_chunks_count_by_source(self, source_id: int) -> int:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
result = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(LawChunk)
|
||||
.where(LawChunk.source_id == source_id)
|
||||
)
|
||||
return int(result or 0)
|
||||
|
||||
async def list_chunks_for_indexing(
|
||||
self,
|
||||
source_ids: list[int] | None = None,
|
||||
law_types: list[str] | None = None,
|
||||
active_only: bool = True,
|
||||
) -> list[JSONDict]:
|
||||
law_types = normalize_law_types_arg(law_types)
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
select(LawChunk, LawSource)
|
||||
.join(LawSource, LawSource.id == LawChunk.source_id)
|
||||
.order_by(LawSource.id, LawChunk.chunk_index)
|
||||
)
|
||||
if active_only:
|
||||
stmt = stmt.where(LawSource.is_active.is_(True))
|
||||
if source_ids:
|
||||
stmt = stmt.where(LawSource.id.in_(source_ids))
|
||||
if law_types:
|
||||
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
||||
|
||||
rows = await session.execute(stmt)
|
||||
payloads: list[JSONDict] = []
|
||||
for chunk, source in rows.all():
|
||||
payloads.append(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"source_id": source.id,
|
||||
"source_title": source.title,
|
||||
"source_url": source.source_url,
|
||||
"law_type": source.law_type,
|
||||
"jurisdiction": source.jurisdiction,
|
||||
"version_hash": source.version_hash,
|
||||
"article_number": chunk.article_number,
|
||||
"article_title": chunk.article_title,
|
||||
"chunk_text": chunk.chunk_text,
|
||||
"metadata": chunk.chunk_metadata,
|
||||
}
|
||||
)
|
||||
return payloads
|
||||
|
||||
async def get_law_chunks_with_sources_by_ids(
|
||||
self,
|
||||
chunk_ids: list[int],
|
||||
law_types: list[str] | None = None,
|
||||
jurisdiction: str = "RU",
|
||||
) -> list[JSONDict]:
|
||||
law_types = normalize_law_types_arg(law_types)
|
||||
if not chunk_ids:
|
||||
return []
|
||||
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
select(LawChunk, LawSource)
|
||||
.join(LawSource, LawSource.id == LawChunk.source_id)
|
||||
.where(
|
||||
LawChunk.id.in_(chunk_ids),
|
||||
LawSource.is_active.is_(True),
|
||||
LawSource.jurisdiction == jurisdiction,
|
||||
)
|
||||
)
|
||||
if law_types:
|
||||
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
||||
|
||||
rows = await session.execute(stmt)
|
||||
by_id: dict[int, JSONDict] = {}
|
||||
for chunk, source in rows.all():
|
||||
by_id[chunk.id] = {
|
||||
"chunk_id": chunk.id,
|
||||
"source_id": source.id,
|
||||
"source_title": source.title,
|
||||
"source_url": source.source_url,
|
||||
"law_type": source.law_type,
|
||||
"jurisdiction": source.jurisdiction,
|
||||
"article_number": chunk.article_number,
|
||||
"article_title": chunk.article_title,
|
||||
"chunk_text": chunk.chunk_text,
|
||||
"metadata": chunk.chunk_metadata,
|
||||
}
|
||||
|
||||
return [by_id[chunk_id] for chunk_id in chunk_ids if chunk_id in by_id]
|
||||
|
||||
async def search_law_chunks_full_text(
|
||||
self,
|
||||
query: str,
|
||||
law_types: list[str] | None = None,
|
||||
jurisdiction: str = "RU",
|
||||
limit: int = 20,
|
||||
) -> list[JSONDict]:
|
||||
law_types = normalize_law_types_arg(law_types)
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
ts_query = func.plainto_tsquery("russian", query)
|
||||
score = func.ts_rank(LawChunk.tsv, ts_query).label("score")
|
||||
stmt = (
|
||||
select(LawChunk, LawSource, score)
|
||||
.join(LawSource, LawSource.id == LawChunk.source_id)
|
||||
.where(
|
||||
LawSource.jurisdiction == jurisdiction,
|
||||
LawSource.is_active.is_(True),
|
||||
LawChunk.tsv.op("@@")(ts_query),
|
||||
)
|
||||
.order_by(score.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
if law_types:
|
||||
stmt = stmt.where(LawSource.law_type.in_(law_types))
|
||||
|
||||
rows = await session.execute(stmt)
|
||||
payloads: list[JSONDict] = []
|
||||
for chunk, source, result_score in rows.all():
|
||||
payloads.append(
|
||||
{
|
||||
"chunk_id": chunk.id,
|
||||
"source_id": source.id,
|
||||
"source_title": source.title,
|
||||
"source_url": source.source_url,
|
||||
"law_type": source.law_type,
|
||||
"jurisdiction": source.jurisdiction,
|
||||
"article_number": chunk.article_number,
|
||||
"article_title": chunk.article_title,
|
||||
"chunk_text": chunk.chunk_text,
|
||||
"metadata": chunk.chunk_metadata,
|
||||
"score": float(result_score or 0.0),
|
||||
}
|
||||
)
|
||||
return payloads
|
||||
|
||||
async def create_consultation(
|
||||
self,
|
||||
user_id: int,
|
||||
category: str,
|
||||
title: str | None = None,
|
||||
region: str | None = None,
|
||||
status: str = "active",
|
||||
) -> Consultation:
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
consultation = Consultation(
|
||||
user_id=user_id,
|
||||
category=category,
|
||||
title=title,
|
||||
region=region,
|
||||
status=status,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(consultation)
|
||||
await session.flush()
|
||||
return consultation
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
consultation_id: int,
|
||||
role: str,
|
||||
content: str,
|
||||
sources_json: Any | None = None,
|
||||
) -> Message:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
message = Message(
|
||||
consultation_id=consultation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
sources_json=sources_json,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(message)
|
||||
await session.execute(
|
||||
update(Consultation)
|
||||
.where(Consultation.id == consultation_id)
|
||||
.values(updated_at=datetime.now(timezone.utc))
|
||||
)
|
||||
await session.flush()
|
||||
return message
|
||||
|
||||
async def create_rag_query(
|
||||
self,
|
||||
consultation_id: int | None,
|
||||
user_message_id: int | None,
|
||||
generated_queries: list[str],
|
||||
retrieved_chunks: list[JSONDict],
|
||||
) -> RagQuery:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
rag_query = RagQuery(
|
||||
consultation_id=consultation_id,
|
||||
user_message_id=user_message_id,
|
||||
generated_queries=generated_queries,
|
||||
retrieved_chunks=retrieved_chunks,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(rag_query)
|
||||
await session.flush()
|
||||
return rag_query
|
||||
|
||||
async def is_user_exists(self, user_id: int) -> bool:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.execute(
|
||||
select(User.user_id).where(User.user_id == user_id)
|
||||
)
|
||||
return query.one_or_none() is not None
|
||||
|
||||
async def create_user(
|
||||
self, user_id: int, username: str | None, fullname: str, register_date: datetime
|
||||
) -> int | None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
existing = await session.get(User, user_id)
|
||||
if existing is not None:
|
||||
existing.username = username
|
||||
existing.fullname = fullname
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
return existing.user_id
|
||||
|
||||
user = User(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
fullname=fullname,
|
||||
country="Россия",
|
||||
user_type="physical_person",
|
||||
register_date=register_date,
|
||||
updated_at=register_date,
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
return user.user_id
|
||||
|
||||
async def set_users_field(
|
||||
self, user_id: int, field: str, value: int | str | bool
|
||||
) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.user_id == user_id)
|
||||
.values(
|
||||
{
|
||||
getattr(User, field): value,
|
||||
User.updated_at: datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def set_user_region(self, user_id: int, region: str) -> None:
|
||||
await self.set_users_field(user_id, "region", region)
|
||||
|
||||
async def set_user_type(self, user_id: int, user_type: str) -> None:
|
||||
await self.set_users_field(user_id, "user_type", user_type)
|
||||
|
||||
async def get_user(self, user_id: int) -> User | None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(
|
||||
select(User).where(User.user_id == user_id)
|
||||
)
|
||||
return query.one_or_none()
|
||||
|
||||
async def get_all_users(self) -> list[User]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(select(User))
|
||||
return query.all()
|
||||
|
||||
async def get_users_count(self) -> int:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalar(select(func.count()).select_from(User))
|
||||
return int(query or 0)
|
||||
|
||||
async def get_all_user_ids(self) -> list[int]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(select(User.user_id))
|
||||
return query.all()
|
||||
|
||||
async def delete_user(self, user_id: int) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(delete(User).where(User.user_id == user_id))
|
||||
|
||||
async def get_consultation(
|
||||
self, consultation_id: int, user_id: int | None = None
|
||||
) -> Consultation | None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = select(Consultation).where(Consultation.id == consultation_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(Consultation.user_id == user_id)
|
||||
query = await session.scalars(stmt)
|
||||
return query.one_or_none()
|
||||
|
||||
async def list_user_consultations(
|
||||
self, user_id: int, limit: int = 20
|
||||
) -> list[Consultation]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(
|
||||
select(Consultation)
|
||||
.where(Consultation.user_id == user_id)
|
||||
.order_by(Consultation.updated_at.desc(), Consultation.id.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return query.all()
|
||||
|
||||
async def get_consultation_messages(
|
||||
self,
|
||||
consultation_id: int,
|
||||
limit: int | None = None,
|
||||
) -> list[Message]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.consultation_id == consultation_id)
|
||||
.order_by(Message.created_at.asc(), Message.id.asc())
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
query = await session.scalars(stmt)
|
||||
return query.all()
|
||||
|
||||
async def delete_consultation(self, consultation_id: int, user_id: int) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Consultation).where(
|
||||
Consultation.id == consultation_id,
|
||||
Consultation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def count_user_consultations_since(
|
||||
self, user_id: int, since: datetime
|
||||
) -> int:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Consultation)
|
||||
.where(
|
||||
Consultation.user_id == user_id,
|
||||
Consultation.created_at >= since,
|
||||
)
|
||||
)
|
||||
return int(query or 0)
|
||||
|
||||
async def count_user_messages_in_consultation(self, consultation_id: int) -> int:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Message)
|
||||
.where(
|
||||
Message.consultation_id == consultation_id,
|
||||
Message.role == "user",
|
||||
)
|
||||
)
|
||||
return int(query or 0)
|
||||
|
||||
async def is_admin_exists(self, user_id: int) -> bool:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.execute(
|
||||
select(Admin.user_id).where(Admin.user_id == user_id)
|
||||
)
|
||||
return query.one_or_none() is not None
|
||||
|
||||
async def create_admin(self, user_id: int, username: str, fullname: str) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
admin = Admin(user_id=user_id, username=username, fullname=fullname)
|
||||
await session.merge(admin)
|
||||
|
||||
async def get_admin(self, user_id: int) -> Admin | None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(
|
||||
select(Admin).where(Admin.user_id == user_id)
|
||||
)
|
||||
return query.one_or_none()
|
||||
|
||||
async def get_all_admins(self) -> list[Admin]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(select(Admin))
|
||||
return query.all()
|
||||
|
||||
async def delete_admin(self, user_id: int) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(delete(Admin).where(Admin.user_id == user_id))
|
||||
|
||||
async def set_admin_field(
|
||||
self, user_id: int, field: str, value: int | str | bool
|
||||
) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
update(Admin)
|
||||
.where(Admin.user_id == user_id)
|
||||
.values({getattr(Admin, field): value})
|
||||
)
|
||||
|
||||
async def is_setting_exists(self, name: str) -> bool:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.execute(
|
||||
select(Setting).where(Setting.name == name)
|
||||
)
|
||||
return query.one_or_none() is not None
|
||||
|
||||
async def create_setting(self, name: str, value: Any) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
setting = Setting(name=name, value=value)
|
||||
await session.merge(setting)
|
||||
|
||||
async def init_settings(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_setting_value(self, name: str) -> Any:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(
|
||||
select(Setting.value).where(Setting.name == name)
|
||||
)
|
||||
return query.one_or_none()
|
||||
|
||||
async def update_setting_value(self, name: str, value: dict | list) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
update(Setting)
|
||||
.where(Setting.name == name)
|
||||
.values({getattr(Setting, "value"): value})
|
||||
)
|
||||
|
||||
async def is_blacklisted(self, user_id: int) -> bool:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.execute(
|
||||
select(Blacklist).where(Blacklist.user_id == user_id)
|
||||
)
|
||||
return query.one_or_none() is not None
|
||||
|
||||
async def create_blacklist(self, user_id: int) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
blacklist = Blacklist(user_id=user_id)
|
||||
await session.merge(blacklist)
|
||||
|
||||
async def get_all_blacklist(self) -> list[int]:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
query = await session.scalars(
|
||||
select(Blacklist.user_id).order_by(Blacklist.user_id)
|
||||
)
|
||||
return query.all()
|
||||
|
||||
async def delete_blacklist(self, user_id: int) -> None:
|
||||
async with self.session_maker() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Blacklist).where(Blacklist.user_id == user_id)
|
||||
)
|
||||
Reference in New Issue
Block a user