89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
import chromadb
|
|
from chromadb.config import Settings as ChromaSettings
|
|
|
|
from api.config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChromaVectorStore:
|
|
def __init__(self) -> None:
|
|
self._client = chromadb.HttpClient(
|
|
host=settings.chroma_host,
|
|
port=settings.chroma_port,
|
|
ssl=settings.chroma_ssl,
|
|
settings=ChromaSettings(anonymized_telemetry=False),
|
|
)
|
|
self._collection_name = settings.chroma_collection
|
|
logger.info(
|
|
"Chroma client configured: host=%s port=%s collection=%s ssl=%s",
|
|
settings.chroma_host,
|
|
settings.chroma_port,
|
|
self._collection_name,
|
|
settings.chroma_ssl,
|
|
)
|
|
|
|
@property
|
|
def collection_name(self) -> str:
|
|
return self._collection_name
|
|
|
|
def get_collection(self):
|
|
return self._client.get_or_create_collection(
|
|
name=self._collection_name,
|
|
metadata={"hnsw:space": "cosine"},
|
|
)
|
|
|
|
def count(self) -> int:
|
|
return int(self.get_collection().count())
|
|
|
|
def reset_collection(self) -> None:
|
|
logger.warning("Resetting Chroma collection: %s", self._collection_name)
|
|
try:
|
|
self._client.delete_collection(self._collection_name)
|
|
except Exception:
|
|
logger.exception(
|
|
"Could not delete Chroma collection before reset, continuing with create_collection"
|
|
)
|
|
self.get_collection()
|
|
|
|
def upsert(
|
|
self,
|
|
ids: list[str],
|
|
documents: list[str],
|
|
embeddings: list[list[float]],
|
|
metadatas: list[dict[str, Any]],
|
|
) -> None:
|
|
logger.info(
|
|
"Upserting %s embeddings into Chroma collection=%s",
|
|
len(ids),
|
|
self._collection_name,
|
|
)
|
|
self.get_collection().upsert(
|
|
ids=ids,
|
|
documents=documents,
|
|
embeddings=embeddings,
|
|
metadatas=metadatas,
|
|
)
|
|
|
|
def query(
|
|
self,
|
|
query_embeddings: list[list[float]],
|
|
n_results: int,
|
|
) -> dict[str, Any]:
|
|
logger.info(
|
|
"Running vector query against collection=%s n_results=%s",
|
|
self._collection_name,
|
|
n_results,
|
|
)
|
|
return self.get_collection().query(
|
|
query_embeddings=query_embeddings,
|
|
n_results=n_results,
|
|
include=["distances", "metadatas", "documents"],
|
|
)
|