first commit
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""API clients."""
|
||||
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
|
||||
from api.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaVectorStore:
|
||||
def __init__(self) -> None:
|
||||
self._client = chromadb.HttpClient(
|
||||
host=settings.chroma_host,
|
||||
port=settings.chroma_port,
|
||||
ssl=settings.chroma_ssl,
|
||||
settings=ChromaSettings(anonymized_telemetry=False),
|
||||
)
|
||||
self._collection_name = settings.chroma_collection
|
||||
logger.info(
|
||||
"Chroma client configured: host=%s port=%s collection=%s ssl=%s",
|
||||
settings.chroma_host,
|
||||
settings.chroma_port,
|
||||
self._collection_name,
|
||||
settings.chroma_ssl,
|
||||
)
|
||||
|
||||
@property
|
||||
def collection_name(self) -> str:
|
||||
return self._collection_name
|
||||
|
||||
def get_collection(self):
|
||||
return self._client.get_or_create_collection(
|
||||
name=self._collection_name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
|
||||
def count(self) -> int:
|
||||
return int(self.get_collection().count())
|
||||
|
||||
def reset_collection(self) -> None:
|
||||
logger.warning("Resetting Chroma collection: %s", self._collection_name)
|
||||
try:
|
||||
self._client.delete_collection(self._collection_name)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Could not delete Chroma collection before reset, continuing with create_collection"
|
||||
)
|
||||
self.get_collection()
|
||||
|
||||
def upsert(
|
||||
self,
|
||||
ids: list[str],
|
||||
documents: list[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Upserting %s embeddings into Chroma collection=%s",
|
||||
len(ids),
|
||||
self._collection_name,
|
||||
)
|
||||
self.get_collection().upsert(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
query_embeddings: list[list[float]],
|
||||
n_results: int,
|
||||
) -> dict[str, Any]:
|
||||
logger.info(
|
||||
"Running vector query against collection=%s n_results=%s",
|
||||
self._collection_name,
|
||||
n_results,
|
||||
)
|
||||
return self.get_collection().query(
|
||||
query_embeddings=query_embeddings,
|
||||
n_results=n_results,
|
||||
include=["distances", "metadatas", "documents"],
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from api.config import settings
|
||||
|
||||
|
||||
def build_openai_client() -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
api_key=settings.openai_api_key,
|
||||
base_url=settings.openai_base_url,
|
||||
timeout=settings.llm_timeout_seconds,
|
||||
)
|
||||
Reference in New Issue
Block a user