Files
LawBot/shared/repositories.py
T
2026-05-25 01:12:43 +03:00

724 lines
28 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable
from decouple import config
from sqlalchemy import delete, func, select, text, update
from shared.engine import create_async_engine, get_session_maker
from shared.models import (
Admin,
BaseModel,
Blacklist,
Consultation,
LawChunk,
LawSource,
Message,
RagQuery,
Setting,
User,
)
from shared.types import JSONDict
def resolve_database_url() -> str:
database_url = config("DATABASE_URL", default=None)
if database_url:
return database_url
env_values = load_env_file_values(
Path(__file__).resolve().parent.parent / "postgres.env"
)
return (
f"postgresql+asyncpg://{config('POSTGRES_USER', default=env_values.get('POSTGRES_USER'))}:"
f"{config('POSTGRES_PASSWORD', default=env_values.get('POSTGRES_PASSWORD'))}"
f"@{config('POSTGRES_HOST', default=env_values.get('POSTGRES_HOST'))}:"
f"{config('POSTGRES_PORT', default=env_values.get('POSTGRES_PORT'))}/"
f"{config('POSTGRES_DB', default=env_values.get('POSTGRES_DB'))}"
)
def load_env_file_values(path: Path) -> dict[str, str]:
if not path.exists():
return {}
values: dict[str, str] = {}
for raw_line in path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
values[key.strip()] = value.strip().strip("'").strip('"')
return values
def normalize_law_types_arg(
law_types: list[str] | str | None,
) -> list[str] | None:
if law_types is None:
return None
if isinstance(law_types, str):
return [law_types]
normalized = [
item.strip()
for item in law_types
if isinstance(item, str) and item.strip()
]
return normalized or None
@dataclass(frozen=True)
class SourceUpsertDecision:
action: str
should_replace_chunks: bool
def classify_source_update(
existing_version_hash: str | None, new_version_hash: str
) -> SourceUpsertDecision:
if existing_version_hash is None:
return SourceUpsertDecision(action="create", should_replace_chunks=True)
if existing_version_hash == new_version_hash:
return SourceUpsertDecision(action="reuse", should_replace_chunks=False)
return SourceUpsertDecision(action="supersede", should_replace_chunks=True)
class ORM:
def __init__(self, database_url: str | None = None):
self.database_url = database_url or resolve_database_url()
self.async_engine = create_async_engine(url=self.database_url)
self.session_maker = get_session_maker(self.async_engine)
async def init_schema(self) -> None:
async with self.async_engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.create_all)
await conn.execute(
text(
"ALTER TABLE users "
"ADD COLUMN IF NOT EXISTS country TEXT NOT NULL DEFAULT 'Россия'"
)
)
await conn.execute(
text("ALTER TABLE users ADD COLUMN IF NOT EXISTS region TEXT")
)
await conn.execute(
text(
"ALTER TABLE users "
"ADD COLUMN IF NOT EXISTS user_type TEXT NOT NULL DEFAULT 'physical_person'"
)
)
await conn.execute(
text(
"ALTER TABLE users "
"ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT now()"
)
)
await conn.execute(
text(
"UPDATE users "
"SET country = COALESCE(country, 'Россия'), "
"user_type = COALESCE(user_type, 'physical_person'), "
"updated_at = COALESCE(updated_at, register_date, now())"
)
)
async def proceed_schemas(self) -> None:
await self.init_schema()
async def close(self) -> None:
await self.async_engine.dispose()
async def get_active_law_source_by_url(self, source_url: str) -> LawSource | None:
async with self.session_maker() as session:
async with session.begin():
result = await session.scalars(
select(LawSource)
.where(
LawSource.source_url == source_url,
LawSource.is_active.is_(True),
)
.order_by(LawSource.loaded_at.desc())
)
return result.first()
async def upsert_law_source(self, payload: JSONDict) -> tuple[LawSource, bool]:
async with self.session_maker() as session:
async with session.begin():
existing = await session.scalars(
select(LawSource)
.where(
LawSource.source_url == payload["source_url"],
LawSource.is_active.is_(True),
)
.order_by(LawSource.loaded_at.desc())
)
active = existing.first()
decision = classify_source_update(
active.version_hash if active else None,
payload["version_hash"],
)
if decision.action == "reuse":
return active, False
if active is not None:
active.is_active = False
await session.flush()
source = LawSource(**payload)
session.add(source)
await session.flush()
return source, True
async def replace_law_chunks(
self, source_id: int, chunks: Iterable[JSONDict]
) -> int:
chunk_rows = list(chunks)
async with self.session_maker() as session:
async with session.begin():
await session.execute(
delete(LawChunk).where(LawChunk.source_id == source_id)
)
for row in chunk_rows:
session.add(
LawChunk(
source_id=source_id,
chunk_index=row["chunk_index"],
article_number=row.get("article_number"),
article_title=row.get("article_title"),
chunk_text=row["chunk_text"],
chunk_metadata=row["metadata"],
created_at=row.get(
"created_at", datetime.now(timezone.utc)
),
)
)
await session.flush()
await session.execute(
text(
"UPDATE law_chunks "
"SET tsv = to_tsvector('russian', chunk_text) "
"WHERE source_id = :source_id"
),
{"source_id": source_id},
)
return len(chunk_rows)
async def list_active_law_sources(
self, law_types: list[str] | None = None
) -> list[LawSource]:
law_types = normalize_law_types_arg(law_types)
async with self.session_maker() as session:
async with session.begin():
stmt = select(LawSource).where(LawSource.is_active.is_(True))
if law_types:
stmt = stmt.where(LawSource.law_type.in_(law_types))
result = await session.scalars(stmt.order_by(LawSource.title))
return result.all()
async def get_law_source(self, source_id: int) -> LawSource | None:
async with self.session_maker() as session:
async with session.begin():
return await session.get(LawSource, source_id)
async def get_chunks_by_source(self, source_id: int) -> list[LawChunk]:
async with self.session_maker() as session:
async with session.begin():
result = await session.scalars(
select(LawChunk)
.where(LawChunk.source_id == source_id)
.order_by(LawChunk.chunk_index)
)
return result.all()
async def get_chunks_count_by_source(self, source_id: int) -> int:
async with self.session_maker() as session:
async with session.begin():
result = await session.scalar(
select(func.count())
.select_from(LawChunk)
.where(LawChunk.source_id == source_id)
)
return int(result or 0)
async def list_chunks_for_indexing(
self,
source_ids: list[int] | None = None,
law_types: list[str] | None = None,
active_only: bool = True,
) -> list[JSONDict]:
law_types = normalize_law_types_arg(law_types)
async with self.session_maker() as session:
async with session.begin():
stmt = (
select(LawChunk, LawSource)
.join(LawSource, LawSource.id == LawChunk.source_id)
.order_by(LawSource.id, LawChunk.chunk_index)
)
if active_only:
stmt = stmt.where(LawSource.is_active.is_(True))
if source_ids:
stmt = stmt.where(LawSource.id.in_(source_ids))
if law_types:
stmt = stmt.where(LawSource.law_type.in_(law_types))
rows = await session.execute(stmt)
payloads: list[JSONDict] = []
for chunk, source in rows.all():
payloads.append(
{
"chunk_id": chunk.id,
"source_id": source.id,
"source_title": source.title,
"source_url": source.source_url,
"law_type": source.law_type,
"jurisdiction": source.jurisdiction,
"version_hash": source.version_hash,
"article_number": chunk.article_number,
"article_title": chunk.article_title,
"chunk_text": chunk.chunk_text,
"metadata": chunk.chunk_metadata,
}
)
return payloads
async def get_law_chunks_with_sources_by_ids(
self,
chunk_ids: list[int],
law_types: list[str] | None = None,
jurisdiction: str = "RU",
) -> list[JSONDict]:
law_types = normalize_law_types_arg(law_types)
if not chunk_ids:
return []
async with self.session_maker() as session:
async with session.begin():
stmt = (
select(LawChunk, LawSource)
.join(LawSource, LawSource.id == LawChunk.source_id)
.where(
LawChunk.id.in_(chunk_ids),
LawSource.is_active.is_(True),
LawSource.jurisdiction == jurisdiction,
)
)
if law_types:
stmt = stmt.where(LawSource.law_type.in_(law_types))
rows = await session.execute(stmt)
by_id: dict[int, JSONDict] = {}
for chunk, source in rows.all():
by_id[chunk.id] = {
"chunk_id": chunk.id,
"source_id": source.id,
"source_title": source.title,
"source_url": source.source_url,
"law_type": source.law_type,
"jurisdiction": source.jurisdiction,
"article_number": chunk.article_number,
"article_title": chunk.article_title,
"chunk_text": chunk.chunk_text,
"metadata": chunk.chunk_metadata,
}
return [by_id[chunk_id] for chunk_id in chunk_ids if chunk_id in by_id]
async def search_law_chunks_full_text(
self,
query: str,
law_types: list[str] | None = None,
jurisdiction: str = "RU",
limit: int = 20,
) -> list[JSONDict]:
law_types = normalize_law_types_arg(law_types)
async with self.session_maker() as session:
async with session.begin():
ts_query = func.plainto_tsquery("russian", query)
score = func.ts_rank(LawChunk.tsv, ts_query).label("score")
stmt = (
select(LawChunk, LawSource, score)
.join(LawSource, LawSource.id == LawChunk.source_id)
.where(
LawSource.jurisdiction == jurisdiction,
LawSource.is_active.is_(True),
LawChunk.tsv.op("@@")(ts_query),
)
.order_by(score.desc())
.limit(limit)
)
if law_types:
stmt = stmt.where(LawSource.law_type.in_(law_types))
rows = await session.execute(stmt)
payloads: list[JSONDict] = []
for chunk, source, result_score in rows.all():
payloads.append(
{
"chunk_id": chunk.id,
"source_id": source.id,
"source_title": source.title,
"source_url": source.source_url,
"law_type": source.law_type,
"jurisdiction": source.jurisdiction,
"article_number": chunk.article_number,
"article_title": chunk.article_title,
"chunk_text": chunk.chunk_text,
"metadata": chunk.chunk_metadata,
"score": float(result_score or 0.0),
}
)
return payloads
async def create_consultation(
self,
user_id: int,
category: str,
title: str | None = None,
region: str | None = None,
status: str = "active",
) -> Consultation:
now = datetime.now(timezone.utc)
async with self.session_maker() as session:
async with session.begin():
consultation = Consultation(
user_id=user_id,
category=category,
title=title,
region=region,
status=status,
created_at=now,
updated_at=now,
)
session.add(consultation)
await session.flush()
return consultation
async def create_message(
self,
consultation_id: int,
role: str,
content: str,
sources_json: Any | None = None,
) -> Message:
async with self.session_maker() as session:
async with session.begin():
message = Message(
consultation_id=consultation_id,
role=role,
content=content,
sources_json=sources_json,
created_at=datetime.now(timezone.utc),
)
session.add(message)
await session.execute(
update(Consultation)
.where(Consultation.id == consultation_id)
.values(updated_at=datetime.now(timezone.utc))
)
await session.flush()
return message
async def create_rag_query(
self,
consultation_id: int | None,
user_message_id: int | None,
generated_queries: list[str],
retrieved_chunks: list[JSONDict],
) -> RagQuery:
async with self.session_maker() as session:
async with session.begin():
rag_query = RagQuery(
consultation_id=consultation_id,
user_message_id=user_message_id,
generated_queries=generated_queries,
retrieved_chunks=retrieved_chunks,
created_at=datetime.now(timezone.utc),
)
session.add(rag_query)
await session.flush()
return rag_query
async def is_user_exists(self, user_id: int) -> bool:
async with self.session_maker() as session:
async with session.begin():
query = await session.execute(
select(User.user_id).where(User.user_id == user_id)
)
return query.one_or_none() is not None
async def create_user(
self, user_id: int, username: str | None, fullname: str, register_date: datetime
) -> int | None:
async with self.session_maker() as session:
async with session.begin():
existing = await session.get(User, user_id)
if existing is not None:
existing.username = username
existing.fullname = fullname
existing.updated_at = datetime.now(timezone.utc)
await session.flush()
return existing.user_id
user = User(
user_id=user_id,
username=username,
fullname=fullname,
country="Россия",
user_type="physical_person",
register_date=register_date,
updated_at=register_date,
)
session.add(user)
await session.flush()
return user.user_id
async def set_users_field(
self, user_id: int, field: str, value: int | str | bool
) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(
update(User)
.where(User.user_id == user_id)
.values(
{
getattr(User, field): value,
User.updated_at: datetime.now(timezone.utc),
}
)
)
async def set_user_region(self, user_id: int, region: str) -> None:
await self.set_users_field(user_id, "region", region)
async def set_user_type(self, user_id: int, user_type: str) -> None:
await self.set_users_field(user_id, "user_type", user_type)
async def get_user(self, user_id: int) -> User | None:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(
select(User).where(User.user_id == user_id)
)
return query.one_or_none()
async def get_all_users(self) -> list[User]:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(select(User))
return query.all()
async def get_users_count(self) -> int:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalar(select(func.count()).select_from(User))
return int(query or 0)
async def get_all_user_ids(self) -> list[int]:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(select(User.user_id))
return query.all()
async def delete_user(self, user_id: int) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(delete(User).where(User.user_id == user_id))
async def get_consultation(
self, consultation_id: int, user_id: int | None = None
) -> Consultation | None:
async with self.session_maker() as session:
async with session.begin():
stmt = select(Consultation).where(Consultation.id == consultation_id)
if user_id is not None:
stmt = stmt.where(Consultation.user_id == user_id)
query = await session.scalars(stmt)
return query.one_or_none()
async def list_user_consultations(
self, user_id: int, limit: int = 20
) -> list[Consultation]:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(
select(Consultation)
.where(Consultation.user_id == user_id)
.order_by(Consultation.updated_at.desc(), Consultation.id.desc())
.limit(limit)
)
return query.all()
async def get_consultation_messages(
self,
consultation_id: int,
limit: int | None = None,
) -> list[Message]:
async with self.session_maker() as session:
async with session.begin():
stmt = (
select(Message)
.where(Message.consultation_id == consultation_id)
.order_by(Message.created_at.asc(), Message.id.asc())
)
if limit is not None:
stmt = stmt.limit(limit)
query = await session.scalars(stmt)
return query.all()
async def delete_consultation(self, consultation_id: int, user_id: int) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(
delete(Consultation).where(
Consultation.id == consultation_id,
Consultation.user_id == user_id,
)
)
async def count_user_consultations_since(
self, user_id: int, since: datetime
) -> int:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalar(
select(func.count())
.select_from(Consultation)
.where(
Consultation.user_id == user_id,
Consultation.created_at >= since,
)
)
return int(query or 0)
async def count_user_messages_in_consultation(self, consultation_id: int) -> int:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalar(
select(func.count())
.select_from(Message)
.where(
Message.consultation_id == consultation_id,
Message.role == "user",
)
)
return int(query or 0)
async def is_admin_exists(self, user_id: int) -> bool:
async with self.session_maker() as session:
async with session.begin():
query = await session.execute(
select(Admin.user_id).where(Admin.user_id == user_id)
)
return query.one_or_none() is not None
async def create_admin(self, user_id: int, username: str, fullname: str) -> None:
async with self.session_maker() as session:
async with session.begin():
admin = Admin(user_id=user_id, username=username, fullname=fullname)
await session.merge(admin)
async def get_admin(self, user_id: int) -> Admin | None:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(
select(Admin).where(Admin.user_id == user_id)
)
return query.one_or_none()
async def get_all_admins(self) -> list[Admin]:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(select(Admin))
return query.all()
async def delete_admin(self, user_id: int) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(delete(Admin).where(Admin.user_id == user_id))
async def set_admin_field(
self, user_id: int, field: str, value: int | str | bool
) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(
update(Admin)
.where(Admin.user_id == user_id)
.values({getattr(Admin, field): value})
)
async def is_setting_exists(self, name: str) -> bool:
async with self.session_maker() as session:
async with session.begin():
query = await session.execute(
select(Setting).where(Setting.name == name)
)
return query.one_or_none() is not None
async def create_setting(self, name: str, value: Any) -> None:
async with self.session_maker() as session:
async with session.begin():
setting = Setting(name=name, value=value)
await session.merge(setting)
async def init_settings(self) -> None:
return None
async def get_setting_value(self, name: str) -> Any:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(
select(Setting.value).where(Setting.name == name)
)
return query.one_or_none()
async def update_setting_value(self, name: str, value: dict | list) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(
update(Setting)
.where(Setting.name == name)
.values({getattr(Setting, "value"): value})
)
async def is_blacklisted(self, user_id: int) -> bool:
async with self.session_maker() as session:
async with session.begin():
query = await session.execute(
select(Blacklist).where(Blacklist.user_id == user_id)
)
return query.one_or_none() is not None
async def create_blacklist(self, user_id: int) -> None:
async with self.session_maker() as session:
async with session.begin():
blacklist = Blacklist(user_id=user_id)
await session.merge(blacklist)
async def get_all_blacklist(self) -> list[int]:
async with self.session_maker() as session:
async with session.begin():
query = await session.scalars(
select(Blacklist.user_id).order_by(Blacklist.user_id)
)
return query.all()
async def delete_blacklist(self, user_id: int) -> None:
async with self.session_maker() as session:
async with session.begin():
await session.execute(
delete(Blacklist).where(Blacklist.user_id == user_id)
)