saia-python 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- saia_python/__init__.py +253 -0
- saia_python/_http.py +71 -0
- saia_python/_streaming.py +88 -0
- saia_python/_util.py +29 -0
- saia_python/arcana.py +1061 -0
- saia_python/arcana_references.py +182 -0
- saia_python/auth.py +515 -0
- saia_python/chat.py +72 -0
- saia_python/client.py +239 -0
- saia_python/documents.py +145 -0
- saia_python/exceptions.py +68 -0
- saia_python/models.py +146 -0
- saia_python/openai_compat.py +70 -0
- saia_python/py.typed +0 -0
- saia_python/rate_limits.py +84 -0
- saia_python/responses.py +70 -0
- saia_python/voice.py +175 -0
- saia_python-0.4.1.dist-info/METADATA +190 -0
- saia_python-0.4.1.dist-info/RECORD +22 -0
- saia_python-0.4.1.dist-info/WHEEL +5 -0
- saia_python-0.4.1.dist-info/licenses/LICENSE +661 -0
- saia_python-0.4.1.dist-info/top_level.txt +1 -0
saia_python/chat.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Chat service — completions and streaming."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from ._http import post_chat_completion
|
|
8
|
+
from ._streaming import SSEStream
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ChatService:
|
|
15
|
+
"""Access the ``/chat/completions`` endpoint.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
session: A :class:`requests.Session` with auth headers configured.
|
|
19
|
+
base_url: The SAIA API base URL.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, session: requests.Session, base_url: str):
|
|
23
|
+
self._session = session
|
|
24
|
+
self._base_url = base_url
|
|
25
|
+
|
|
26
|
+
def completions(
|
|
27
|
+
self,
|
|
28
|
+
model: str,
|
|
29
|
+
messages: list[dict],
|
|
30
|
+
*,
|
|
31
|
+
temperature: float | None = None,
|
|
32
|
+
top_p: float | None = None,
|
|
33
|
+
max_tokens: int | None = None,
|
|
34
|
+
stream: bool = False,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> dict | SSEStream:
|
|
37
|
+
"""Send a chat completion request.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: Model identifier (e.g. ``"meta-llama-3.1-8b-instruct"``).
|
|
41
|
+
messages: List of message dicts with ``"role"`` and ``"content"`` keys.
|
|
42
|
+
temperature: Sampling temperature (0–2).
|
|
43
|
+
top_p: Nucleus sampling parameter (0–1).
|
|
44
|
+
max_tokens: Maximum tokens to generate.
|
|
45
|
+
stream: If ``True``, return a generator yielding chunks.
|
|
46
|
+
**kwargs: Additional parameters forwarded to the API.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
When ``stream=False``: the API response dict, with an extra
|
|
50
|
+
``"_rate_limits"`` key — a JSON-serializable dict of the current
|
|
51
|
+
rate-limit headers (see :class:`~saia_python.RateLimitInfo`).
|
|
52
|
+
When ``stream=True``: an ``SSEStream`` — iterate it for the
|
|
53
|
+
response chunks; its ``rate_limits`` attribute exposes the same
|
|
54
|
+
dict (available immediately, from the response headers).
|
|
55
|
+
"""
|
|
56
|
+
body = {"model": model, "messages": messages, **kwargs}
|
|
57
|
+
if temperature is not None:
|
|
58
|
+
body["temperature"] = temperature
|
|
59
|
+
if top_p is not None:
|
|
60
|
+
body["top_p"] = top_p
|
|
61
|
+
if max_tokens is not None:
|
|
62
|
+
body["max_tokens"] = max_tokens
|
|
63
|
+
|
|
64
|
+
return post_chat_completion(
|
|
65
|
+
self._session,
|
|
66
|
+
f"{self._base_url}/chat/completions",
|
|
67
|
+
body,
|
|
68
|
+
stream=stream,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def __repr__(self):
|
|
72
|
+
return f"ChatService(base_url={self._base_url!r})"
|
saia_python/client.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Main SAIA client — composes all services."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import requests as _requests
|
|
6
|
+
|
|
7
|
+
from .arcana import ArcanaService
|
|
8
|
+
from .auth import resolve_credentials
|
|
9
|
+
from .chat import ChatService
|
|
10
|
+
from .documents import DocumentService
|
|
11
|
+
from .exceptions import raise_for_status
|
|
12
|
+
from .models import ModelsService
|
|
13
|
+
from .rate_limits import RateLimitInfo, parse_rate_limits
|
|
14
|
+
from .voice import VoiceService
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SAIAClient:
|
|
18
|
+
"""High-level client for the GWDG SAIA platform.
|
|
19
|
+
|
|
20
|
+
Provides access to Chat, Voice AI, ARCANA, Documents, and model
|
|
21
|
+
listing through a shared, authenticated HTTP session. An OpenAI-
|
|
22
|
+
compatible client is available via the ``.openai`` property.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
api_key: Your SAIA API key. If omitted, the key is resolved
|
|
26
|
+
automatically — see :func:`~saia_python.load_api_key` for the
|
|
27
|
+
resolution order.
|
|
28
|
+
base_url: Base URL for the SAIA API. Resolution order:
|
|
29
|
+
explicit parameter > ``[saia] base_url`` in ``config.toml`` >
|
|
30
|
+
hardcoded default (``https://chat-ai.academiccloud.de/v1``).
|
|
31
|
+
key_file: Explicit path to a ``.saia_api`` or ``.env`` file.
|
|
32
|
+
Ignored when ``api_key`` is provided.
|
|
33
|
+
|
|
34
|
+
Example::
|
|
35
|
+
|
|
36
|
+
# All settings resolved automatically
|
|
37
|
+
client = SAIAClient()
|
|
38
|
+
|
|
39
|
+
# Native services
|
|
40
|
+
client.chat.completions(model="...", messages=[...])
|
|
41
|
+
|
|
42
|
+
# OpenAI-compatible client (requires pip install saia-python[openai])
|
|
43
|
+
client.openai.chat.completions.create(model="...", messages=[...])
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
api_key: str | None = None,
|
|
49
|
+
base_url: str | None = None,
|
|
50
|
+
key_file: str | None = None,
|
|
51
|
+
):
|
|
52
|
+
self._api_key, self._base_url = resolve_credentials(api_key, base_url, key_file)
|
|
53
|
+
self._session = _requests.Session()
|
|
54
|
+
self._session.headers.update(
|
|
55
|
+
{
|
|
56
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
57
|
+
"Accept": "application/json",
|
|
58
|
+
}
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self._chat: ChatService | None = None
|
|
62
|
+
self._voice: VoiceService | None = None
|
|
63
|
+
self._models: ModelsService | None = None
|
|
64
|
+
self._arcana: ArcanaService | None = None
|
|
65
|
+
self._documents: DocumentService | None = None
|
|
66
|
+
self._openai = None
|
|
67
|
+
self._openai_async = None
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def chat(self) -> ChatService:
|
|
71
|
+
"""Chat completions service."""
|
|
72
|
+
if self._chat is None:
|
|
73
|
+
self._chat = ChatService(self._session, self._base_url)
|
|
74
|
+
return self._chat
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def voice(self) -> VoiceService:
|
|
78
|
+
"""Voice AI (transcription/translation) service."""
|
|
79
|
+
if self._voice is None:
|
|
80
|
+
self._voice = VoiceService(self._session, self._base_url)
|
|
81
|
+
return self._voice
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def models(self) -> ModelsService:
|
|
85
|
+
"""Model listing service."""
|
|
86
|
+
if self._models is None:
|
|
87
|
+
self._models = ModelsService(self._session, self._base_url)
|
|
88
|
+
return self._models
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def arcana(self) -> ArcanaService:
|
|
92
|
+
"""ARCANA/RAG service."""
|
|
93
|
+
if self._arcana is None:
|
|
94
|
+
self._arcana = ArcanaService(self._session, self._base_url, self._api_key)
|
|
95
|
+
return self._arcana
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def documents(self) -> DocumentService:
|
|
99
|
+
"""Document conversion (Docling) service."""
|
|
100
|
+
if self._documents is None:
|
|
101
|
+
self._documents = DocumentService(self._session, self._base_url)
|
|
102
|
+
return self._documents
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def openai(self):
|
|
106
|
+
"""OpenAI-compatible synchronous client.
|
|
107
|
+
|
|
108
|
+
Returns an ``openai.OpenAI`` instance configured with the same
|
|
109
|
+
API key and base URL as this client. Requires
|
|
110
|
+
``pip install saia-python[openai]``.
|
|
111
|
+
|
|
112
|
+
Example::
|
|
113
|
+
|
|
114
|
+
response = client.openai.chat.completions.create(
|
|
115
|
+
model="llama-3.3-70b-instruct",
|
|
116
|
+
messages=[{"role": "user", "content": "Hello!"}],
|
|
117
|
+
)
|
|
118
|
+
"""
|
|
119
|
+
if self._openai is None:
|
|
120
|
+
from .openai_compat import create_openai_client
|
|
121
|
+
|
|
122
|
+
self._openai = create_openai_client(
|
|
123
|
+
api_key=self._api_key,
|
|
124
|
+
base_url=self._base_url,
|
|
125
|
+
)
|
|
126
|
+
return self._openai
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def openai_async(self):
|
|
130
|
+
"""OpenAI-compatible asynchronous client.
|
|
131
|
+
|
|
132
|
+
Returns an ``openai.AsyncOpenAI`` instance. Requires
|
|
133
|
+
``pip install saia-python[openai]``.
|
|
134
|
+
"""
|
|
135
|
+
if self._openai_async is None:
|
|
136
|
+
from .openai_compat import create_openai_client
|
|
137
|
+
|
|
138
|
+
self._openai_async = create_openai_client(
|
|
139
|
+
api_key=self._api_key,
|
|
140
|
+
base_url=self._base_url,
|
|
141
|
+
async_client=True,
|
|
142
|
+
)
|
|
143
|
+
return self._openai_async
|
|
144
|
+
|
|
145
|
+
def get_rate_limits(self) -> RateLimitInfo:
|
|
146
|
+
"""Fetch current rate-limit status by making a lightweight API call.
|
|
147
|
+
|
|
148
|
+
Uses a GET to ``/chat/completions`` which returns 400 but includes
|
|
149
|
+
rate-limit headers.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Parsed :class:`RateLimitInfo`.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
AuthenticationError: If the API key is invalid or expired
|
|
156
|
+
(401/403). Other non-2xx statuses (notably the expected
|
|
157
|
+
400) are tolerated since they still carry the headers.
|
|
158
|
+
"""
|
|
159
|
+
resp = self._session.get(f"{self._base_url}/chat/completions")
|
|
160
|
+
if resp.status_code in (401, 403):
|
|
161
|
+
# The probe is *expected* to 400 (missing request body) but still
|
|
162
|
+
# carries rate-limit headers. A 401/403 means the key is bad, so
|
|
163
|
+
# surface it instead of silently returning an empty RateLimitInfo.
|
|
164
|
+
raise_for_status(resp)
|
|
165
|
+
return parse_rate_limits(resp.headers)
|
|
166
|
+
|
|
167
|
+
def arcana_version(self) -> str:
|
|
168
|
+
"""Return the ARCANA API version string.
|
|
169
|
+
|
|
170
|
+
Thin delegate to :meth:`ArcanaService.version`
|
|
171
|
+
(``client.arcana.version()``), which owns the ARCANA URL path and
|
|
172
|
+
auth scheme.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
The version string (e.g. ``"0.4.16"``).
|
|
176
|
+
"""
|
|
177
|
+
return self.arcana.version()
|
|
178
|
+
|
|
179
|
+
def arcana_heartbeat(self) -> bool:
|
|
180
|
+
"""Check if the ARCANA service is alive.
|
|
181
|
+
|
|
182
|
+
Thin delegate to :meth:`ArcanaService.heartbeat`
|
|
183
|
+
(``client.arcana.heartbeat()``). Returns ``True`` if the service
|
|
184
|
+
responds with 204, ``False`` otherwise.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
``True`` if the service is reachable.
|
|
188
|
+
"""
|
|
189
|
+
return self.arcana.heartbeat()
|
|
190
|
+
|
|
191
|
+
def health_check(self, *, verbose: bool = False) -> bool | dict:
|
|
192
|
+
"""Verify that the client can reach the API and authenticate.
|
|
193
|
+
|
|
194
|
+
Combines two cheap GETs:
|
|
195
|
+
|
|
196
|
+
- ``GET /models`` (authenticated) — confirms the API key resolves
|
|
197
|
+
and the chat backend is reachable.
|
|
198
|
+
- ``GET /arcanas/api/v1/heartbeat`` (cheap 204) — confirms the
|
|
199
|
+
ARCANA backend is reachable.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
verbose: If ``True``, return a diagnostic dict instead of a
|
|
203
|
+
bool. Useful in onboarding / setup scripts where you
|
|
204
|
+
want to surface *which* leg failed.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
``True`` if both legs succeed, ``False`` otherwise. With
|
|
208
|
+
``verbose=True``, a dict::
|
|
209
|
+
|
|
210
|
+
{
|
|
211
|
+
"ok": <bool>,
|
|
212
|
+
"base_url": <str>,
|
|
213
|
+
"models_ok": <bool>,
|
|
214
|
+
"model_count": <int>, # 0 if models leg failed
|
|
215
|
+
"arcana_ok": <bool>,
|
|
216
|
+
"error": <str|None>, # first leg that failed
|
|
217
|
+
}
|
|
218
|
+
"""
|
|
219
|
+
details: dict = {
|
|
220
|
+
"base_url": self._base_url,
|
|
221
|
+
"models_ok": False,
|
|
222
|
+
"model_count": 0,
|
|
223
|
+
"arcana_ok": False,
|
|
224
|
+
"error": None,
|
|
225
|
+
}
|
|
226
|
+
try:
|
|
227
|
+
model_ids = self.models.list_ids()
|
|
228
|
+
details["models_ok"] = True
|
|
229
|
+
details["model_count"] = len(model_ids)
|
|
230
|
+
except Exception as exc:
|
|
231
|
+
details["error"] = f"models: {exc}"
|
|
232
|
+
details["arcana_ok"] = self.arcana_heartbeat()
|
|
233
|
+
if not details["arcana_ok"] and details["error"] is None:
|
|
234
|
+
details["error"] = "arcana heartbeat returned non-204"
|
|
235
|
+
details["ok"] = details["models_ok"] and details["arcana_ok"]
|
|
236
|
+
return details if verbose else bool(details["ok"])
|
|
237
|
+
|
|
238
|
+
def __repr__(self):
|
|
239
|
+
return f"SAIAClient(base_url={self._base_url!r})"
|
saia_python/documents.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Document conversion service (Docling) — convert PDFs and documents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from .exceptions import raise_for_status
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import requests
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ConversionResult:
|
|
18
|
+
"""Result of a document conversion via the Docling service.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
filename: The original document filename.
|
|
22
|
+
response_type: The output format (``markdown``, ``html``, ``json``, or ``tokens``).
|
|
23
|
+
content: The converted content as a string.
|
|
24
|
+
images: A list of extracted image dicts, each with ``type``,
|
|
25
|
+
``filename``, and ``data`` (base64-encoded) keys.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
filename: str
|
|
29
|
+
response_type: str
|
|
30
|
+
content: str
|
|
31
|
+
images: list[dict] = field(default_factory=list)
|
|
32
|
+
|
|
33
|
+
def save(self, path: str | Path) -> Path:
|
|
34
|
+
"""Save the converted content to a file.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
path: Output file path.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The path written to.
|
|
41
|
+
"""
|
|
42
|
+
path = Path(path)
|
|
43
|
+
path.write_text(self.content, encoding="utf-8")
|
|
44
|
+
return path
|
|
45
|
+
|
|
46
|
+
def __str__(self):
|
|
47
|
+
n_img = len(self.images)
|
|
48
|
+
preview = self.content[:200]
|
|
49
|
+
suffix = "..." if len(self.content) > 200 else ""
|
|
50
|
+
return (
|
|
51
|
+
f"ConversionResult({self.filename!r}, {self.response_type}, "
|
|
52
|
+
f"{n_img} images, {len(self.content)} chars)\n{preview}{suffix}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DocumentService:
|
|
57
|
+
"""Access the ``/documents/convert`` endpoint (Docling).
|
|
58
|
+
|
|
59
|
+
Converts PDF and other document formats to markdown, HTML, JSON,
|
|
60
|
+
or token representations.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
session: A :class:`requests.Session` with auth headers configured.
|
|
64
|
+
base_url: The SAIA API base URL.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, session: requests.Session, base_url: str):
|
|
68
|
+
self._session = session
|
|
69
|
+
self._base_url = base_url
|
|
70
|
+
|
|
71
|
+
def convert(
|
|
72
|
+
self,
|
|
73
|
+
file_path: str | Path,
|
|
74
|
+
*,
|
|
75
|
+
response_type: str = "markdown",
|
|
76
|
+
extract_tables_as_images: bool | None = None,
|
|
77
|
+
image_resolution_scale: int | None = None,
|
|
78
|
+
) -> ConversionResult:
|
|
79
|
+
"""Convert a document using the Docling service.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
file_path: Path to the document file (PDF, etc.).
|
|
83
|
+
response_type: Output format — ``"markdown"`` (default),
|
|
84
|
+
``"html"``, ``"json"``, or ``"tokens"``.
|
|
85
|
+
extract_tables_as_images: If ``True``, render tables as images
|
|
86
|
+
instead of structured text.
|
|
87
|
+
image_resolution_scale: Image quality multiplier (1–4).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A :class:`ConversionResult` with the converted content and
|
|
91
|
+
any extracted images.
|
|
92
|
+
"""
|
|
93
|
+
file_path = Path(file_path)
|
|
94
|
+
params: dict = {"response_type": response_type}
|
|
95
|
+
if extract_tables_as_images is not None:
|
|
96
|
+
params["extract_tables_as_images"] = str(extract_tables_as_images).lower()
|
|
97
|
+
if image_resolution_scale is not None:
|
|
98
|
+
params["image_resolution_scale"] = image_resolution_scale
|
|
99
|
+
|
|
100
|
+
with open(file_path, "rb") as f:
|
|
101
|
+
resp = self._session.post(
|
|
102
|
+
f"{self._base_url}/documents/convert",
|
|
103
|
+
params=params,
|
|
104
|
+
files={"document": (file_path.name, f)},
|
|
105
|
+
)
|
|
106
|
+
raise_for_status(resp)
|
|
107
|
+
data = resp.json()
|
|
108
|
+
|
|
109
|
+
content = data.get(response_type, data.get("markdown", ""))
|
|
110
|
+
if isinstance(content, (list, dict)):
|
|
111
|
+
content = json.dumps(content, indent=2)
|
|
112
|
+
|
|
113
|
+
return ConversionResult(
|
|
114
|
+
filename=data.get("filename", file_path.name),
|
|
115
|
+
response_type=data.get("response_type", response_type),
|
|
116
|
+
content=content,
|
|
117
|
+
images=data.get("images", []),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def convert_to_markdown(self, file_path: str | Path, **kwargs) -> str:
|
|
121
|
+
"""Convert a document to markdown (convenience method).
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
file_path: Path to the document file.
|
|
125
|
+
**kwargs: Additional parameters passed to :meth:`convert`.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The markdown content as a string.
|
|
129
|
+
"""
|
|
130
|
+
return self.convert(file_path, response_type="markdown", **kwargs).content
|
|
131
|
+
|
|
132
|
+
def convert_to_html(self, file_path: str | Path, **kwargs) -> str:
|
|
133
|
+
"""Convert a document to HTML (convenience method).
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
file_path: Path to the document file.
|
|
137
|
+
**kwargs: Additional parameters passed to :meth:`convert`.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
The HTML content as a string.
|
|
141
|
+
"""
|
|
142
|
+
return self.convert(file_path, response_type="html", **kwargs).content
|
|
143
|
+
|
|
144
|
+
def __repr__(self):
|
|
145
|
+
return f"DocumentService(base_url={self._base_url!r})"
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Custom exceptions and shared HTTP error handling for the SAIA Python wrapper."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json as _json
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from .rate_limits import parse_rate_limits
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SAIAError(Exception):
|
|
15
|
+
"""Base exception for all SAIA API errors."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AuthenticationError(SAIAError):
|
|
19
|
+
"""Raised on 401/403 responses — invalid or missing API key."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RateLimitError(SAIAError):
|
|
23
|
+
"""Raised on 429 responses — rate limit exceeded."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, message, rate_limits=None):
|
|
26
|
+
super().__init__(message)
|
|
27
|
+
self.rate_limits = rate_limits
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class APIError(SAIAError):
|
|
31
|
+
"""Raised on unexpected HTTP errors."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, message, status_code=None, response_body=None):
|
|
34
|
+
super().__init__(message)
|
|
35
|
+
self.status_code = status_code
|
|
36
|
+
self.response_body = response_body
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _extract_detail(resp: requests.Response) -> str:
|
|
40
|
+
"""Try to extract a human-readable message from a JSON error body.
|
|
41
|
+
|
|
42
|
+
The SAIA API typically returns ``{"detail": "..."}`` on errors.
|
|
43
|
+
Falls back to the raw response text.
|
|
44
|
+
"""
|
|
45
|
+
try:
|
|
46
|
+
body = resp.json()
|
|
47
|
+
if isinstance(body, dict) and "detail" in body:
|
|
48
|
+
return body["detail"]
|
|
49
|
+
except (_json.JSONDecodeError, ValueError):
|
|
50
|
+
pass
|
|
51
|
+
return resp.text
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def raise_for_status(resp: requests.Response) -> None:
|
|
55
|
+
"""Raise a typed SAIA exception for HTTP error responses.
|
|
56
|
+
|
|
57
|
+
Called by all service modules — this is the single implementation.
|
|
58
|
+
"""
|
|
59
|
+
if resp.ok:
|
|
60
|
+
return
|
|
61
|
+
detail = _extract_detail(resp)
|
|
62
|
+
if resp.status_code in (401, 403):
|
|
63
|
+
raise AuthenticationError(detail)
|
|
64
|
+
if resp.status_code == 429:
|
|
65
|
+
raise RateLimitError(detail, rate_limits=parse_rate_limits(resp.headers))
|
|
66
|
+
# Any other non-2xx/3xx status. The early `return` above already handled
|
|
67
|
+
# the resp.ok case, so reaching here always means an error response.
|
|
68
|
+
raise APIError(detail, status_code=resp.status_code, response_body=resp.text)
|
saia_python/models.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Models service — list available SAIA models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from ._util import progress_iter
|
|
8
|
+
from .exceptions import raise_for_status
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelsService:
|
|
15
|
+
"""Access the ``/models`` endpoint.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
session: A :class:`requests.Session` with auth headers configured.
|
|
19
|
+
base_url: The SAIA API base URL (e.g. ``https://chat-ai.academiccloud.de/v1``).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, session: requests.Session, base_url: str):
|
|
23
|
+
self._session = session
|
|
24
|
+
self._base_url = base_url
|
|
25
|
+
|
|
26
|
+
def list_raw(self) -> dict:
|
|
27
|
+
"""Return the raw ``/models`` response envelope, as the API sent it.
|
|
28
|
+
|
|
29
|
+
Unlike :meth:`list`, this does **not** unwrap the OpenAI-style
|
|
30
|
+
``{"object": "list", "data": [...]}`` envelope — it returns the
|
|
31
|
+
parsed JSON verbatim. Use it when you need the full
|
|
32
|
+
OpenAI-compatible payload, e.g. an adapter that re-serves SAIA's
|
|
33
|
+
models at its own ``GET /v1/models`` endpoint::
|
|
34
|
+
|
|
35
|
+
return client.models.list_raw() # already the OpenAI envelope
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
The parsed JSON response. For the SAIA / OpenAI-compatible API
|
|
39
|
+
this is a dict of the form
|
|
40
|
+
``{"object": "list", "data": [...]}``.
|
|
41
|
+
"""
|
|
42
|
+
resp = self._session.get(f"{self._base_url}/models")
|
|
43
|
+
raise_for_status(resp)
|
|
44
|
+
return resp.json()
|
|
45
|
+
|
|
46
|
+
def list(self) -> list[dict]:
|
|
47
|
+
"""Return the full model list as returned by the API.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
A list of model dicts, each containing at least an ``"id"`` key.
|
|
51
|
+
"""
|
|
52
|
+
data = self.list_raw()
|
|
53
|
+
if isinstance(data, list):
|
|
54
|
+
return data
|
|
55
|
+
if isinstance(data, dict) and "data" in data:
|
|
56
|
+
return data["data"]
|
|
57
|
+
return [data]
|
|
58
|
+
|
|
59
|
+
def list_ids(self) -> list[str]:
|
|
60
|
+
"""Return a deduplicated list of model ID strings."""
|
|
61
|
+
models = self.list()
|
|
62
|
+
seen = set()
|
|
63
|
+
ids = []
|
|
64
|
+
for m in models:
|
|
65
|
+
mid = (
|
|
66
|
+
m.get("id")
|
|
67
|
+
or m.get("modelId")
|
|
68
|
+
or m.get("model_id")
|
|
69
|
+
or m.get("model")
|
|
70
|
+
or m.get("name")
|
|
71
|
+
or m.get("model_name")
|
|
72
|
+
)
|
|
73
|
+
if mid and mid not in seen:
|
|
74
|
+
seen.add(mid)
|
|
75
|
+
ids.append(mid)
|
|
76
|
+
return ids
|
|
77
|
+
|
|
78
|
+
def list_tool_capable(self, *, verbose: bool = False) -> list[str]:
|
|
79
|
+
"""Identify models that support tool calling by probing each one.
|
|
80
|
+
|
|
81
|
+
Sends a minimal tool-calling request to each available model and
|
|
82
|
+
checks whether the response contains a ``tool_calls`` field. This
|
|
83
|
+
is a trial-and-error approach because the SAIA API does not expose
|
|
84
|
+
tool support as model metadata.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
verbose: If ``True``, print per-model results during probing.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A list of model ID strings that responded with a tool call.
|
|
91
|
+
|
|
92
|
+
Note:
|
|
93
|
+
This method consumes API quota (one request per model) and may
|
|
94
|
+
take several minutes depending on the number of available models.
|
|
95
|
+
"""
|
|
96
|
+
model_ids = self.list_ids()
|
|
97
|
+
|
|
98
|
+
_PROBE_TOOLS = [
|
|
99
|
+
{
|
|
100
|
+
"type": "function",
|
|
101
|
+
"function": {
|
|
102
|
+
"name": "probe",
|
|
103
|
+
"description": "Test probe.",
|
|
104
|
+
"parameters": {
|
|
105
|
+
"type": "object",
|
|
106
|
+
"properties": {"x": {"type": "string"}},
|
|
107
|
+
"required": ["x"],
|
|
108
|
+
},
|
|
109
|
+
},
|
|
110
|
+
}
|
|
111
|
+
]
|
|
112
|
+
_PROBE_MESSAGES = [
|
|
113
|
+
{"role": "user", "content": "Call the probe tool with x='test'."}
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
capable = []
|
|
117
|
+
for mid in progress_iter(
|
|
118
|
+
model_ids, desc="Probing models", unit="model", enabled=not verbose
|
|
119
|
+
):
|
|
120
|
+
try:
|
|
121
|
+
resp = self._session.post(
|
|
122
|
+
f"{self._base_url}/chat/completions",
|
|
123
|
+
json={
|
|
124
|
+
"model": mid,
|
|
125
|
+
"messages": _PROBE_MESSAGES,
|
|
126
|
+
"tools": _PROBE_TOOLS,
|
|
127
|
+
"max_tokens": 50,
|
|
128
|
+
},
|
|
129
|
+
timeout=30,
|
|
130
|
+
)
|
|
131
|
+
data = resp.json()
|
|
132
|
+
msg = data.get("choices", [{}])[0].get("message", {})
|
|
133
|
+
has_tools = bool(msg.get("tool_calls"))
|
|
134
|
+
if has_tools:
|
|
135
|
+
capable.append(mid)
|
|
136
|
+
if verbose:
|
|
137
|
+
status = "tool_calls" if has_tools else "no tools"
|
|
138
|
+
print(f" {mid:<45} {status}")
|
|
139
|
+
except Exception as e:
|
|
140
|
+
if verbose:
|
|
141
|
+
print(f" {mid:<45} error: {e}")
|
|
142
|
+
|
|
143
|
+
return capable
|
|
144
|
+
|
|
145
|
+
def __repr__(self):
|
|
146
|
+
return f"ModelsService(base_url={self._base_url!r})"
|