ff-aitoolkit 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aitoolkit/__init__.py ADDED
@@ -0,0 +1,66 @@
1
+ """aitoolkit — centralized AI clients for self-hosted OpenAI-compatible services.
2
+
3
+ Core capabilities (LLM, embeddings, STT, TTS) are always available. RAG and the
4
+ LangChain bridge live behind extras and are imported from their own subpackages
5
+ (``aitoolkit.rag``, ``aitoolkit.integrations.langchain``) so that importing the
6
+ top-level package never forces an optional dependency.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from aitoolkit.config import AIToolkitSettings, configure, get_settings
12
+ from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
13
+ from aitoolkit.exceptions import (
14
+ AIToolkitError,
15
+ ConfigurationError,
16
+ EmbeddingsError,
17
+ LLMError,
18
+ STTError,
19
+ TTSError,
20
+ VectorStoreError,
21
+ )
22
+ from aitoolkit.llm import LLMClient, get_llm_client
23
+ from aitoolkit.stt import STTClient, get_stt_client
24
+ from aitoolkit.tts import TTSClient, concat_wav, get_tts_client
25
+ from aitoolkit.types import (
26
+ ChatMessage,
27
+ DialogueTurn,
28
+ RetrievedChunk,
29
+ TranscriptionResult,
30
+ )
31
+
32
+ __version__ = "0.2.0"
33
+
34
+ __all__ = [
35
+ "__version__",
36
+ # config
37
+ "AIToolkitSettings",
38
+ "configure",
39
+ "get_settings",
40
+ # llm
41
+ "LLMClient",
42
+ "get_llm_client",
43
+ # embeddings
44
+ "EmbeddingsClient",
45
+ "get_embeddings_client",
46
+ # stt
47
+ "STTClient",
48
+ "get_stt_client",
49
+ # tts
50
+ "TTSClient",
51
+ "get_tts_client",
52
+ "concat_wav",
53
+ # types
54
+ "ChatMessage",
55
+ "DialogueTurn",
56
+ "RetrievedChunk",
57
+ "TranscriptionResult",
58
+ # exceptions
59
+ "AIToolkitError",
60
+ "ConfigurationError",
61
+ "LLMError",
62
+ "EmbeddingsError",
63
+ "STTError",
64
+ "TTSError",
65
+ "VectorStoreError",
66
+ ]
aitoolkit/config.py ADDED
@@ -0,0 +1,107 @@
1
+ """Central configuration for aitoolkit.
2
+
3
+ All settings are read from ``AITOOLKIT_*`` environment variables (or a ``.env``
4
+ file) but every client also accepts explicit overrides, so the package can be
5
+ used with zero environment configuration.
6
+
7
+ Defaults intentionally point at ``localhost`` — they are NOT specific to any one
8
+ deployment. Production endpoints are supplied by the consuming application via
9
+ environment variables (see ``.env.example``).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from functools import lru_cache
15
+ from typing import Optional
16
+
17
+ from pydantic import Field
18
+ from pydantic_settings import BaseSettings, SettingsConfigDict
19
+
20
+ # A placeholder key for OpenAI-compatible servers that perform no app-layer auth
21
+ # (our GPU services are firewalled, not key-gated). The openai SDK requires a
22
+ # non-empty key, so we provide one.
23
+ _NO_AUTH = "no-auth"
24
+
25
+
26
+ class AIToolkitSettings(BaseSettings):
27
+ """Runtime configuration for every aitoolkit capability."""
28
+
29
+ model_config = SettingsConfigDict(
30
+ env_prefix="AITOOLKIT_",
31
+ env_file=".env",
32
+ env_file_encoding="utf-8",
33
+ extra="ignore",
34
+ )
35
+
36
+ # --- LLM (vLLM, OpenAI-compatible) ---
37
+ llm_base_url: str = Field(default="http://localhost:8000/v1")
38
+ llm_api_key: str = Field(default=_NO_AUTH)
39
+ # No default model id — set AITOOLKIT_LLM_MODEL to your served model.
40
+ llm_model: str = Field(default="")
41
+ llm_temperature: float = Field(default=0.2)
42
+ llm_timeout: float = Field(default=60.0)
43
+ llm_max_retries: int = Field(default=2)
44
+
45
+ # --- Embeddings (TEI, OpenAI-compatible) ---
46
+ embeddings_base_url: str = Field(default="http://localhost:8001/v1")
47
+ embeddings_api_key: str = Field(default=_NO_AUTH)
48
+ # No default model id — set AITOOLKIT_EMBEDDINGS_MODEL to your served model.
49
+ embeddings_model: str = Field(default="")
50
+ # TEI accepts modest batches; keep conservative and configurable.
51
+ embeddings_batch_size: int = Field(default=32)
52
+ embeddings_timeout: float = Field(default=60.0)
53
+
54
+ # --- Speech-to-Text (faster-whisper, OpenAI-compatible) ---
55
+ stt_base_url: str = Field(default="http://localhost:8003/v1")
56
+ stt_api_key: str = Field(default=_NO_AUTH)
57
+ stt_model: str = Field(default="whisper-1")
58
+ stt_language: Optional[str] = Field(default=None)
59
+ stt_timeout: float = Field(default=120.0)
60
+
61
+ # --- Text-to-Speech (custom /api/tts) ---
62
+ tts_base_url: str = Field(default="http://localhost:8002")
63
+ tts_default_voice: Optional[str] = Field(default=None)
64
+ tts_timeout: float = Field(default=120.0)
65
+
66
+ # --- Vector store (Qdrant) ---
67
+ qdrant_url: str = Field(default="http://localhost:6333")
68
+ qdrant_collection: str = Field(default="documents")
69
+ # Optional fixed vector size. When None, it is detected from the embedding model.
70
+ qdrant_vector_size: Optional[int] = Field(default=None)
71
+ # The qdrant-client refuses to talk to a server whose minor version differs by
72
+ # more than one, emitting a UserWarning. Self-hosted servers often lag the
73
+ # client; set False to silence the check when the API surface we use is stable.
74
+ qdrant_check_compatibility: bool = Field(default=True)
75
+
76
+ # --- Retriever cache (Redis, optional) ---
77
+ redis_url: Optional[str] = Field(default=None)
78
+ cache_ttl: int = Field(default=3600)
79
+
80
+
81
+ _override: Optional[AIToolkitSettings] = None
82
+
83
+
84
+ @lru_cache(maxsize=1)
85
+ def _env_settings() -> AIToolkitSettings:
86
+ """Settings built from AITOOLKIT_* env vars (cached fallback)."""
87
+ return AIToolkitSettings()
88
+
89
+
90
+ def get_settings() -> AIToolkitSettings:
91
+ """Return the process-wide settings.
92
+
93
+ Prefers an instance installed via :func:`configure`; otherwise builds one
94
+ from ``AITOOLKIT_*`` environment variables.
95
+ """
96
+ return _override if _override is not None else _env_settings()
97
+
98
+
99
+ def configure(settings: AIToolkitSettings) -> None:
100
+ """Install an explicit settings object as the process-wide singleton.
101
+
102
+ Lets a consuming application own configuration directly instead of relying
103
+ on environment variables and this package's generic defaults. Call once at
104
+ startup, before any client is created.
105
+ """
106
+ global _override
107
+ _override = settings
@@ -0,0 +1,5 @@
1
+ """Embeddings capability — OpenAI-compatible embeddings (e.g. TEI)."""
2
+
3
+ from aitoolkit.embeddings.client import EmbeddingsClient, get_embeddings_client
4
+
5
+ __all__ = ["EmbeddingsClient", "get_embeddings_client"]
@@ -0,0 +1,133 @@
1
+ """Embeddings client backed by an OpenAI-compatible embeddings server (TEI).
2
+
3
+ Async-first, with synchronous convenience wrappers so it can also satisfy the
4
+ LangChain ``Embeddings`` interface. The embedding dimension is detected at
5
+ runtime — never hardcoded — so swapping the underlying model requires no code
6
+ change here (only a re-index in the vector store).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from functools import lru_cache
12
+ from typing import List, Optional
13
+
14
+ from loguru import logger
15
+ from openai import AsyncOpenAI, OpenAI
16
+
17
+ from aitoolkit.config import AIToolkitSettings, get_settings
18
+ from aitoolkit.exceptions import EmbeddingsError
19
+
20
+
21
+ class EmbeddingsClient:
22
+ """Create embeddings via an OpenAI-compatible ``/v1/embeddings`` endpoint."""
23
+
24
+ def __init__(
25
+ self,
26
+ base_url: Optional[str] = None,
27
+ api_key: Optional[str] = None,
28
+ model: Optional[str] = None,
29
+ batch_size: Optional[int] = None,
30
+ timeout: Optional[float] = None,
31
+ settings: Optional[AIToolkitSettings] = None,
32
+ ) -> None:
33
+ settings = settings or get_settings()
34
+ self.model = model or settings.embeddings_model
35
+ self.batch_size = batch_size or settings.embeddings_batch_size
36
+ self._base_url = base_url or settings.embeddings_base_url
37
+ self._api_key = api_key or settings.embeddings_api_key
38
+ self._timeout = timeout if timeout is not None else settings.embeddings_timeout
39
+
40
+ self._aclient = AsyncOpenAI(
41
+ base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
42
+ )
43
+ self._sclient: Optional[OpenAI] = None
44
+ self._dimension: Optional[int] = None
45
+ logger.info(
46
+ f"EmbeddingsClient ready (model={self.model}, base_url={self._base_url})"
47
+ )
48
+
49
+ @property
50
+ def sync_client(self) -> OpenAI:
51
+ if self._sclient is None:
52
+ self._sclient = OpenAI(
53
+ base_url=self._base_url, api_key=self._api_key, timeout=self._timeout
54
+ )
55
+ return self._sclient
56
+
57
+ @property
58
+ def dimension(self) -> Optional[int]:
59
+ """The embedding dimension, known after the first call."""
60
+ return self._dimension
61
+
62
+ # --------------------------------------------------------------- helpers
63
+ @staticmethod
64
+ def _batched(items: List[str], size: int):
65
+ for i in range(0, len(items), size):
66
+ yield items[i : i + size]
67
+
68
+ def _record_dim(self, vectors: List[List[float]]) -> None:
69
+ if vectors and self._dimension is None:
70
+ self._dimension = len(vectors[0])
71
+ logger.debug(f"Detected embedding dimension: {self._dimension}")
72
+
73
+ # ----------------------------------------------------------------- async
74
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
75
+ """Embed many documents, batching to respect server limits."""
76
+ if not texts:
77
+ return []
78
+ out: List[List[float]] = []
79
+ for batch in self._batched(texts, self.batch_size):
80
+ try:
81
+ resp = await self._aclient.embeddings.create(
82
+ model=self.model, input=batch
83
+ )
84
+ except Exception as exc: # noqa: BLE001
85
+ raise EmbeddingsError(f"embedding request failed: {exc}") from exc
86
+ # Preserve input order via the returned index.
87
+ ordered = sorted(resp.data, key=lambda d: d.index)
88
+ out.extend([d.embedding for d in ordered])
89
+ self._record_dim(out)
90
+ return out
91
+
92
+ async def aembed_query(self, text: str) -> List[float]:
93
+ """Embed a single query string."""
94
+ vectors = await self.aembed_documents([text])
95
+ if not vectors:
96
+ raise EmbeddingsError("empty embedding response for query")
97
+ return vectors[0]
98
+
99
+ # ------------------------------------------------------------------ sync
100
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
101
+ """Synchronous document embedding (LangChain ``Embeddings`` interface)."""
102
+ if not texts:
103
+ return []
104
+ out: List[List[float]] = []
105
+ for batch in self._batched(texts, self.batch_size):
106
+ try:
107
+ resp = self.sync_client.embeddings.create(
108
+ model=self.model, input=batch
109
+ )
110
+ except Exception as exc: # noqa: BLE001
111
+ raise EmbeddingsError(f"embedding request failed: {exc}") from exc
112
+ ordered = sorted(resp.data, key=lambda d: d.index)
113
+ out.extend([d.embedding for d in ordered])
114
+ self._record_dim(out)
115
+ return out
116
+
117
+ def embed_query(self, text: str) -> List[float]:
118
+ """Synchronous single-query embedding."""
119
+ vectors = self.embed_documents([text])
120
+ if not vectors:
121
+ raise EmbeddingsError("empty embedding response for query")
122
+ return vectors[0]
123
+
124
+ async def aclose(self) -> None:
125
+ await self._aclient.close()
126
+ if self._sclient is not None:
127
+ self._sclient.close()
128
+
129
+
130
+ @lru_cache(maxsize=1)
131
+ def get_embeddings_client() -> EmbeddingsClient:
132
+ """Return the process-wide embeddings client singleton."""
133
+ return EmbeddingsClient()
@@ -0,0 +1,35 @@
1
+ """Exception hierarchy for aitoolkit.
2
+
3
+ A single base class lets callers catch every toolkit-originated error with one
4
+ ``except AIToolkitError``, while specific subclasses allow fine-grained handling.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+
10
+ class AIToolkitError(Exception):
11
+ """Base class for all aitoolkit errors."""
12
+
13
+
14
+ class ConfigurationError(AIToolkitError):
15
+ """Raised when configuration is missing or invalid."""
16
+
17
+
18
+ class LLMError(AIToolkitError):
19
+ """Raised when an LLM request fails or returns an unusable response."""
20
+
21
+
22
+ class EmbeddingsError(AIToolkitError):
23
+ """Raised when an embeddings request fails."""
24
+
25
+
26
+ class STTError(AIToolkitError):
27
+ """Raised when speech-to-text transcription fails."""
28
+
29
+
30
+ class TTSError(AIToolkitError):
31
+ """Raised when text-to-speech synthesis fails."""
32
+
33
+
34
+ class VectorStoreError(AIToolkitError):
35
+ """Raised when a vector-store operation fails."""
@@ -0,0 +1 @@
1
+ """Optional third-party integrations (each gated behind its own extra)."""
@@ -0,0 +1,69 @@
1
+ """LangChain bridge (optional extra ``aitoolkit[langchain]``).
2
+
3
+ Provides:
4
+
5
+ * :func:`to_chat_model` — a LangChain ``BaseChatModel`` pointed at the toolkit's
6
+ LLM endpoint, so existing LangGraph graphs keep working unchanged.
7
+ * :class:`LangChainEmbeddings` — wraps :class:`EmbeddingsClient` as a LangChain
8
+ ``Embeddings`` so vector stores / chains can consume toolkit embeddings.
9
+
10
+ Only this module imports LangChain; the toolkit core stays LangChain-free.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import List, Optional
16
+
17
+ try: # pragma: no cover - import guard
18
+ from langchain_core.embeddings import Embeddings
19
+ from langchain_openai import ChatOpenAI
20
+ except ImportError as exc: # pragma: no cover
21
+ raise ImportError(
22
+ "LangChain integration requires extra deps. "
23
+ "Install with: pip install 'aitoolkit[langchain]'"
24
+ ) from exc
25
+
26
+ from aitoolkit.config import get_settings
27
+ from aitoolkit.embeddings import EmbeddingsClient, get_embeddings_client
28
+
29
+
30
+ def to_chat_model(
31
+ model: Optional[str] = None,
32
+ temperature: Optional[float] = None,
33
+ streaming: bool = False,
34
+ **kwargs,
35
+ ) -> ChatOpenAI:
36
+ """Return a LangChain ``ChatOpenAI`` bound to the toolkit's LLM endpoint."""
37
+ settings = get_settings()
38
+ return ChatOpenAI(
39
+ base_url=settings.llm_base_url,
40
+ api_key=settings.llm_api_key,
41
+ model=model or settings.llm_model,
42
+ temperature=temperature if temperature is not None else settings.llm_temperature,
43
+ streaming=streaming,
44
+ timeout=settings.llm_timeout,
45
+ max_retries=settings.llm_max_retries,
46
+ **kwargs,
47
+ )
48
+
49
+
50
+ class LangChainEmbeddings(Embeddings):
51
+ """Adapt :class:`EmbeddingsClient` to the LangChain ``Embeddings`` interface."""
52
+
53
+ def __init__(self, client: Optional[EmbeddingsClient] = None) -> None:
54
+ self._client = client or get_embeddings_client()
55
+
56
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
57
+ return self._client.embed_documents(texts)
58
+
59
+ def embed_query(self, text: str) -> List[float]:
60
+ return self._client.embed_query(text)
61
+
62
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
63
+ return await self._client.aembed_documents(texts)
64
+
65
+ async def aembed_query(self, text: str) -> List[float]:
66
+ return await self._client.aembed_query(text)
67
+
68
+
69
+ __all__ = ["to_chat_model", "LangChainEmbeddings"]
@@ -0,0 +1,5 @@
1
+ """LLM capability — OpenAI-compatible chat, streaming, and structured output."""
2
+
3
+ from aitoolkit.llm.client import LLMClient, get_llm_client
4
+
5
+ __all__ = ["LLMClient", "get_llm_client"]
@@ -0,0 +1,230 @@
1
+ """LLM client backed by any OpenAI-compatible server (e.g. vLLM).
2
+
3
+ Provides three primitives, all provider-agnostic:
4
+
5
+ * :meth:`LLMClient.chat` — single completion, returns text
6
+ * :meth:`LLMClient.stream` — async token stream
7
+ * :meth:`LLMClient.chat_structured` — completion validated into a Pydantic model
8
+
9
+ The client returns plain strings / Pydantic models and never leaks the underlying
10
+ ``openai`` SDK types to callers.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ from functools import lru_cache
17
+ from typing import AsyncIterator, List, Optional, Type, TypeVar
18
+
19
+ from loguru import logger
20
+ from openai import AsyncOpenAI, OpenAI
21
+ from pydantic import BaseModel, ValidationError
22
+
23
+ from aitoolkit.config import AIToolkitSettings, get_settings
24
+ from aitoolkit.exceptions import LLMError
25
+ from aitoolkit.types import ChatMessage, as_messages
26
+
27
+ T = TypeVar("T", bound=BaseModel)
28
+
29
+
30
+ class LLMClient:
31
+ """A thin, stable wrapper over an OpenAI-compatible chat endpoint."""
32
+
33
+ def __init__(
34
+ self,
35
+ base_url: Optional[str] = None,
36
+ api_key: Optional[str] = None,
37
+ model: Optional[str] = None,
38
+ temperature: Optional[float] = None,
39
+ timeout: Optional[float] = None,
40
+ max_retries: Optional[int] = None,
41
+ settings: Optional[AIToolkitSettings] = None,
42
+ ) -> None:
43
+ settings = settings or get_settings()
44
+ self.model = model or settings.llm_model
45
+ self.default_temperature = (
46
+ temperature if temperature is not None else settings.llm_temperature
47
+ )
48
+ self._base_url = base_url or settings.llm_base_url
49
+ self._api_key = api_key or settings.llm_api_key
50
+ self._timeout = timeout if timeout is not None else settings.llm_timeout
51
+ self._max_retries = (
52
+ max_retries if max_retries is not None else settings.llm_max_retries
53
+ )
54
+
55
+ self._aclient = AsyncOpenAI(
56
+ base_url=self._base_url,
57
+ api_key=self._api_key,
58
+ timeout=self._timeout,
59
+ max_retries=self._max_retries,
60
+ )
61
+ self._sclient: Optional[OpenAI] = None # created lazily for sync calls
62
+ logger.info(f"LLMClient ready (model={self.model}, base_url={self._base_url})")
63
+
64
+ @property
65
+ def sync_client(self) -> OpenAI:
66
+ """Lazily-created synchronous client (for non-async call sites)."""
67
+ if self._sclient is None:
68
+ self._sclient = OpenAI(
69
+ base_url=self._base_url,
70
+ api_key=self._api_key,
71
+ timeout=self._timeout,
72
+ max_retries=self._max_retries,
73
+ )
74
+ return self._sclient
75
+
76
+ # ------------------------------------------------------------------ chat
77
+ async def chat(
78
+ self,
79
+ prompt: Optional[str] = None,
80
+ *,
81
+ system: Optional[str] = None,
82
+ messages: Optional[List[ChatMessage]] = None,
83
+ temperature: Optional[float] = None,
84
+ max_tokens: Optional[int] = None,
85
+ **kwargs,
86
+ ) -> str:
87
+ """Return a single completion as text."""
88
+ msgs = as_messages(prompt, system=system, messages=messages)
89
+ try:
90
+ resp = await self._aclient.chat.completions.create(
91
+ model=self.model,
92
+ messages=msgs, # type: ignore[arg-type]
93
+ temperature=self._temp(temperature),
94
+ max_tokens=max_tokens,
95
+ **kwargs,
96
+ )
97
+ except Exception as exc: # noqa: BLE001 - surface as toolkit error
98
+ raise LLMError(f"chat completion failed: {exc}") from exc
99
+
100
+ return resp.choices[0].message.content or ""
101
+
102
+ def chat_sync(
103
+ self,
104
+ prompt: Optional[str] = None,
105
+ *,
106
+ system: Optional[str] = None,
107
+ messages: Optional[List[ChatMessage]] = None,
108
+ temperature: Optional[float] = None,
109
+ max_tokens: Optional[int] = None,
110
+ **kwargs,
111
+ ) -> str:
112
+ """Synchronous counterpart of :meth:`chat`."""
113
+ msgs = as_messages(prompt, system=system, messages=messages)
114
+ try:
115
+ resp = self.sync_client.chat.completions.create(
116
+ model=self.model,
117
+ messages=msgs, # type: ignore[arg-type]
118
+ temperature=self._temp(temperature),
119
+ max_tokens=max_tokens,
120
+ **kwargs,
121
+ )
122
+ except Exception as exc: # noqa: BLE001
123
+ raise LLMError(f"chat completion failed: {exc}") from exc
124
+ return resp.choices[0].message.content or ""
125
+
126
+ # ---------------------------------------------------------------- stream
127
+ async def stream(
128
+ self,
129
+ prompt: Optional[str] = None,
130
+ *,
131
+ system: Optional[str] = None,
132
+ messages: Optional[List[ChatMessage]] = None,
133
+ temperature: Optional[float] = None,
134
+ max_tokens: Optional[int] = None,
135
+ **kwargs,
136
+ ) -> AsyncIterator[str]:
137
+ """Yield completion text deltas as they arrive."""
138
+ msgs = as_messages(prompt, system=system, messages=messages)
139
+ try:
140
+ stream = await self._aclient.chat.completions.create(
141
+ model=self.model,
142
+ messages=msgs, # type: ignore[arg-type]
143
+ temperature=self._temp(temperature),
144
+ max_tokens=max_tokens,
145
+ stream=True,
146
+ **kwargs,
147
+ )
148
+ async for chunk in stream:
149
+ if not chunk.choices:
150
+ continue
151
+ delta = chunk.choices[0].delta.content
152
+ if delta:
153
+ yield delta
154
+ except Exception as exc: # noqa: BLE001
155
+ raise LLMError(f"streaming completion failed: {exc}") from exc
156
+
157
+ # ------------------------------------------------------------ structured
158
+ async def chat_structured(
159
+ self,
160
+ response_model: Type[T],
161
+ prompt: Optional[str] = None,
162
+ *,
163
+ system: Optional[str] = None,
164
+ messages: Optional[List[ChatMessage]] = None,
165
+ temperature: Optional[float] = None,
166
+ strict: bool = False,
167
+ **kwargs,
168
+ ) -> T:
169
+ """Return a completion validated into ``response_model``.
170
+
171
+ Uses the OpenAI ``response_format`` json_schema mechanism, which vLLM
172
+ implements via guided decoding. The raw JSON is validated with Pydantic,
173
+ so a malformed response raises :class:`LLMError` rather than passing
174
+ silently.
175
+ """
176
+ msgs = as_messages(prompt, system=system, messages=messages)
177
+ schema = response_model.model_json_schema()
178
+ response_format = {
179
+ "type": "json_schema",
180
+ "json_schema": {
181
+ "name": response_model.__name__,
182
+ "schema": schema,
183
+ "strict": strict,
184
+ },
185
+ }
186
+ try:
187
+ resp = await self._aclient.chat.completions.create(
188
+ model=self.model,
189
+ messages=msgs, # type: ignore[arg-type]
190
+ temperature=self._temp(temperature),
191
+ response_format=response_format, # type: ignore[arg-type]
192
+ **kwargs,
193
+ )
194
+ except Exception as exc: # noqa: BLE001
195
+ raise LLMError(f"structured completion request failed: {exc}") from exc
196
+
197
+ content = resp.choices[0].message.content or ""
198
+ try:
199
+ return response_model.model_validate_json(content)
200
+ except (ValidationError, json.JSONDecodeError) as exc:
201
+ raise LLMError(
202
+ f"structured output did not match {response_model.__name__}: {exc}\n"
203
+ f"raw content: {content[:500]}"
204
+ ) from exc
205
+
206
+ # ----------------------------------------------------------------- utils
207
+ def _temp(self, temperature: Optional[float]) -> float:
208
+ return temperature if temperature is not None else self.default_temperature
209
+
210
+ async def aclose(self) -> None:
211
+ await self._aclient.close()
212
+ if self._sclient is not None:
213
+ self._sclient.close()
214
+
215
+
216
+ @lru_cache(maxsize=8)
217
+ def _cached_client(model: Optional[str], temperature: Optional[float]) -> LLMClient:
218
+ return LLMClient(model=model, temperature=temperature)
219
+
220
+
221
+ def get_llm_client(
222
+ model: Optional[str] = None,
223
+ temperature: Optional[float] = None,
224
+ ) -> LLMClient:
225
+ """Return a cached :class:`LLMClient` for the given model/temperature.
226
+
227
+ Caching mirrors the previous ``llm.py`` behaviour and avoids re-creating
228
+ HTTP clients on every request.
229
+ """
230
+ return _cached_client(model, temperature)
aitoolkit/py.typed ADDED
File without changes
@@ -0,0 +1,25 @@
1
+ """RAG capability (optional extra ``aitoolkit[rag]``).
2
+
3
+ Imports require ``qdrant-client``. Install with ``pip install 'aitoolkit[rag]'``.
4
+ """
5
+
6
+ try: # pragma: no cover - import guard
7
+ import qdrant_client # noqa: F401
8
+ except ImportError as exc: # pragma: no cover
9
+ raise ImportError(
10
+ "The RAG module requires qdrant-client. Install with: pip install 'aitoolkit[rag]'"
11
+ ) from exc
12
+
13
+ from aitoolkit.rag.agent import RAGAgent, get_rag_agent
14
+ from aitoolkit.rag.query_expansion import QueryExpander, get_query_expander
15
+ from aitoolkit.rag.retriever import RAGRetriever
16
+ from aitoolkit.rag.vector_store import UnifiedVectorStore
17
+
18
+ __all__ = [
19
+ "RAGAgent",
20
+ "get_rag_agent",
21
+ "RAGRetriever",
22
+ "UnifiedVectorStore",
23
+ "QueryExpander",
24
+ "get_query_expander",
25
+ ]