saia-python 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
saia_python/chat.py ADDED
@@ -0,0 +1,72 @@
1
+ """Chat service — completions and streaming."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from ._http import post_chat_completion
8
+ from ._streaming import SSEStream
9
+
10
+ if TYPE_CHECKING:
11
+ import requests
12
+
13
+
14
+ class ChatService:
15
+ """Access the ``/chat/completions`` endpoint.
16
+
17
+ Args:
18
+ session: A :class:`requests.Session` with auth headers configured.
19
+ base_url: The SAIA API base URL.
20
+ """
21
+
22
+ def __init__(self, session: requests.Session, base_url: str):
23
+ self._session = session
24
+ self._base_url = base_url
25
+
26
+ def completions(
27
+ self,
28
+ model: str,
29
+ messages: list[dict],
30
+ *,
31
+ temperature: float | None = None,
32
+ top_p: float | None = None,
33
+ max_tokens: int | None = None,
34
+ stream: bool = False,
35
+ **kwargs,
36
+ ) -> dict | SSEStream:
37
+ """Send a chat completion request.
38
+
39
+ Args:
40
+ model: Model identifier (e.g. ``"meta-llama-3.1-8b-instruct"``).
41
+ messages: List of message dicts with ``"role"`` and ``"content"`` keys.
42
+ temperature: Sampling temperature (0–2).
43
+ top_p: Nucleus sampling parameter (0–1).
44
+ max_tokens: Maximum tokens to generate.
45
+ stream: If ``True``, return a generator yielding chunks.
46
+ **kwargs: Additional parameters forwarded to the API.
47
+
48
+ Returns:
49
+ When ``stream=False``: the API response dict, with an extra
50
+ ``"_rate_limits"`` key — a JSON-serializable dict of the current
51
+ rate-limit headers (see :class:`~saia_python.RateLimitInfo`).
52
+ When ``stream=True``: an ``SSEStream`` — iterate it for the
53
+ response chunks; its ``rate_limits`` attribute exposes the same
54
+ dict (available immediately, from the response headers).
55
+ """
56
+ body = {"model": model, "messages": messages, **kwargs}
57
+ if temperature is not None:
58
+ body["temperature"] = temperature
59
+ if top_p is not None:
60
+ body["top_p"] = top_p
61
+ if max_tokens is not None:
62
+ body["max_tokens"] = max_tokens
63
+
64
+ return post_chat_completion(
65
+ self._session,
66
+ f"{self._base_url}/chat/completions",
67
+ body,
68
+ stream=stream,
69
+ )
70
+
71
+ def __repr__(self):
72
+ return f"ChatService(base_url={self._base_url!r})"
saia_python/client.py ADDED
@@ -0,0 +1,239 @@
1
+ """Main SAIA client — composes all services."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import requests as _requests
6
+
7
+ from .arcana import ArcanaService
8
+ from .auth import resolve_credentials
9
+ from .chat import ChatService
10
+ from .documents import DocumentService
11
+ from .exceptions import raise_for_status
12
+ from .models import ModelsService
13
+ from .rate_limits import RateLimitInfo, parse_rate_limits
14
+ from .voice import VoiceService
15
+
16
+
17
+ class SAIAClient:
18
+ """High-level client for the GWDG SAIA platform.
19
+
20
+ Provides access to Chat, Voice AI, ARCANA, Documents, and model
21
+ listing through a shared, authenticated HTTP session. An OpenAI-
22
+ compatible client is available via the ``.openai`` property.
23
+
24
+ Args:
25
+ api_key: Your SAIA API key. If omitted, the key is resolved
26
+ automatically — see :func:`~saia_python.load_api_key` for the
27
+ resolution order.
28
+ base_url: Base URL for the SAIA API. Resolution order:
29
+ explicit parameter > ``[saia] base_url`` in ``config.toml`` >
30
+ hardcoded default (``https://chat-ai.academiccloud.de/v1``).
31
+ key_file: Explicit path to a ``.saia_api`` or ``.env`` file.
32
+ Ignored when ``api_key`` is provided.
33
+
34
+ Example::
35
+
36
+ # All settings resolved automatically
37
+ client = SAIAClient()
38
+
39
+ # Native services
40
+ client.chat.completions(model="...", messages=[...])
41
+
42
+ # OpenAI-compatible client (requires pip install saia-python[openai])
43
+ client.openai.chat.completions.create(model="...", messages=[...])
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ api_key: str | None = None,
49
+ base_url: str | None = None,
50
+ key_file: str | None = None,
51
+ ):
52
+ self._api_key, self._base_url = resolve_credentials(api_key, base_url, key_file)
53
+ self._session = _requests.Session()
54
+ self._session.headers.update(
55
+ {
56
+ "Authorization": f"Bearer {self._api_key}",
57
+ "Accept": "application/json",
58
+ }
59
+ )
60
+
61
+ self._chat: ChatService | None = None
62
+ self._voice: VoiceService | None = None
63
+ self._models: ModelsService | None = None
64
+ self._arcana: ArcanaService | None = None
65
+ self._documents: DocumentService | None = None
66
+ self._openai = None
67
+ self._openai_async = None
68
+
69
+ @property
70
+ def chat(self) -> ChatService:
71
+ """Chat completions service."""
72
+ if self._chat is None:
73
+ self._chat = ChatService(self._session, self._base_url)
74
+ return self._chat
75
+
76
+ @property
77
+ def voice(self) -> VoiceService:
78
+ """Voice AI (transcription/translation) service."""
79
+ if self._voice is None:
80
+ self._voice = VoiceService(self._session, self._base_url)
81
+ return self._voice
82
+
83
+ @property
84
+ def models(self) -> ModelsService:
85
+ """Model listing service."""
86
+ if self._models is None:
87
+ self._models = ModelsService(self._session, self._base_url)
88
+ return self._models
89
+
90
+ @property
91
+ def arcana(self) -> ArcanaService:
92
+ """ARCANA/RAG service."""
93
+ if self._arcana is None:
94
+ self._arcana = ArcanaService(self._session, self._base_url, self._api_key)
95
+ return self._arcana
96
+
97
+ @property
98
+ def documents(self) -> DocumentService:
99
+ """Document conversion (Docling) service."""
100
+ if self._documents is None:
101
+ self._documents = DocumentService(self._session, self._base_url)
102
+ return self._documents
103
+
104
+ @property
105
+ def openai(self):
106
+ """OpenAI-compatible synchronous client.
107
+
108
+ Returns an ``openai.OpenAI`` instance configured with the same
109
+ API key and base URL as this client. Requires
110
+ ``pip install saia-python[openai]``.
111
+
112
+ Example::
113
+
114
+ response = client.openai.chat.completions.create(
115
+ model="llama-3.3-70b-instruct",
116
+ messages=[{"role": "user", "content": "Hello!"}],
117
+ )
118
+ """
119
+ if self._openai is None:
120
+ from .openai_compat import create_openai_client
121
+
122
+ self._openai = create_openai_client(
123
+ api_key=self._api_key,
124
+ base_url=self._base_url,
125
+ )
126
+ return self._openai
127
+
128
+ @property
129
+ def openai_async(self):
130
+ """OpenAI-compatible asynchronous client.
131
+
132
+ Returns an ``openai.AsyncOpenAI`` instance. Requires
133
+ ``pip install saia-python[openai]``.
134
+ """
135
+ if self._openai_async is None:
136
+ from .openai_compat import create_openai_client
137
+
138
+ self._openai_async = create_openai_client(
139
+ api_key=self._api_key,
140
+ base_url=self._base_url,
141
+ async_client=True,
142
+ )
143
+ return self._openai_async
144
+
145
+ def get_rate_limits(self) -> RateLimitInfo:
146
+ """Fetch current rate-limit status by making a lightweight API call.
147
+
148
+ Uses a GET to ``/chat/completions`` which returns 400 but includes
149
+ rate-limit headers.
150
+
151
+ Returns:
152
+ Parsed :class:`RateLimitInfo`.
153
+
154
+ Raises:
155
+ AuthenticationError: If the API key is invalid or expired
156
+ (401/403). Other non-2xx statuses (notably the expected
157
+ 400) are tolerated since they still carry the headers.
158
+ """
159
+ resp = self._session.get(f"{self._base_url}/chat/completions")
160
+ if resp.status_code in (401, 403):
161
+ # The probe is *expected* to 400 (missing request body) but still
162
+ # carries rate-limit headers. A 401/403 means the key is bad, so
163
+ # surface it instead of silently returning an empty RateLimitInfo.
164
+ raise_for_status(resp)
165
+ return parse_rate_limits(resp.headers)
166
+
167
+ def arcana_version(self) -> str:
168
+ """Return the ARCANA API version string.
169
+
170
+ Thin delegate to :meth:`ArcanaService.version`
171
+ (``client.arcana.version()``), which owns the ARCANA URL path and
172
+ auth scheme.
173
+
174
+ Returns:
175
+ The version string (e.g. ``"0.4.16"``).
176
+ """
177
+ return self.arcana.version()
178
+
179
+ def arcana_heartbeat(self) -> bool:
180
+ """Check if the ARCANA service is alive.
181
+
182
+ Thin delegate to :meth:`ArcanaService.heartbeat`
183
+ (``client.arcana.heartbeat()``). Returns ``True`` if the service
184
+ responds with 204, ``False`` otherwise.
185
+
186
+ Returns:
187
+ ``True`` if the service is reachable.
188
+ """
189
+ return self.arcana.heartbeat()
190
+
191
+ def health_check(self, *, verbose: bool = False) -> bool | dict:
192
+ """Verify that the client can reach the API and authenticate.
193
+
194
+ Combines two cheap GETs:
195
+
196
+ - ``GET /models`` (authenticated) — confirms the API key resolves
197
+ and the chat backend is reachable.
198
+ - ``GET /arcanas/api/v1/heartbeat`` (cheap 204) — confirms the
199
+ ARCANA backend is reachable.
200
+
201
+ Args:
202
+ verbose: If ``True``, return a diagnostic dict instead of a
203
+ bool. Useful in onboarding / setup scripts where you
204
+ want to surface *which* leg failed.
205
+
206
+ Returns:
207
+ ``True`` if both legs succeed, ``False`` otherwise. With
208
+ ``verbose=True``, a dict::
209
+
210
+ {
211
+ "ok": <bool>,
212
+ "base_url": <str>,
213
+ "models_ok": <bool>,
214
+ "model_count": <int>, # 0 if models leg failed
215
+ "arcana_ok": <bool>,
216
+ "error": <str|None>, # first leg that failed
217
+ }
218
+ """
219
+ details: dict = {
220
+ "base_url": self._base_url,
221
+ "models_ok": False,
222
+ "model_count": 0,
223
+ "arcana_ok": False,
224
+ "error": None,
225
+ }
226
+ try:
227
+ model_ids = self.models.list_ids()
228
+ details["models_ok"] = True
229
+ details["model_count"] = len(model_ids)
230
+ except Exception as exc:
231
+ details["error"] = f"models: {exc}"
232
+ details["arcana_ok"] = self.arcana_heartbeat()
233
+ if not details["arcana_ok"] and details["error"] is None:
234
+ details["error"] = "arcana heartbeat returned non-204"
235
+ details["ok"] = details["models_ok"] and details["arcana_ok"]
236
+ return details if verbose else bool(details["ok"])
237
+
238
+ def __repr__(self):
239
+ return f"SAIAClient(base_url={self._base_url!r})"
@@ -0,0 +1,145 @@
1
+ """Document conversion service (Docling) — convert PDFs and documents."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ from .exceptions import raise_for_status
11
+
12
+ if TYPE_CHECKING:
13
+ import requests
14
+
15
+
16
+ @dataclass
17
+ class ConversionResult:
18
+ """Result of a document conversion via the Docling service.
19
+
20
+ Attributes:
21
+ filename: The original document filename.
22
+ response_type: The output format (``markdown``, ``html``, ``json``, or ``tokens``).
23
+ content: The converted content as a string.
24
+ images: A list of extracted image dicts, each with ``type``,
25
+ ``filename``, and ``data`` (base64-encoded) keys.
26
+ """
27
+
28
+ filename: str
29
+ response_type: str
30
+ content: str
31
+ images: list[dict] = field(default_factory=list)
32
+
33
+ def save(self, path: str | Path) -> Path:
34
+ """Save the converted content to a file.
35
+
36
+ Args:
37
+ path: Output file path.
38
+
39
+ Returns:
40
+ The path written to.
41
+ """
42
+ path = Path(path)
43
+ path.write_text(self.content, encoding="utf-8")
44
+ return path
45
+
46
+ def __str__(self):
47
+ n_img = len(self.images)
48
+ preview = self.content[:200]
49
+ suffix = "..." if len(self.content) > 200 else ""
50
+ return (
51
+ f"ConversionResult({self.filename!r}, {self.response_type}, "
52
+ f"{n_img} images, {len(self.content)} chars)\n{preview}{suffix}"
53
+ )
54
+
55
+
56
+ class DocumentService:
57
+ """Access the ``/documents/convert`` endpoint (Docling).
58
+
59
+ Converts PDF and other document formats to markdown, HTML, JSON,
60
+ or token representations.
61
+
62
+ Args:
63
+ session: A :class:`requests.Session` with auth headers configured.
64
+ base_url: The SAIA API base URL.
65
+ """
66
+
67
+ def __init__(self, session: requests.Session, base_url: str):
68
+ self._session = session
69
+ self._base_url = base_url
70
+
71
+ def convert(
72
+ self,
73
+ file_path: str | Path,
74
+ *,
75
+ response_type: str = "markdown",
76
+ extract_tables_as_images: bool | None = None,
77
+ image_resolution_scale: int | None = None,
78
+ ) -> ConversionResult:
79
+ """Convert a document using the Docling service.
80
+
81
+ Args:
82
+ file_path: Path to the document file (PDF, etc.).
83
+ response_type: Output format — ``"markdown"`` (default),
84
+ ``"html"``, ``"json"``, or ``"tokens"``.
85
+ extract_tables_as_images: If ``True``, render tables as images
86
+ instead of structured text.
87
+ image_resolution_scale: Image quality multiplier (1–4).
88
+
89
+ Returns:
90
+ A :class:`ConversionResult` with the converted content and
91
+ any extracted images.
92
+ """
93
+ file_path = Path(file_path)
94
+ params: dict = {"response_type": response_type}
95
+ if extract_tables_as_images is not None:
96
+ params["extract_tables_as_images"] = str(extract_tables_as_images).lower()
97
+ if image_resolution_scale is not None:
98
+ params["image_resolution_scale"] = image_resolution_scale
99
+
100
+ with open(file_path, "rb") as f:
101
+ resp = self._session.post(
102
+ f"{self._base_url}/documents/convert",
103
+ params=params,
104
+ files={"document": (file_path.name, f)},
105
+ )
106
+ raise_for_status(resp)
107
+ data = resp.json()
108
+
109
+ content = data.get(response_type, data.get("markdown", ""))
110
+ if isinstance(content, (list, dict)):
111
+ content = json.dumps(content, indent=2)
112
+
113
+ return ConversionResult(
114
+ filename=data.get("filename", file_path.name),
115
+ response_type=data.get("response_type", response_type),
116
+ content=content,
117
+ images=data.get("images", []),
118
+ )
119
+
120
+ def convert_to_markdown(self, file_path: str | Path, **kwargs) -> str:
121
+ """Convert a document to markdown (convenience method).
122
+
123
+ Args:
124
+ file_path: Path to the document file.
125
+ **kwargs: Additional parameters passed to :meth:`convert`.
126
+
127
+ Returns:
128
+ The markdown content as a string.
129
+ """
130
+ return self.convert(file_path, response_type="markdown", **kwargs).content
131
+
132
+ def convert_to_html(self, file_path: str | Path, **kwargs) -> str:
133
+ """Convert a document to HTML (convenience method).
134
+
135
+ Args:
136
+ file_path: Path to the document file.
137
+ **kwargs: Additional parameters passed to :meth:`convert`.
138
+
139
+ Returns:
140
+ The HTML content as a string.
141
+ """
142
+ return self.convert(file_path, response_type="html", **kwargs).content
143
+
144
+ def __repr__(self):
145
+ return f"DocumentService(base_url={self._base_url!r})"
@@ -0,0 +1,68 @@
1
+ """Custom exceptions and shared HTTP error handling for the SAIA Python wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json as _json
6
+ from typing import TYPE_CHECKING
7
+
8
+ from .rate_limits import parse_rate_limits
9
+
10
+ if TYPE_CHECKING:
11
+ import requests
12
+
13
+
14
+ class SAIAError(Exception):
15
+ """Base exception for all SAIA API errors."""
16
+
17
+
18
+ class AuthenticationError(SAIAError):
19
+ """Raised on 401/403 responses — invalid or missing API key."""
20
+
21
+
22
+ class RateLimitError(SAIAError):
23
+ """Raised on 429 responses — rate limit exceeded."""
24
+
25
+ def __init__(self, message, rate_limits=None):
26
+ super().__init__(message)
27
+ self.rate_limits = rate_limits
28
+
29
+
30
+ class APIError(SAIAError):
31
+ """Raised on unexpected HTTP errors."""
32
+
33
+ def __init__(self, message, status_code=None, response_body=None):
34
+ super().__init__(message)
35
+ self.status_code = status_code
36
+ self.response_body = response_body
37
+
38
+
39
+ def _extract_detail(resp: requests.Response) -> str:
40
+ """Try to extract a human-readable message from a JSON error body.
41
+
42
+ The SAIA API typically returns ``{"detail": "..."}`` on errors.
43
+ Falls back to the raw response text.
44
+ """
45
+ try:
46
+ body = resp.json()
47
+ if isinstance(body, dict) and "detail" in body:
48
+ return body["detail"]
49
+ except (_json.JSONDecodeError, ValueError):
50
+ pass
51
+ return resp.text
52
+
53
+
54
+ def raise_for_status(resp: requests.Response) -> None:
55
+ """Raise a typed SAIA exception for HTTP error responses.
56
+
57
+ Called by all service modules — this is the single implementation.
58
+ """
59
+ if resp.ok:
60
+ return
61
+ detail = _extract_detail(resp)
62
+ if resp.status_code in (401, 403):
63
+ raise AuthenticationError(detail)
64
+ if resp.status_code == 429:
65
+ raise RateLimitError(detail, rate_limits=parse_rate_limits(resp.headers))
66
+ # Any other non-2xx/3xx status. The early `return` above already handled
67
+ # the resp.ok case, so reaching here always means an error response.
68
+ raise APIError(detail, status_code=resp.status_code, response_body=resp.text)
saia_python/models.py ADDED
@@ -0,0 +1,146 @@
1
+ """Models service — list available SAIA models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from ._util import progress_iter
8
+ from .exceptions import raise_for_status
9
+
10
+ if TYPE_CHECKING:
11
+ import requests
12
+
13
+
14
+ class ModelsService:
15
+ """Access the ``/models`` endpoint.
16
+
17
+ Args:
18
+ session: A :class:`requests.Session` with auth headers configured.
19
+ base_url: The SAIA API base URL (e.g. ``https://chat-ai.academiccloud.de/v1``).
20
+ """
21
+
22
+ def __init__(self, session: requests.Session, base_url: str):
23
+ self._session = session
24
+ self._base_url = base_url
25
+
26
+ def list_raw(self) -> dict:
27
+ """Return the raw ``/models`` response envelope, as the API sent it.
28
+
29
+ Unlike :meth:`list`, this does **not** unwrap the OpenAI-style
30
+ ``{"object": "list", "data": [...]}`` envelope — it returns the
31
+ parsed JSON verbatim. Use it when you need the full
32
+ OpenAI-compatible payload, e.g. an adapter that re-serves SAIA's
33
+ models at its own ``GET /v1/models`` endpoint::
34
+
35
+ return client.models.list_raw() # already the OpenAI envelope
36
+
37
+ Returns:
38
+ The parsed JSON response. For the SAIA / OpenAI-compatible API
39
+ this is a dict of the form
40
+ ``{"object": "list", "data": [...]}``.
41
+ """
42
+ resp = self._session.get(f"{self._base_url}/models")
43
+ raise_for_status(resp)
44
+ return resp.json()
45
+
46
+ def list(self) -> list[dict]:
47
+ """Return the full model list as returned by the API.
48
+
49
+ Returns:
50
+ A list of model dicts, each containing at least an ``"id"`` key.
51
+ """
52
+ data = self.list_raw()
53
+ if isinstance(data, list):
54
+ return data
55
+ if isinstance(data, dict) and "data" in data:
56
+ return data["data"]
57
+ return [data]
58
+
59
+ def list_ids(self) -> list[str]:
60
+ """Return a deduplicated list of model ID strings."""
61
+ models = self.list()
62
+ seen = set()
63
+ ids = []
64
+ for m in models:
65
+ mid = (
66
+ m.get("id")
67
+ or m.get("modelId")
68
+ or m.get("model_id")
69
+ or m.get("model")
70
+ or m.get("name")
71
+ or m.get("model_name")
72
+ )
73
+ if mid and mid not in seen:
74
+ seen.add(mid)
75
+ ids.append(mid)
76
+ return ids
77
+
78
+ def list_tool_capable(self, *, verbose: bool = False) -> list[str]:
79
+ """Identify models that support tool calling by probing each one.
80
+
81
+ Sends a minimal tool-calling request to each available model and
82
+ checks whether the response contains a ``tool_calls`` field. This
83
+ is a trial-and-error approach because the SAIA API does not expose
84
+ tool support as model metadata.
85
+
86
+ Args:
87
+ verbose: If ``True``, print per-model results during probing.
88
+
89
+ Returns:
90
+ A list of model ID strings that responded with a tool call.
91
+
92
+ Note:
93
+ This method consumes API quota (one request per model) and may
94
+ take several minutes depending on the number of available models.
95
+ """
96
+ model_ids = self.list_ids()
97
+
98
+ _PROBE_TOOLS = [
99
+ {
100
+ "type": "function",
101
+ "function": {
102
+ "name": "probe",
103
+ "description": "Test probe.",
104
+ "parameters": {
105
+ "type": "object",
106
+ "properties": {"x": {"type": "string"}},
107
+ "required": ["x"],
108
+ },
109
+ },
110
+ }
111
+ ]
112
+ _PROBE_MESSAGES = [
113
+ {"role": "user", "content": "Call the probe tool with x='test'."}
114
+ ]
115
+
116
+ capable = []
117
+ for mid in progress_iter(
118
+ model_ids, desc="Probing models", unit="model", enabled=not verbose
119
+ ):
120
+ try:
121
+ resp = self._session.post(
122
+ f"{self._base_url}/chat/completions",
123
+ json={
124
+ "model": mid,
125
+ "messages": _PROBE_MESSAGES,
126
+ "tools": _PROBE_TOOLS,
127
+ "max_tokens": 50,
128
+ },
129
+ timeout=30,
130
+ )
131
+ data = resp.json()
132
+ msg = data.get("choices", [{}])[0].get("message", {})
133
+ has_tools = bool(msg.get("tool_calls"))
134
+ if has_tools:
135
+ capable.append(mid)
136
+ if verbose:
137
+ status = "tool_calls" if has_tools else "no tools"
138
+ print(f" {mid:<45} {status}")
139
+ except Exception as e:
140
+ if verbose:
141
+ print(f" {mid:<45} error: {e}")
142
+
143
+ return capable
144
+
145
+ def __repr__(self):
146
+ return f"ModelsService(base_url={self._base_url!r})"