first commit
This commit is contained in:
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from api.clients.chroma_store import ChromaVectorStore
|
||||
from api.config import settings
|
||||
from api.services.local_embeddings import LocalEmbeddingService
|
||||
from shared import ORM
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_embedding_text(chunk: dict) -> str:
|
||||
article_title = chunk.get("article_title") or ""
|
||||
source_title = chunk.get("source_title") or ""
|
||||
return "\n".join(
|
||||
part
|
||||
for part in [
|
||||
f"source: {source_title}",
|
||||
f"article: {article_title}",
|
||||
f"text: {chunk['chunk_text']}",
|
||||
]
|
||||
if part.strip()
|
||||
)
|
||||
|
||||
|
||||
class IndexingService:
|
||||
def __init__(
|
||||
self,
|
||||
orm: ORM,
|
||||
embedder: LocalEmbeddingService,
|
||||
vector_store: ChromaVectorStore,
|
||||
) -> None:
|
||||
self.orm = orm
|
||||
self.embedder = embedder
|
||||
self.vector_store = vector_store
|
||||
|
||||
async def get_indexable_chunks_count(
|
||||
self,
|
||||
source_ids: list[int] | None = None,
|
||||
law_types: list[str] | None = None,
|
||||
) -> int:
|
||||
chunks = await self.orm.list_chunks_for_indexing(
|
||||
source_ids=source_ids,
|
||||
law_types=law_types,
|
||||
active_only=True,
|
||||
)
|
||||
return len(chunks)
|
||||
|
||||
async def rebuild(
|
||||
self,
|
||||
source_ids: list[int] | None = None,
|
||||
law_types: list[str] | None = None,
|
||||
reset_collection: bool = True,
|
||||
batch_size: int | None = None,
|
||||
) -> dict:
|
||||
logger.info(
|
||||
"Index rebuild started: reset_collection=%s source_ids=%s law_types=%s",
|
||||
reset_collection,
|
||||
source_ids,
|
||||
law_types,
|
||||
)
|
||||
chunks = await self.orm.list_chunks_for_indexing(
|
||||
source_ids=source_ids,
|
||||
law_types=law_types,
|
||||
active_only=True,
|
||||
)
|
||||
logger.info("Loaded %s chunks from Postgres for indexing", len(chunks))
|
||||
if reset_collection:
|
||||
await run_in_threadpool(self.vector_store.reset_collection)
|
||||
|
||||
batch_size = batch_size or settings.index_batch_size
|
||||
indexed_chunks = 0
|
||||
indexed_sources = len({chunk["source_id"] for chunk in chunks})
|
||||
|
||||
for start in range(0, len(chunks), batch_size):
|
||||
batch = chunks[start : start + batch_size]
|
||||
batch_number = (start // batch_size) + 1
|
||||
total_batches = max(1, (len(chunks) + batch_size - 1) // batch_size)
|
||||
logger.info(
|
||||
"Indexing batch %s/%s with %s chunks",
|
||||
batch_number,
|
||||
total_batches,
|
||||
len(batch),
|
||||
)
|
||||
embeddings = await run_in_threadpool(
|
||||
self.embedder.encode_documents,
|
||||
[build_embedding_text(chunk) for chunk in batch],
|
||||
)
|
||||
await run_in_threadpool(
|
||||
self.vector_store.upsert,
|
||||
[str(chunk["chunk_id"]) for chunk in batch],
|
||||
[chunk["chunk_text"] for chunk in batch],
|
||||
embeddings,
|
||||
[
|
||||
{
|
||||
"chunk_id": chunk["chunk_id"],
|
||||
"source_id": chunk["source_id"],
|
||||
"source_title": chunk["source_title"],
|
||||
"source_url": chunk["source_url"],
|
||||
"law_type": chunk["law_type"] or "",
|
||||
"jurisdiction": chunk["jurisdiction"],
|
||||
"article_number": chunk["article_number"] or "",
|
||||
"article_title": chunk["article_title"] or "",
|
||||
"version_hash": chunk["version_hash"],
|
||||
}
|
||||
for chunk in batch
|
||||
],
|
||||
)
|
||||
indexed_chunks += len(batch)
|
||||
logger.info(
|
||||
"Indexed %s/%s chunks into Chroma",
|
||||
indexed_chunks,
|
||||
len(chunks),
|
||||
)
|
||||
|
||||
result = {
|
||||
"indexed_chunks": indexed_chunks,
|
||||
"indexed_sources": indexed_sources,
|
||||
"collection_name": self.vector_store.collection_name,
|
||||
}
|
||||
logger.info("Index rebuild completed: %s", result)
|
||||
return result
|
||||
Reference in New Issue
Block a user