ff-aitoolkit 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aitoolkit/__init__.py +66 -0
- aitoolkit/config.py +107 -0
- aitoolkit/embeddings/__init__.py +5 -0
- aitoolkit/embeddings/client.py +133 -0
- aitoolkit/exceptions.py +35 -0
- aitoolkit/integrations/__init__.py +1 -0
- aitoolkit/integrations/langchain.py +69 -0
- aitoolkit/llm/__init__.py +5 -0
- aitoolkit/llm/client.py +230 -0
- aitoolkit/py.typed +0 -0
- aitoolkit/rag/__init__.py +25 -0
- aitoolkit/rag/agent.py +165 -0
- aitoolkit/rag/query_expansion.py +147 -0
- aitoolkit/rag/retriever.py +141 -0
- aitoolkit/rag/vector_store.py +245 -0
- aitoolkit/retry.py +51 -0
- aitoolkit/stt/__init__.py +5 -0
- aitoolkit/stt/client.py +147 -0
- aitoolkit/tts/__init__.py +10 -0
- aitoolkit/tts/audio.py +68 -0
- aitoolkit/tts/client.py +219 -0
- aitoolkit/types.py +66 -0
- ff_aitoolkit-0.2.0.dist-info/METADATA +159 -0
- ff_aitoolkit-0.2.0.dist-info/RECORD +26 -0
- ff_aitoolkit-0.2.0.dist-info/WHEEL +4 -0
- ff_aitoolkit-0.2.0.dist-info/licenses/LICENSE +21 -0
aitoolkit/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""aitoolkit — centralized AI clients for self-hosted OpenAI-compatible services.
|
|
2
|
+
|
|
3
|
+
Core capabilities (LLM, embeddings, STT, TTS) are always available. RAG and the
|
|
4
|
+
LangChain bridge live behind extras and are imported from their own subpackages
|
|
5
|
+
(``aitoolkit.rag``, ``aitoolkit.integrations.langchain``) so that importing the
|
|
6
|
+
top-level package never forces an optional dependency.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from aitoolkit.config import AIToolkitSettings, configure, get_settings
|
|
12
|
+
from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
|
|
13
|
+
from aitoolkit.exceptions import (
|
|
14
|
+
AIToolkitError,
|
|
15
|
+
ConfigurationError,
|
|
16
|
+
EmbeddingsError,
|
|
17
|
+
LLMError,
|
|
18
|
+
STTError,
|
|
19
|
+
TTSError,
|
|
20
|
+
VectorStoreError,
|
|
21
|
+
)
|
|
22
|
+
from aitoolkit.llm import LLMClient, get_llm_client
|
|
23
|
+
from aitoolkit.stt import STTClient, get_stt_client
|
|
24
|
+
from aitoolkit.tts import TTSClient, concat_wav, get_tts_client
|
|
25
|
+
from aitoolkit.types import (
|
|
26
|
+
ChatMessage,
|
|
27
|
+
DialogueTurn,
|
|
28
|
+
RetrievedChunk,
|
|
29
|
+
TranscriptionResult,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__version__ = "0.2.0"
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"__version__",
|
|
36
|
+
# config
|
|
37
|
+
"AIToolkitSettings",
|
|
38
|
+
"configure",
|
|
39
|
+
"get_settings",
|
|
40
|
+
# llm
|
|
41
|
+
"LLMClient",
|
|
42
|
+
"get_llm_client",
|
|
43
|
+
# embeddings
|
|
44
|
+
"EmbeddingsClient",
|
|
45
|
+
"get_embeddings_client",
|
|
46
|
+
# stt
|
|
47
|
+
"STTClient",
|
|
48
|
+
"get_stt_client",
|
|
49
|
+
# tts
|
|
50
|
+
"TTSClient",
|
|
51
|
+
"get_tts_client",
|
|
52
|
+
"concat_wav",
|
|
53
|
+
# types
|
|
54
|
+
"ChatMessage",
|
|
55
|
+
"DialogueTurn",
|
|
56
|
+
"RetrievedChunk",
|
|
57
|
+
"TranscriptionResult",
|
|
58
|
+
# exceptions
|
|
59
|
+
"AIToolkitError",
|
|
60
|
+
"ConfigurationError",
|
|
61
|
+
"LLMError",
|
|
62
|
+
"EmbeddingsError",
|
|
63
|
+
"STTError",
|
|
64
|
+
"TTSError",
|
|
65
|
+
"VectorStoreError",
|
|
66
|
+
]
|
aitoolkit/config.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Central configuration for aitoolkit.
|
|
2
|
+
|
|
3
|
+
All settings are read from ``AITOOLKIT_*`` environment variables (or a ``.env``
|
|
4
|
+
file) but every client also accepts explicit overrides, so the package can be
|
|
5
|
+
used with zero environment configuration.
|
|
6
|
+
|
|
7
|
+
Defaults intentionally point at ``localhost`` — they are NOT specific to any one
|
|
8
|
+
deployment. Production endpoints are supplied by the consuming application via
|
|
9
|
+
environment variables (see ``.env.example``).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from functools import lru_cache
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
19
|
+
|
|
20
|
+
# A placeholder key for OpenAI-compatible servers that perform no app-layer auth
|
|
21
|
+
# (our GPU services are firewalled, not key-gated). The openai SDK requires a
|
|
22
|
+
# non-empty key, so we provide one.
|
|
23
|
+
_NO_AUTH = "no-auth"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AIToolkitSettings(BaseSettings):
|
|
27
|
+
"""Runtime configuration for every aitoolkit capability."""
|
|
28
|
+
|
|
29
|
+
model_config = SettingsConfigDict(
|
|
30
|
+
env_prefix="AITOOLKIT_",
|
|
31
|
+
env_file=".env",
|
|
32
|
+
env_file_encoding="utf-8",
|
|
33
|
+
extra="ignore",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# --- LLM (vLLM, OpenAI-compatible) ---
|
|
37
|
+
llm_base_url: str = Field(default="http://localhost:8000/v1")
|
|
38
|
+
llm_api_key: str = Field(default=_NO_AUTH)
|
|
39
|
+
# No default model id — set AITOOLKIT_LLM_MODEL to your served model.
|
|
40
|
+
llm_model: str = Field(default="")
|
|
41
|
+
llm_temperature: float = Field(default=0.2)
|
|
42
|
+
llm_timeout: float = Field(default=60.0)
|
|
43
|
+
llm_max_retries: int = Field(default=2)
|
|
44
|
+
|
|
45
|
+
# --- Embeddings (TEI, OpenAI-compatible) ---
|
|
46
|
+
embeddings_base_url: str = Field(default="http://localhost:8001/v1")
|
|
47
|
+
embeddings_api_key: str = Field(default=_NO_AUTH)
|
|
48
|
+
# No default model id — set AITOOLKIT_EMBEDDINGS_MODEL to your served model.
|
|
49
|
+
embeddings_model: str = Field(default="")
|
|
50
|
+
# TEI accepts modest batches; keep conservative and configurable.
|
|
51
|
+
embeddings_batch_size: int = Field(default=32)
|
|
52
|
+
embeddings_timeout: float = Field(default=60.0)
|
|
53
|
+
|
|
54
|
+
# --- Speech-to-Text (faster-whisper, OpenAI-compatible) ---
|
|
55
|
+
stt_base_url: str = Field(default="http://localhost:8003/v1")
|
|
56
|
+
stt_api_key: str = Field(default=_NO_AUTH)
|
|
57
|
+
stt_model: str = Field(default="whisper-1")
|
|
58
|
+
stt_language: Optional[str] = Field(default=None)
|
|
59
|
+
stt_timeout: float = Field(default=120.0)
|
|
60
|
+
|
|
61
|
+
# --- Text-to-Speech (custom /api/tts) ---
|
|
62
|
+
tts_base_url: str = Field(default="http://localhost:8002")
|
|
63
|
+
tts_default_voice: Optional[str] = Field(default=None)
|
|
64
|
+
tts_timeout: float = Field(default=120.0)
|
|
65
|
+
|
|
66
|
+
# --- Vector store (Qdrant) ---
|
|
67
|
+
qdrant_url: str = Field(default="http://localhost:6333")
|
|
68
|
+
qdrant_collection: str = Field(default="documents")
|
|
69
|
+
# Optional fixed vector size. When None, it is detected from the embedding model.
|
|
70
|
+
qdrant_vector_size: Optional[int] = Field(default=None)
|
|
71
|
+
# The qdrant-client refuses to talk to a server whose minor version differs by
|
|
72
|
+
# more than one, emitting a UserWarning. Self-hosted servers often lag the
|
|
73
|
+
# client; set False to silence the check when the API surface we use is stable.
|
|
74
|
+
qdrant_check_compatibility: bool = Field(default=True)
|
|
75
|
+
|
|
76
|
+
# --- Retriever cache (Redis, optional) ---
|
|
77
|
+
redis_url: Optional[str] = Field(default=None)
|
|
78
|
+
cache_ttl: int = Field(default=3600)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
_override: Optional[AIToolkitSettings] = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@lru_cache(maxsize=1)
|
|
85
|
+
def _env_settings() -> AIToolkitSettings:
|
|
86
|
+
"""Settings built from AITOOLKIT_* env vars (cached fallback)."""
|
|
87
|
+
return AIToolkitSettings()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_settings() -> AIToolkitSettings:
|
|
91
|
+
"""Return the process-wide settings.
|
|
92
|
+
|
|
93
|
+
Prefers an instance installed via :func:`configure`; otherwise builds one
|
|
94
|
+
from ``AITOOLKIT_*`` environment variables.
|
|
95
|
+
"""
|
|
96
|
+
return _override if _override is not None else _env_settings()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def configure(settings: AIToolkitSettings) -> None:
|
|
100
|
+
"""Install an explicit settings object as the process-wide singleton.
|
|
101
|
+
|
|
102
|
+
Lets a consuming application own configuration directly instead of relying
|
|
103
|
+
on environment variables and this package's generic defaults. Call once at
|
|
104
|
+
startup, before any client is created.
|
|
105
|
+
"""
|
|
106
|
+
global _override
|
|
107
|
+
_override = settings
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Embeddings client backed by an OpenAI-compatible embeddings server (TEI).
|
|
2
|
+
|
|
3
|
+
Async-first, with synchronous convenience wrappers so it can also satisfy the
|
|
4
|
+
LangChain ``Embeddings`` interface. The embedding dimension is detected at
|
|
5
|
+
runtime — never hardcoded — so swapping the underlying model requires no code
|
|
6
|
+
change here (only a re-index in the vector store).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from functools import lru_cache
|
|
12
|
+
from typing import List, Optional
|
|
13
|
+
|
|
14
|
+
from loguru import logger
|
|
15
|
+
from openai import AsyncOpenAI, OpenAI
|
|
16
|
+
|
|
17
|
+
from aitoolkit.config import AIToolkitSettings, get_settings
|
|
18
|
+
from aitoolkit.exceptions import EmbeddingsError
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EmbeddingsClient:
|
|
22
|
+
"""Create embeddings via an OpenAI-compatible ``/v1/embeddings`` endpoint."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
base_url: Optional[str] = None,
|
|
27
|
+
api_key: Optional[str] = None,
|
|
28
|
+
model: Optional[str] = None,
|
|
29
|
+
batch_size: Optional[int] = None,
|
|
30
|
+
timeout: Optional[float] = None,
|
|
31
|
+
settings: Optional[AIToolkitSettings] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
settings = settings or get_settings()
|
|
34
|
+
self.model = model or settings.embeddings_model
|
|
35
|
+
self.batch_size = batch_size or settings.embeddings_batch_size
|
|
36
|
+
self._base_url = base_url or settings.embeddings_base_url
|
|
37
|
+
self._api_key = api_key or settings.embeddings_api_key
|
|
38
|
+
self._timeout = timeout if timeout is not None else settings.embeddings_timeout
|
|
39
|
+
|
|
40
|
+
self._aclient = AsyncOpenAI(
|
|
41
|
+
base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
|
|
42
|
+
)
|
|
43
|
+
self._sclient: Optional[OpenAI] = None
|
|
44
|
+
self._dimension: Optional[int] = None
|
|
45
|
+
logger.info(
|
|
46
|
+
f"EmbeddingsClient ready (model={self.model}, base_url={self._base_url})"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def sync_client(self) -> OpenAI:
|
|
51
|
+
if self._sclient is None:
|
|
52
|
+
self._sclient = OpenAI(
|
|
53
|
+
base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
|
|
54
|
+
)
|
|
55
|
+
return self._sclient
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def dimension(self) -> Optional[int]:
|
|
59
|
+
"""The embedding dimension, known after the first call."""
|
|
60
|
+
return self._dimension
|
|
61
|
+
|
|
62
|
+
# --------------------------------------------------------------- helpers
|
|
63
|
+
@staticmethod
|
|
64
|
+
def _batched(items: List[str], size: int):
|
|
65
|
+
for i in range(0, len(items), size):
|
|
66
|
+
yield items[i : i + size]
|
|
67
|
+
|
|
68
|
+
def _record_dim(self, vectors: List[List[float]]) -> None:
|
|
69
|
+
if vectors and self._dimension is None:
|
|
70
|
+
self._dimension = len(vectors[0])
|
|
71
|
+
logger.debug(f"Detected embedding dimension: {self._dimension}")
|
|
72
|
+
|
|
73
|
+
# ----------------------------------------------------------------- async
|
|
74
|
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
75
|
+
"""Embed many documents, batching to respect server limits."""
|
|
76
|
+
if not texts:
|
|
77
|
+
return []
|
|
78
|
+
out: List[List[float]] = []
|
|
79
|
+
for batch in self._batched(texts, self.batch_size):
|
|
80
|
+
try:
|
|
81
|
+
resp = await self._aclient.embeddings.create(
|
|
82
|
+
model=self.model, input=batch
|
|
83
|
+
)
|
|
84
|
+
except Exception as exc: # noqa: BLE001
|
|
85
|
+
raise EmbeddingsError(f"embedding request failed: {exc}") from exc
|
|
86
|
+
# Preserve input order via the returned index.
|
|
87
|
+
ordered = sorted(resp.data, key=lambda d: d.index)
|
|
88
|
+
out.extend([d.embedding for d in ordered])
|
|
89
|
+
self._record_dim(out)
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
async def aembed_query(self, text: str) -> List[float]:
|
|
93
|
+
"""Embed a single query string."""
|
|
94
|
+
vectors = await self.aembed_documents([text])
|
|
95
|
+
if not vectors:
|
|
96
|
+
raise EmbeddingsError("empty embedding response for query")
|
|
97
|
+
return vectors[0]
|
|
98
|
+
|
|
99
|
+
# ------------------------------------------------------------------ sync
|
|
100
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
101
|
+
"""Synchronous document embedding (LangChain ``Embeddings`` interface)."""
|
|
102
|
+
if not texts:
|
|
103
|
+
return []
|
|
104
|
+
out: List[List[float]] = []
|
|
105
|
+
for batch in self._batched(texts, self.batch_size):
|
|
106
|
+
try:
|
|
107
|
+
resp = self.sync_client.embeddings.create(
|
|
108
|
+
model=self.model, input=batch
|
|
109
|
+
)
|
|
110
|
+
except Exception as exc: # noqa: BLE001
|
|
111
|
+
raise EmbeddingsError(f"embedding request failed: {exc}") from exc
|
|
112
|
+
ordered = sorted(resp.data, key=lambda d: d.index)
|
|
113
|
+
out.extend([d.embedding for d in ordered])
|
|
114
|
+
self._record_dim(out)
|
|
115
|
+
return out
|
|
116
|
+
|
|
117
|
+
def embed_query(self, text: str) -> List[float]:
|
|
118
|
+
"""Synchronous single-query embedding."""
|
|
119
|
+
vectors = self.embed_documents([text])
|
|
120
|
+
if not vectors:
|
|
121
|
+
raise EmbeddingsError("empty embedding response for query")
|
|
122
|
+
return vectors[0]
|
|
123
|
+
|
|
124
|
+
async def aclose(self) -> None:
|
|
125
|
+
await self._aclient.close()
|
|
126
|
+
if self._sclient is not None:
|
|
127
|
+
self._sclient.close()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@lru_cache(maxsize=1)
|
|
131
|
+
def get_embeddings_client() -> EmbeddingsClient:
|
|
132
|
+
"""Return the process-wide embeddings client singleton."""
|
|
133
|
+
return EmbeddingsClient()
|
aitoolkit/exceptions.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Exception hierarchy for aitoolkit.
|
|
2
|
+
|
|
3
|
+
A single base class lets callers catch every toolkit-originated error with one
|
|
4
|
+
``except AIToolkitError``, while specific subclasses allow fine-grained handling.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AIToolkitError(Exception):
|
|
11
|
+
"""Base class for all aitoolkit errors."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConfigurationError(AIToolkitError):
|
|
15
|
+
"""Raised when configuration is missing or invalid."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LLMError(AIToolkitError):
|
|
19
|
+
"""Raised when an LLM request fails or returns an unusable response."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class EmbeddingsError(AIToolkitError):
|
|
23
|
+
"""Raised when an embeddings request fails."""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class STTError(AIToolkitError):
|
|
27
|
+
"""Raised when speech-to-text transcription fails."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TTSError(AIToolkitError):
|
|
31
|
+
"""Raised when text-to-speech synthesis fails."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class VectorStoreError(AIToolkitError):
|
|
35
|
+
"""Raised when a vector-store operation fails."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Optional third-party integrations (each gated behind its own extra)."""
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""LangChain bridge (optional extra ``aitoolkit[langchain]``).
|
|
2
|
+
|
|
3
|
+
Provides:
|
|
4
|
+
|
|
5
|
+
* :func:`to_chat_model` — a LangChain ``BaseChatModel`` pointed at the toolkit's
|
|
6
|
+
LLM endpoint, so existing LangGraph graphs keep working unchanged.
|
|
7
|
+
* :class:`LangChainEmbeddings` — wraps :class:`EmbeddingsClient` as a LangChain
|
|
8
|
+
``Embeddings`` so vector stores / chains can consume toolkit embeddings.
|
|
9
|
+
|
|
10
|
+
Only this module imports LangChain; the toolkit core stays LangChain-free.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional
|
|
16
|
+
|
|
17
|
+
try: # pragma: no cover - import guard
|
|
18
|
+
from langchain_core.embeddings import Embeddings
|
|
19
|
+
from langchain_openai import ChatOpenAI
|
|
20
|
+
except ImportError as exc: # pragma: no cover
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"LangChain integration requires extra deps. "
|
|
23
|
+
"Install with: pip install 'aitoolkit[langchain]'"
|
|
24
|
+
) from exc
|
|
25
|
+
|
|
26
|
+
from aitoolkit.config import get_settings
|
|
27
|
+
from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def to_chat_model(
|
|
31
|
+
model: Optional[str] = None,
|
|
32
|
+
temperature: Optional[float] = None,
|
|
33
|
+
streaming: bool = False,
|
|
34
|
+
**kwargs,
|
|
35
|
+
) -> ChatOpenAI:
|
|
36
|
+
"""Return a LangChain ``ChatOpenAI`` bound to the toolkit's LLM endpoint."""
|
|
37
|
+
settings = get_settings()
|
|
38
|
+
return ChatOpenAI(
|
|
39
|
+
base_url=settings.llm_base_url,
|
|
40
|
+
api_key=settings.llm_api_key,
|
|
41
|
+
model=model or settings.llm_model,
|
|
42
|
+
temperature=temperature if temperature is not None else settings.llm_temperature,
|
|
43
|
+
streaming=streaming,
|
|
44
|
+
timeout=settings.llm_timeout,
|
|
45
|
+
max_retries=settings.llm_max_retries,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LangChainEmbeddings(Embeddings):
|
|
51
|
+
"""Adapt :class:`EmbeddingsClient` to the LangChain ``Embeddings`` interface."""
|
|
52
|
+
|
|
53
|
+
def __init__(self, client: Optional[EmbeddingsClient] = None) -> None:
|
|
54
|
+
self._client = client or get_embeddings_client()
|
|
55
|
+
|
|
56
|
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
57
|
+
return self._client.embed_documents(texts)
|
|
58
|
+
|
|
59
|
+
def embed_query(self, text: str) -> List[float]:
|
|
60
|
+
return self._client.embed_query(text)
|
|
61
|
+
|
|
62
|
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
63
|
+
return await self._client.aembed_documents(texts)
|
|
64
|
+
|
|
65
|
+
async def aembed_query(self, text: str) -> List[float]:
|
|
66
|
+
return await self._client.aembed_query(text)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
__all__ = ["to_chat_model", "LangChainEmbeddings"]
|
aitoolkit/llm/client.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""LLM client backed by any OpenAI-compatible server (e.g. vLLM).
|
|
2
|
+
|
|
3
|
+
Provides three primitives, all provider-agnostic:
|
|
4
|
+
|
|
5
|
+
* :meth:`LLMClient.chat` — single completion, returns text
|
|
6
|
+
* :meth:`LLMClient.stream` — async token stream
|
|
7
|
+
* :meth:`LLMClient.chat_structured` — completion validated into a Pydantic model
|
|
8
|
+
|
|
9
|
+
The client returns plain strings / Pydantic models and never leaks the underlying
|
|
10
|
+
``openai`` SDK types to callers.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
from functools import lru_cache
|
|
17
|
+
from typing import AsyncIterator, List, Optional, Type, TypeVar
|
|
18
|
+
|
|
19
|
+
from loguru import logger
|
|
20
|
+
from openai import AsyncOpenAI, OpenAI
|
|
21
|
+
from pydantic import BaseModel, ValidationError
|
|
22
|
+
|
|
23
|
+
from aitoolkit.config import AIToolkitSettings, get_settings
|
|
24
|
+
from aitoolkit.exceptions import LLMError
|
|
25
|
+
from aitoolkit.types import ChatMessage, as_messages
|
|
26
|
+
|
|
27
|
+
T = TypeVar("T", bound=BaseModel)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LLMClient:
|
|
31
|
+
"""A thin, stable wrapper over an OpenAI-compatible chat endpoint."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
base_url: Optional[str] = None,
|
|
36
|
+
api_key: Optional[str] = None,
|
|
37
|
+
model: Optional[str] = None,
|
|
38
|
+
temperature: Optional[float] = None,
|
|
39
|
+
timeout: Optional[float] = None,
|
|
40
|
+
max_retries: Optional[int] = None,
|
|
41
|
+
settings: Optional[AIToolkitSettings] = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
settings = settings or get_settings()
|
|
44
|
+
self.model = model or settings.llm_model
|
|
45
|
+
self.default_temperature = (
|
|
46
|
+
temperature if temperature is not None else settings.llm_temperature
|
|
47
|
+
)
|
|
48
|
+
self._base_url = base_url or settings.llm_base_url
|
|
49
|
+
self._api_key = api_key or settings.llm_api_key
|
|
50
|
+
self._timeout = timeout if timeout is not None else settings.llm_timeout
|
|
51
|
+
self._max_retries = (
|
|
52
|
+
max_retries if max_retries is not None else settings.llm_max_retries
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
self._aclient = AsyncOpenAI(
|
|
56
|
+
base_url=self._base_url,
|
|
57
|
+
api_key=self._api_key,
|
|
58
|
+
timeout=self._timeout,
|
|
59
|
+
max_retries=self._max_retries,
|
|
60
|
+
)
|
|
61
|
+
self._sclient: Optional[OpenAI] = None # created lazily for sync calls
|
|
62
|
+
logger.info(f"LLMClient ready (model={self.model}, base_url={self._base_url})")
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def sync_client(self) -> OpenAI:
|
|
66
|
+
"""Lazily-created synchronous client (for non-async call sites)."""
|
|
67
|
+
if self._sclient is None:
|
|
68
|
+
self._sclient = OpenAI(
|
|
69
|
+
base_url=self._base_url,
|
|
70
|
+
api_key=self._api_key,
|
|
71
|
+
timeout=self._timeout,
|
|
72
|
+
max_retries=self._max_retries,
|
|
73
|
+
)
|
|
74
|
+
return self._sclient
|
|
75
|
+
|
|
76
|
+
# ------------------------------------------------------------------ chat
|
|
77
|
+
async def chat(
|
|
78
|
+
self,
|
|
79
|
+
prompt: Optional[str] = None,
|
|
80
|
+
*,
|
|
81
|
+
system: Optional[str] = None,
|
|
82
|
+
messages: Optional[List[ChatMessage]] = None,
|
|
83
|
+
temperature: Optional[float] = None,
|
|
84
|
+
max_tokens: Optional[int] = None,
|
|
85
|
+
**kwargs,
|
|
86
|
+
) -> str:
|
|
87
|
+
"""Return a single completion as text."""
|
|
88
|
+
msgs = as_messages(prompt, system=system, messages=messages)
|
|
89
|
+
try:
|
|
90
|
+
resp = await self._aclient.chat.completions.create(
|
|
91
|
+
model=self.model,
|
|
92
|
+
messages=msgs, # type: ignore[arg-type]
|
|
93
|
+
temperature=self._temp(temperature),
|
|
94
|
+
max_tokens=max_tokens,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|
|
97
|
+
except Exception as exc: # noqa: BLE001 - surface as toolkit error
|
|
98
|
+
raise LLMError(f"chat completion failed: {exc}") from exc
|
|
99
|
+
|
|
100
|
+
return resp.choices[0].message.content or ""
|
|
101
|
+
|
|
102
|
+
def chat_sync(
|
|
103
|
+
self,
|
|
104
|
+
prompt: Optional[str] = None,
|
|
105
|
+
*,
|
|
106
|
+
system: Optional[str] = None,
|
|
107
|
+
messages: Optional[List[ChatMessage]] = None,
|
|
108
|
+
temperature: Optional[float] = None,
|
|
109
|
+
max_tokens: Optional[int] = None,
|
|
110
|
+
**kwargs,
|
|
111
|
+
) -> str:
|
|
112
|
+
"""Synchronous counterpart of :meth:`chat`."""
|
|
113
|
+
msgs = as_messages(prompt, system=system, messages=messages)
|
|
114
|
+
try:
|
|
115
|
+
resp = self.sync_client.chat.completions.create(
|
|
116
|
+
model=self.model,
|
|
117
|
+
messages=msgs, # type: ignore[arg-type]
|
|
118
|
+
temperature=self._temp(temperature),
|
|
119
|
+
max_tokens=max_tokens,
|
|
120
|
+
**kwargs,
|
|
121
|
+
)
|
|
122
|
+
except Exception as exc: # noqa: BLE001
|
|
123
|
+
raise LLMError(f"chat completion failed: {exc}") from exc
|
|
124
|
+
return resp.choices[0].message.content or ""
|
|
125
|
+
|
|
126
|
+
# ---------------------------------------------------------------- stream
|
|
127
|
+
async def stream(
|
|
128
|
+
self,
|
|
129
|
+
prompt: Optional[str] = None,
|
|
130
|
+
*,
|
|
131
|
+
system: Optional[str] = None,
|
|
132
|
+
messages: Optional[List[ChatMessage]] = None,
|
|
133
|
+
temperature: Optional[float] = None,
|
|
134
|
+
max_tokens: Optional[int] = None,
|
|
135
|
+
**kwargs,
|
|
136
|
+
) -> AsyncIterator[str]:
|
|
137
|
+
"""Yield completion text deltas as they arrive."""
|
|
138
|
+
msgs = as_messages(prompt, system=system, messages=messages)
|
|
139
|
+
try:
|
|
140
|
+
stream = await self._aclient.chat.completions.create(
|
|
141
|
+
model=self.model,
|
|
142
|
+
messages=msgs, # type: ignore[arg-type]
|
|
143
|
+
temperature=self._temp(temperature),
|
|
144
|
+
max_tokens=max_tokens,
|
|
145
|
+
stream=True,
|
|
146
|
+
**kwargs,
|
|
147
|
+
)
|
|
148
|
+
async for chunk in stream:
|
|
149
|
+
if not chunk.choices:
|
|
150
|
+
continue
|
|
151
|
+
delta = chunk.choices[0].delta.content
|
|
152
|
+
if delta:
|
|
153
|
+
yield delta
|
|
154
|
+
except Exception as exc: # noqa: BLE001
|
|
155
|
+
raise LLMError(f"streaming completion failed: {exc}") from exc
|
|
156
|
+
|
|
157
|
+
# ------------------------------------------------------------ structured
|
|
158
|
+
async def chat_structured(
|
|
159
|
+
self,
|
|
160
|
+
response_model: Type[T],
|
|
161
|
+
prompt: Optional[str] = None,
|
|
162
|
+
*,
|
|
163
|
+
system: Optional[str] = None,
|
|
164
|
+
messages: Optional[List[ChatMessage]] = None,
|
|
165
|
+
temperature: Optional[float] = None,
|
|
166
|
+
strict: bool = False,
|
|
167
|
+
**kwargs,
|
|
168
|
+
) -> T:
|
|
169
|
+
"""Return a completion validated into ``response_model``.
|
|
170
|
+
|
|
171
|
+
Uses the OpenAI ``response_format`` json_schema mechanism, which vLLM
|
|
172
|
+
implements via guided decoding. The raw JSON is validated with Pydantic,
|
|
173
|
+
so a malformed response raises :class:`LLMError` rather than passing
|
|
174
|
+
silently.
|
|
175
|
+
"""
|
|
176
|
+
msgs = as_messages(prompt, system=system, messages=messages)
|
|
177
|
+
schema = response_model.model_json_schema()
|
|
178
|
+
response_format = {
|
|
179
|
+
"type": "json_schema",
|
|
180
|
+
"json_schema": {
|
|
181
|
+
"name": response_model.__name__,
|
|
182
|
+
"schema": schema,
|
|
183
|
+
"strict": strict,
|
|
184
|
+
},
|
|
185
|
+
}
|
|
186
|
+
try:
|
|
187
|
+
resp = await self._aclient.chat.completions.create(
|
|
188
|
+
model=self.model,
|
|
189
|
+
messages=msgs, # type: ignore[arg-type]
|
|
190
|
+
temperature=self._temp(temperature),
|
|
191
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
192
|
+
**kwargs,
|
|
193
|
+
)
|
|
194
|
+
except Exception as exc: # noqa: BLE001
|
|
195
|
+
raise LLMError(f"structured completion request failed: {exc}") from exc
|
|
196
|
+
|
|
197
|
+
content = resp.choices[0].message.content or ""
|
|
198
|
+
try:
|
|
199
|
+
return response_model.model_validate_json(content)
|
|
200
|
+
except (ValidationError, json.JSONDecodeError) as exc:
|
|
201
|
+
raise LLMError(
|
|
202
|
+
f"structured output did not match {response_model.__name__}: {exc}\n"
|
|
203
|
+
f"raw content: {content[:500]}"
|
|
204
|
+
) from exc
|
|
205
|
+
|
|
206
|
+
# ----------------------------------------------------------------- utils
|
|
207
|
+
def _temp(self, temperature: Optional[float]) -> float:
|
|
208
|
+
return temperature if temperature is not None else self.default_temperature
|
|
209
|
+
|
|
210
|
+
async def aclose(self) -> None:
|
|
211
|
+
await self._aclient.close()
|
|
212
|
+
if self._sclient is not None:
|
|
213
|
+
self._sclient.close()
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@lru_cache(maxsize=8)
|
|
217
|
+
def _cached_client(model: Optional[str], temperature: Optional[float]) -> LLMClient:
|
|
218
|
+
return LLMClient(model=model, temperature=temperature)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def get_llm_client(
|
|
222
|
+
model: Optional[str] = None,
|
|
223
|
+
temperature: Optional[float] = None,
|
|
224
|
+
) -> LLMClient:
|
|
225
|
+
"""Return a cached :class:`LLMClient` for the given model/temperature.
|
|
226
|
+
|
|
227
|
+
Caching mirrors the previous ``llm.py`` behaviour and avoids re-creating
|
|
228
|
+
HTTP clients on every request.
|
|
229
|
+
"""
|
|
230
|
+
return _cached_client(model, temperature)
|
aitoolkit/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""RAG capability (optional extra ``aitoolkit[rag]``).
|
|
2
|
+
|
|
3
|
+
Imports require ``qdrant-client``. Install with ``pip install 'aitoolkit[rag]'``.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
try: # pragma: no cover - import guard
|
|
7
|
+
import qdrant_client # noqa: F401
|
|
8
|
+
except ImportError as exc: # pragma: no cover
|
|
9
|
+
raise ImportError(
|
|
10
|
+
"The RAG module requires qdrant-client. Install with: pip install 'aitoolkit[rag]'"
|
|
11
|
+
) from exc
|
|
12
|
+
|
|
13
|
+
from aitoolkit.rag.agent import RAGAgent, get_rag_agent
|
|
14
|
+
from aitoolkit.rag.query_expansion import QueryExpander, get_query_expander
|
|
15
|
+
from aitoolkit.rag.retriever import RAGRetriever
|
|
16
|
+
from aitoolkit.rag.vector_store import UnifiedVectorStore
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"RAGAgent",
|
|
20
|
+
"get_rag_agent",
|
|
21
|
+
"RAGRetriever",
|
|
22
|
+
"UnifiedVectorStore",
|
|
23
|
+
"QueryExpander",
|
|
24
|
+
"get_query_expander",
|
|
25
|
+
]
|