from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Iterable from decouple import config from sqlalchemy import delete, func, select, text, update from shared.engine import create_async_engine, get_session_maker from shared.models import ( Admin, BaseModel, Blacklist, Consultation, LawChunk, LawSource, Message, RagQuery, Setting, User, ) from shared.types import JSONDict def resolve_database_url() -> str: database_url = config("DATABASE_URL", default=None) if database_url: return database_url env_values = load_env_file_values( Path(__file__).resolve().parent.parent / "postgres.env" ) return ( f"postgresql+asyncpg://{config('POSTGRES_USER', default=env_values.get('POSTGRES_USER'))}:" f"{config('POSTGRES_PASSWORD', default=env_values.get('POSTGRES_PASSWORD'))}" f"@{config('POSTGRES_HOST', default=env_values.get('POSTGRES_HOST'))}:" f"{config('POSTGRES_PORT', default=env_values.get('POSTGRES_PORT'))}/" f"{config('POSTGRES_DB', default=env_values.get('POSTGRES_DB'))}" ) def load_env_file_values(path: Path) -> dict[str, str]: if not path.exists(): return {} values: dict[str, str] = {} for raw_line in path.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) values[key.strip()] = value.strip().strip("'").strip('"') return values def normalize_law_types_arg( law_types: list[str] | str | None, ) -> list[str] | None: if law_types is None: return None if isinstance(law_types, str): return [law_types] normalized = [ item.strip() for item in law_types if isinstance(item, str) and item.strip() ] return normalized or None @dataclass(frozen=True) class SourceUpsertDecision: action: str should_replace_chunks: bool def classify_source_update( existing_version_hash: str | None, new_version_hash: str ) -> SourceUpsertDecision: if existing_version_hash is None: return SourceUpsertDecision(action="create", should_replace_chunks=True) if existing_version_hash == new_version_hash: return SourceUpsertDecision(action="reuse", should_replace_chunks=False) return SourceUpsertDecision(action="supersede", should_replace_chunks=True) class ORM: def __init__(self, database_url: str | None = None): self.database_url = database_url or resolve_database_url() self.async_engine = create_async_engine(url=self.database_url) self.session_maker = get_session_maker(self.async_engine) async def init_schema(self) -> None: async with self.async_engine.begin() as conn: await conn.run_sync(BaseModel.metadata.create_all) await conn.execute( text( "ALTER TABLE users " "ADD COLUMN IF NOT EXISTS country TEXT NOT NULL DEFAULT 'Россия'" ) ) await conn.execute( text("ALTER TABLE users ADD COLUMN IF NOT EXISTS region TEXT") ) await conn.execute( text( "ALTER TABLE users " "ADD COLUMN IF NOT EXISTS user_type TEXT NOT NULL DEFAULT 'physical_person'" ) ) await conn.execute( text( "ALTER TABLE users " "ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT now()" ) ) await conn.execute( text( "UPDATE users " "SET country = COALESCE(country, 'Россия'), " "user_type = COALESCE(user_type, 'physical_person'), " "updated_at = COALESCE(updated_at, register_date, now())" ) ) async def proceed_schemas(self) -> None: await self.init_schema() async def close(self) -> None: await self.async_engine.dispose() async def get_active_law_source_by_url(self, source_url: str) -> LawSource | None: async with self.session_maker() as session: async with session.begin(): result = await session.scalars( select(LawSource) .where( LawSource.source_url == source_url, LawSource.is_active.is_(True), ) .order_by(LawSource.loaded_at.desc()) ) return result.first() async def upsert_law_source(self, payload: JSONDict) -> tuple[LawSource, bool]: async with self.session_maker() as session: async with session.begin(): existing = await session.scalars( select(LawSource) .where( LawSource.source_url == payload["source_url"], LawSource.is_active.is_(True), ) .order_by(LawSource.loaded_at.desc()) ) active = existing.first() decision = classify_source_update( active.version_hash if active else None, payload["version_hash"], ) if decision.action == "reuse": return active, False if active is not None: active.is_active = False await session.flush() source = LawSource(**payload) session.add(source) await session.flush() return source, True async def replace_law_chunks( self, source_id: int, chunks: Iterable[JSONDict] ) -> int: chunk_rows = list(chunks) async with self.session_maker() as session: async with session.begin(): await session.execute( delete(LawChunk).where(LawChunk.source_id == source_id) ) for row in chunk_rows: session.add( LawChunk( source_id=source_id, chunk_index=row["chunk_index"], article_number=row.get("article_number"), article_title=row.get("article_title"), chunk_text=row["chunk_text"], chunk_metadata=row["metadata"], created_at=row.get( "created_at", datetime.now(timezone.utc) ), ) ) await session.flush() await session.execute( text( "UPDATE law_chunks " "SET tsv = to_tsvector('russian', chunk_text) " "WHERE source_id = :source_id" ), {"source_id": source_id}, ) return len(chunk_rows) async def list_active_law_sources( self, law_types: list[str] | None = None ) -> list[LawSource]: law_types = normalize_law_types_arg(law_types) async with self.session_maker() as session: async with session.begin(): stmt = select(LawSource).where(LawSource.is_active.is_(True)) if law_types: stmt = stmt.where(LawSource.law_type.in_(law_types)) result = await session.scalars(stmt.order_by(LawSource.title)) return result.all() async def get_law_source(self, source_id: int) -> LawSource | None: async with self.session_maker() as session: async with session.begin(): return await session.get(LawSource, source_id) async def get_chunks_by_source(self, source_id: int) -> list[LawChunk]: async with self.session_maker() as session: async with session.begin(): result = await session.scalars( select(LawChunk) .where(LawChunk.source_id == source_id) .order_by(LawChunk.chunk_index) ) return result.all() async def get_chunks_count_by_source(self, source_id: int) -> int: async with self.session_maker() as session: async with session.begin(): result = await session.scalar( select(func.count()) .select_from(LawChunk) .where(LawChunk.source_id == source_id) ) return int(result or 0) async def list_chunks_for_indexing( self, source_ids: list[int] | None = None, law_types: list[str] | None = None, active_only: bool = True, ) -> list[JSONDict]: law_types = normalize_law_types_arg(law_types) async with self.session_maker() as session: async with session.begin(): stmt = ( select(LawChunk, LawSource) .join(LawSource, LawSource.id == LawChunk.source_id) .order_by(LawSource.id, LawChunk.chunk_index) ) if active_only: stmt = stmt.where(LawSource.is_active.is_(True)) if source_ids: stmt = stmt.where(LawSource.id.in_(source_ids)) if law_types: stmt = stmt.where(LawSource.law_type.in_(law_types)) rows = await session.execute(stmt) payloads: list[JSONDict] = [] for chunk, source in rows.all(): payloads.append( { "chunk_id": chunk.id, "source_id": source.id, "source_title": source.title, "source_url": source.source_url, "law_type": source.law_type, "jurisdiction": source.jurisdiction, "version_hash": source.version_hash, "article_number": chunk.article_number, "article_title": chunk.article_title, "chunk_text": chunk.chunk_text, "metadata": chunk.chunk_metadata, } ) return payloads async def get_law_chunks_with_sources_by_ids( self, chunk_ids: list[int], law_types: list[str] | None = None, jurisdiction: str = "RU", ) -> list[JSONDict]: law_types = normalize_law_types_arg(law_types) if not chunk_ids: return [] async with self.session_maker() as session: async with session.begin(): stmt = ( select(LawChunk, LawSource) .join(LawSource, LawSource.id == LawChunk.source_id) .where( LawChunk.id.in_(chunk_ids), LawSource.is_active.is_(True), LawSource.jurisdiction == jurisdiction, ) ) if law_types: stmt = stmt.where(LawSource.law_type.in_(law_types)) rows = await session.execute(stmt) by_id: dict[int, JSONDict] = {} for chunk, source in rows.all(): by_id[chunk.id] = { "chunk_id": chunk.id, "source_id": source.id, "source_title": source.title, "source_url": source.source_url, "law_type": source.law_type, "jurisdiction": source.jurisdiction, "article_number": chunk.article_number, "article_title": chunk.article_title, "chunk_text": chunk.chunk_text, "metadata": chunk.chunk_metadata, } return [by_id[chunk_id] for chunk_id in chunk_ids if chunk_id in by_id] async def search_law_chunks_full_text( self, query: str, law_types: list[str] | None = None, jurisdiction: str = "RU", limit: int = 20, ) -> list[JSONDict]: law_types = normalize_law_types_arg(law_types) async with self.session_maker() as session: async with session.begin(): ts_query = func.plainto_tsquery("russian", query) score = func.ts_rank(LawChunk.tsv, ts_query).label("score") stmt = ( select(LawChunk, LawSource, score) .join(LawSource, LawSource.id == LawChunk.source_id) .where( LawSource.jurisdiction == jurisdiction, LawSource.is_active.is_(True), LawChunk.tsv.op("@@")(ts_query), ) .order_by(score.desc()) .limit(limit) ) if law_types: stmt = stmt.where(LawSource.law_type.in_(law_types)) rows = await session.execute(stmt) payloads: list[JSONDict] = [] for chunk, source, result_score in rows.all(): payloads.append( { "chunk_id": chunk.id, "source_id": source.id, "source_title": source.title, "source_url": source.source_url, "law_type": source.law_type, "jurisdiction": source.jurisdiction, "article_number": chunk.article_number, "article_title": chunk.article_title, "chunk_text": chunk.chunk_text, "metadata": chunk.chunk_metadata, "score": float(result_score or 0.0), } ) return payloads async def create_consultation( self, user_id: int, category: str, title: str | None = None, region: str | None = None, status: str = "active", ) -> Consultation: now = datetime.now(timezone.utc) async with self.session_maker() as session: async with session.begin(): consultation = Consultation( user_id=user_id, category=category, title=title, region=region, status=status, created_at=now, updated_at=now, ) session.add(consultation) await session.flush() return consultation async def create_message( self, consultation_id: int, role: str, content: str, sources_json: Any | None = None, ) -> Message: async with self.session_maker() as session: async with session.begin(): message = Message( consultation_id=consultation_id, role=role, content=content, sources_json=sources_json, created_at=datetime.now(timezone.utc), ) session.add(message) await session.execute( update(Consultation) .where(Consultation.id == consultation_id) .values(updated_at=datetime.now(timezone.utc)) ) await session.flush() return message async def create_rag_query( self, consultation_id: int | None, user_message_id: int | None, generated_queries: list[str], retrieved_chunks: list[JSONDict], ) -> RagQuery: async with self.session_maker() as session: async with session.begin(): rag_query = RagQuery( consultation_id=consultation_id, user_message_id=user_message_id, generated_queries=generated_queries, retrieved_chunks=retrieved_chunks, created_at=datetime.now(timezone.utc), ) session.add(rag_query) await session.flush() return rag_query async def is_user_exists(self, user_id: int) -> bool: async with self.session_maker() as session: async with session.begin(): query = await session.execute( select(User.user_id).where(User.user_id == user_id) ) return query.one_or_none() is not None async def create_user( self, user_id: int, username: str | None, fullname: str, register_date: datetime ) -> int | None: async with self.session_maker() as session: async with session.begin(): existing = await session.get(User, user_id) if existing is not None: existing.username = username existing.fullname = fullname existing.updated_at = datetime.now(timezone.utc) await session.flush() return existing.user_id user = User( user_id=user_id, username=username, fullname=fullname, country="Россия", user_type="physical_person", register_date=register_date, updated_at=register_date, ) session.add(user) await session.flush() return user.user_id async def set_users_field( self, user_id: int, field: str, value: int | str | bool ) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute( update(User) .where(User.user_id == user_id) .values( { getattr(User, field): value, User.updated_at: datetime.now(timezone.utc), } ) ) async def set_user_region(self, user_id: int, region: str) -> None: await self.set_users_field(user_id, "region", region) async def set_user_type(self, user_id: int, user_type: str) -> None: await self.set_users_field(user_id, "user_type", user_type) async def get_user(self, user_id: int) -> User | None: async with self.session_maker() as session: async with session.begin(): query = await session.scalars( select(User).where(User.user_id == user_id) ) return query.one_or_none() async def get_all_users(self) -> list[User]: async with self.session_maker() as session: async with session.begin(): query = await session.scalars(select(User)) return query.all() async def get_users_count(self) -> int: async with self.session_maker() as session: async with session.begin(): query = await session.scalar(select(func.count()).select_from(User)) return int(query or 0) async def get_all_user_ids(self) -> list[int]: async with self.session_maker() as session: async with session.begin(): query = await session.scalars(select(User.user_id)) return query.all() async def delete_user(self, user_id: int) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute(delete(User).where(User.user_id == user_id)) async def get_consultation( self, consultation_id: int, user_id: int | None = None ) -> Consultation | None: async with self.session_maker() as session: async with session.begin(): stmt = select(Consultation).where(Consultation.id == consultation_id) if user_id is not None: stmt = stmt.where(Consultation.user_id == user_id) query = await session.scalars(stmt) return query.one_or_none() async def list_user_consultations( self, user_id: int, limit: int = 20 ) -> list[Consultation]: async with self.session_maker() as session: async with session.begin(): query = await session.scalars( select(Consultation) .where(Consultation.user_id == user_id) .order_by(Consultation.updated_at.desc(), Consultation.id.desc()) .limit(limit) ) return query.all() async def get_consultation_messages( self, consultation_id: int, limit: int | None = None, ) -> list[Message]: async with self.session_maker() as session: async with session.begin(): stmt = ( select(Message) .where(Message.consultation_id == consultation_id) .order_by(Message.created_at.asc(), Message.id.asc()) ) if limit is not None: stmt = stmt.limit(limit) query = await session.scalars(stmt) return query.all() async def delete_consultation(self, consultation_id: int, user_id: int) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute( delete(Consultation).where( Consultation.id == consultation_id, Consultation.user_id == user_id, ) ) async def count_user_consultations_since( self, user_id: int, since: datetime ) -> int: async with self.session_maker() as session: async with session.begin(): query = await session.scalar( select(func.count()) .select_from(Consultation) .where( Consultation.user_id == user_id, Consultation.created_at >= since, ) ) return int(query or 0) async def count_user_messages_in_consultation(self, consultation_id: int) -> int: async with self.session_maker() as session: async with session.begin(): query = await session.scalar( select(func.count()) .select_from(Message) .where( Message.consultation_id == consultation_id, Message.role == "user", ) ) return int(query or 0) async def is_admin_exists(self, user_id: int) -> bool: async with self.session_maker() as session: async with session.begin(): query = await session.execute( select(Admin.user_id).where(Admin.user_id == user_id) ) return query.one_or_none() is not None async def create_admin(self, user_id: int, username: str, fullname: str) -> None: async with self.session_maker() as session: async with session.begin(): admin = Admin(user_id=user_id, username=username, fullname=fullname) await session.merge(admin) async def get_admin(self, user_id: int) -> Admin | None: async with self.session_maker() as session: async with session.begin(): query = await session.scalars( select(Admin).where(Admin.user_id == user_id) ) return query.one_or_none() async def get_all_admins(self) -> list[Admin]: async with self.session_maker() as session: async with session.begin(): query = await session.scalars(select(Admin)) return query.all() async def delete_admin(self, user_id: int) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute(delete(Admin).where(Admin.user_id == user_id)) async def set_admin_field( self, user_id: int, field: str, value: int | str | bool ) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute( update(Admin) .where(Admin.user_id == user_id) .values({getattr(Admin, field): value}) ) async def is_setting_exists(self, name: str) -> bool: async with self.session_maker() as session: async with session.begin(): query = await session.execute( select(Setting).where(Setting.name == name) ) return query.one_or_none() is not None async def create_setting(self, name: str, value: Any) -> None: async with self.session_maker() as session: async with session.begin(): setting = Setting(name=name, value=value) await session.merge(setting) async def init_settings(self) -> None: return None async def get_setting_value(self, name: str) -> Any: async with self.session_maker() as session: async with session.begin(): query = await session.scalars( select(Setting.value).where(Setting.name == name) ) return query.one_or_none() async def update_setting_value(self, name: str, value: dict | list) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute( update(Setting) .where(Setting.name == name) .values({getattr(Setting, "value"): value}) ) async def is_blacklisted(self, user_id: int) -> bool: async with self.session_maker() as session: async with session.begin(): query = await session.execute( select(Blacklist).where(Blacklist.user_id == user_id) ) return query.one_or_none() is not None async def create_blacklist(self, user_id: int) -> None: async with self.session_maker() as session: async with session.begin(): blacklist = Blacklist(user_id=user_id) await session.merge(blacklist) async def get_all_blacklist(self) -> list[int]: async with self.session_maker() as session: async with session.begin(): query = await session.scalars( select(Blacklist.user_id).order_by(Blacklist.user_id) ) return query.all() async def delete_blacklist(self, user_id: int) -> None: async with self.session_maker() as session: async with session.begin(): await session.execute( delete(Blacklist).where(Blacklist.user_id == user_id) )