127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from fastapi.concurrency import run_in_threadpool
|
|
|
|
from api.clients.chroma_store import ChromaVectorStore
|
|
from api.config import settings
|
|
from api.services.local_embeddings import LocalEmbeddingService
|
|
from shared import ORM
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def build_embedding_text(chunk: dict) -> str:
|
|
article_title = chunk.get("article_title") or ""
|
|
source_title = chunk.get("source_title") or ""
|
|
return "\n".join(
|
|
part
|
|
for part in [
|
|
f"source: {source_title}",
|
|
f"article: {article_title}",
|
|
f"text: {chunk['chunk_text']}",
|
|
]
|
|
if part.strip()
|
|
)
|
|
|
|
|
|
class IndexingService:
|
|
def __init__(
|
|
self,
|
|
orm: ORM,
|
|
embedder: LocalEmbeddingService,
|
|
vector_store: ChromaVectorStore,
|
|
) -> None:
|
|
self.orm = orm
|
|
self.embedder = embedder
|
|
self.vector_store = vector_store
|
|
|
|
async def get_indexable_chunks_count(
|
|
self,
|
|
source_ids: list[int] | None = None,
|
|
law_types: list[str] | None = None,
|
|
) -> int:
|
|
chunks = await self.orm.list_chunks_for_indexing(
|
|
source_ids=source_ids,
|
|
law_types=law_types,
|
|
active_only=True,
|
|
)
|
|
return len(chunks)
|
|
|
|
async def rebuild(
|
|
self,
|
|
source_ids: list[int] | None = None,
|
|
law_types: list[str] | None = None,
|
|
reset_collection: bool = True,
|
|
batch_size: int | None = None,
|
|
) -> dict:
|
|
logger.info(
|
|
"Index rebuild started: reset_collection=%s source_ids=%s law_types=%s",
|
|
reset_collection,
|
|
source_ids,
|
|
law_types,
|
|
)
|
|
chunks = await self.orm.list_chunks_for_indexing(
|
|
source_ids=source_ids,
|
|
law_types=law_types,
|
|
active_only=True,
|
|
)
|
|
logger.info("Loaded %s chunks from Postgres for indexing", len(chunks))
|
|
if reset_collection:
|
|
await run_in_threadpool(self.vector_store.reset_collection)
|
|
|
|
batch_size = batch_size or settings.index_batch_size
|
|
indexed_chunks = 0
|
|
indexed_sources = len({chunk["source_id"] for chunk in chunks})
|
|
|
|
for start in range(0, len(chunks), batch_size):
|
|
batch = chunks[start : start + batch_size]
|
|
batch_number = (start // batch_size) + 1
|
|
total_batches = max(1, (len(chunks) + batch_size - 1) // batch_size)
|
|
logger.info(
|
|
"Indexing batch %s/%s with %s chunks",
|
|
batch_number,
|
|
total_batches,
|
|
len(batch),
|
|
)
|
|
embeddings = await run_in_threadpool(
|
|
self.embedder.encode_documents,
|
|
[build_embedding_text(chunk) for chunk in batch],
|
|
)
|
|
await run_in_threadpool(
|
|
self.vector_store.upsert,
|
|
[str(chunk["chunk_id"]) for chunk in batch],
|
|
[chunk["chunk_text"] for chunk in batch],
|
|
embeddings,
|
|
[
|
|
{
|
|
"chunk_id": chunk["chunk_id"],
|
|
"source_id": chunk["source_id"],
|
|
"source_title": chunk["source_title"],
|
|
"source_url": chunk["source_url"],
|
|
"law_type": chunk["law_type"] or "",
|
|
"jurisdiction": chunk["jurisdiction"],
|
|
"article_number": chunk["article_number"] or "",
|
|
"article_title": chunk["article_title"] or "",
|
|
"version_hash": chunk["version_hash"],
|
|
}
|
|
for chunk in batch
|
|
],
|
|
)
|
|
indexed_chunks += len(batch)
|
|
logger.info(
|
|
"Indexed %s/%s chunks into Chroma",
|
|
indexed_chunks,
|
|
len(chunks),
|
|
)
|
|
|
|
result = {
|
|
"indexed_chunks": indexed_chunks,
|
|
"indexed_sources": indexed_sources,
|
|
"collection_name": self.vector_store.collection_name,
|
|
}
|
|
logger.info("Index rebuild completed: %s", result)
|
|
return result
|