from __future__ import annotations import logging from fastapi.concurrency import run_in_threadpool from api.clients.chroma_store import ChromaVectorStore from api.config import settings from api.services.local_embeddings import LocalEmbeddingService from shared import ORM logger = logging.getLogger(__name__) def build_embedding_text(chunk: dict) -> str: article_title = chunk.get("article_title") or "" source_title = chunk.get("source_title") or "" return "\n".join( part for part in [ f"source: {source_title}", f"article: {article_title}", f"text: {chunk['chunk_text']}", ] if part.strip() ) class IndexingService: def __init__( self, orm: ORM, embedder: LocalEmbeddingService, vector_store: ChromaVectorStore, ) -> None: self.orm = orm self.embedder = embedder self.vector_store = vector_store async def get_indexable_chunks_count( self, source_ids: list[int] | None = None, law_types: list[str] | None = None, ) -> int: chunks = await self.orm.list_chunks_for_indexing( source_ids=source_ids, law_types=law_types, active_only=True, ) return len(chunks) async def rebuild( self, source_ids: list[int] | None = None, law_types: list[str] | None = None, reset_collection: bool = True, batch_size: int | None = None, ) -> dict: logger.info( "Index rebuild started: reset_collection=%s source_ids=%s law_types=%s", reset_collection, source_ids, law_types, ) chunks = await self.orm.list_chunks_for_indexing( source_ids=source_ids, law_types=law_types, active_only=True, ) logger.info("Loaded %s chunks from Postgres for indexing", len(chunks)) if reset_collection: await run_in_threadpool(self.vector_store.reset_collection) batch_size = batch_size or settings.index_batch_size indexed_chunks = 0 indexed_sources = len({chunk["source_id"] for chunk in chunks}) for start in range(0, len(chunks), batch_size): batch = chunks[start : start + batch_size] batch_number = (start // batch_size) + 1 total_batches = max(1, (len(chunks) + batch_size - 1) // batch_size) logger.info( "Indexing batch %s/%s with %s chunks", batch_number, total_batches, len(batch), ) embeddings = await run_in_threadpool( self.embedder.encode_documents, [build_embedding_text(chunk) for chunk in batch], ) await run_in_threadpool( self.vector_store.upsert, [str(chunk["chunk_id"]) for chunk in batch], [chunk["chunk_text"] for chunk in batch], embeddings, [ { "chunk_id": chunk["chunk_id"], "source_id": chunk["source_id"], "source_title": chunk["source_title"], "source_url": chunk["source_url"], "law_type": chunk["law_type"] or "", "jurisdiction": chunk["jurisdiction"], "article_number": chunk["article_number"] or "", "article_title": chunk["article_title"] or "", "version_hash": chunk["version_hash"], } for chunk in batch ], ) indexed_chunks += len(batch) logger.info( "Indexed %s/%s chunks into Chroma", indexed_chunks, len(chunks), ) result = { "indexed_chunks": indexed_chunks, "indexed_sources": indexed_sources, "collection_name": self.vector_store.collection_name, } logger.info("Index rebuild completed: %s", result) return result