yarn-au 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,388 @@
1
+ # Yarn Python SDK — Design Specification
2
+
3
+ ## Overview
4
+
5
+ Thin Python client for Yarn's REST API. Handles auth, session lifecycle, job submission, storage, and billing queries. Dependencies: `requests` only.
6
+
7
+ Install: `pip install -e apps/yarn/sdk/` (monorepo) or `pip install git+ssh://git@github.com/prosodylabs/infrastructure.git#subdirectory=apps/yarn/sdk`
8
+
9
+ ## Package structure
10
+
11
+ ```
12
+ apps/yarn/sdk/
13
+ ├── pyproject.toml
14
+ ├── src/
15
+ │ └── yarn/
16
+ │ ├── __init__.py # Re-export top-level convenience functions
17
+ │ ├── client.py # YarnClient — auth, base URL, request helpers
18
+ │ ├── jobs.py # Job submission, status, logs, cancel
19
+ │ ├── sessions.py # Interactive GPU sessions (context manager)
20
+ │ ├── notebooks.py # JupyterLab lifecycle
21
+ │ ├── storage.py # Upload, download, list, delete
22
+ │ ├── billing.py # Balance, usage, cost estimation
23
+ │ ├── secrets.py # Secret CRUD
24
+ │ └── exceptions.py # YarnError, InsufficientCredits, etc.
25
+ ```
26
+
27
+ ## Authentication
28
+
29
+ API key is the only auth mechanism for the SDK. The gateway/Kong must resolve API keys to user IDs + roles for `/v1/research/*` and `/v1/data/*` routes (backend item #15 — prerequisite).
30
+
31
+ ```python
32
+ client = yarn.Client(api_key="yarn_...")
33
+ # or
34
+ client = yarn.Client() # reads YARN_API_KEY env var
35
+ ```
36
+
37
+ All requests send `Authorization: Bearer yarn_...` header. No JWT management, no OIDC, no token refresh.
38
+
39
+ ## Base URL
40
+
41
+ Default: `https://api.au.yarn.prosodylabs.com.au`
42
+ Override: `yarn.Client(api_key="...", base_url="...")`
43
+
44
+ All SDK methods hit this base URL. Research endpoints at `/v1/research/*`, data at `/v1/data/*`, billing at `/v1/billing/*`, inference at `/v1/chat/completions`.
45
+
46
+ ## Top-level convenience API
47
+
48
+ ```python
49
+ import yarn
50
+
51
+ # Configure once
52
+ yarn.api_key = "yarn_..."
53
+ # or
54
+ client = yarn.Client(api_key="yarn_...")
55
+ ```
56
+
57
+ ### Sessions (interactive GPU)
58
+
59
+ ```python
60
+ # Context manager — creates session, waits for ready, tears down on exit
61
+ with yarn.session(gpu="rtx-4090", name="experiment-1") as s:
62
+ print(s.ray_address) # ray://experiment-1.research.prosodylabs.com.au:443
63
+ print(s.dashboard_url) # https://experiment-1-dash.research.prosodylabs.com.au
64
+ print(s.status) # "running"
65
+
66
+ import ray
67
+ ray.init(address=s.ray_address)
68
+ # ... Ray code ...
69
+
70
+ # Session is deleted when context exits (including on exception)
71
+
72
+ # Manual lifecycle
73
+ s = client.sessions.create(name="exp-1", gpu="rtx-4090", idle_timeout_minutes=120)
74
+ s.wait_until_ready(timeout=300) # polls every 5s, raises TimeoutError
75
+ s.refresh() # update status from API
76
+ s.stop() # DELETE
77
+ ```
78
+
79
+ **Session.create() parameters:**
80
+ - `name: str` — K8s-safe, 1-63 chars
81
+ - `gpu: str` — GPU type (default: "rtx-4090")
82
+ - `gpu_count: int` — default 1
83
+ - `cpu: str` — K8s resource (default: "4")
84
+ - `memory: str` — K8s resource (default: "8Gi")
85
+ - `image: str` — container image (default: ray-base)
86
+ - `idle_timeout_minutes: int` — auto-shutdown (default: 60)
87
+ - `env_vars: dict` — environment variables
88
+
89
+ **Session properties:**
90
+ - `id: str`
91
+ - `name: str`
92
+ - `status: str` — "creating", "running", "failed", "stopped"
93
+ - `ray_address: str` — `ray://{name}.research.prosodylabs.com.au:443`
94
+ - `dashboard_url: str`
95
+ - `gpu_type: str`
96
+ - `created_at: str`
97
+
98
+ ### Jobs (fire-and-forget training)
99
+
100
+ ```python
101
+ # From a script file
102
+ job = yarn.submit_job(
103
+ script="train.py", # local file path — SDK reads + base64 encodes
104
+ name="kairos-batch32",
105
+ gpu="rtx-4090",
106
+ pip_packages=["torch", "wandb"],
107
+ env_vars={"WANDB_PROJECT": "kairos"},
108
+ secret_names=["WANDB_API_KEY"],
109
+ max_runtime_hours=4,
110
+ )
111
+
112
+ # From a string
113
+ job = yarn.submit_job(
114
+ code="import torch; print(torch.cuda.is_available())",
115
+ name="gpu-test",
116
+ )
117
+
118
+ # Block until complete
119
+ job.wait(timeout=3600, poll_interval=10) # polls every 10s
120
+
121
+ # Stream logs
122
+ for line in job.stream_logs(tail=100):
123
+ print(line)
124
+
125
+ # Check status
126
+ print(job.status) # "SUCCEEDED", "FAILED", "RUNNING", "PENDING"
127
+ print(job.logs()) # full log output
128
+ print(job.cost) # actual cost in AUD (after completion)
129
+
130
+ # Cancel
131
+ job.cancel()
132
+
133
+ # List all jobs
134
+ jobs = client.jobs.list()
135
+
136
+ # Cost estimation (before submitting)
137
+ estimate = client.jobs.estimate(gpu="rtx-4090", max_runtime_hours=4)
138
+ print(f"${estimate.low:.2f} - ${estimate.high:.2f} AUD")
139
+ ```
140
+
141
+ **submit_job() parameters:**
142
+ - `script: str` — local file path (SDK reads, base64 encodes)
143
+ - `code: str` — inline Python code (alternative to script)
144
+ - `name: str` — job name (K8s-safe)
145
+ - `gpu: str` — GPU type (default: "rtx-4090")
146
+ - `gpu_count: int` — default 1
147
+ - `cpu: str` — default "2"
148
+ - `memory: str` — default "4Gi"
149
+ - `image: str` — container image (default: ray-base)
150
+ - `entrypoint: str` — default "python /home/ray/code/main.py"
151
+ - `pip_packages: list[str]` — installed at runtime
152
+ - `env_vars: dict` — environment variables
153
+ - `secret_names: list[str]` — K8s secrets to inject
154
+ - `max_runtime_hours: float` — 0 = no limit
155
+ - `description: str` — optional
156
+
157
+ Exactly one of `script` or `code` must be provided. If `script`, SDK reads the file and base64-encodes it. If `code`, SDK base64-encodes the string directly.
158
+
159
+ **Job properties:**
160
+ - `id: str`
161
+ - `name: str`
162
+ - `status: str` — "PENDING", "RUNNING", "SUCCEEDED", "FAILED", "STOPPED"
163
+ - `logs() -> str` — fetch latest logs
164
+ - `cost: float | None` — AUD, available after completion
165
+ - `created_at, start_time, end_time: str`
166
+ - `error: str` — error message if failed
167
+
168
+ ### Notebooks (JupyterLab)
169
+
170
+ ```python
171
+ nb = client.notebooks.create(
172
+ name="analysis",
173
+ gpu="rtx-4090",
174
+ idle_timeout_minutes=120,
175
+ data_mounts=["datasets/kairos"],
176
+ )
177
+ nb.wait_until_ready()
178
+ print(nb.jupyter_url) # opens in browser
179
+ # ...
180
+ nb.stop()
181
+ ```
182
+
183
+ ### Storage
184
+
185
+ ```python
186
+ # Upload
187
+ yarn.upload("datasets/train.csv", local_path="./data/train.csv")
188
+ yarn.upload("models/checkpoint.pt", local_path="./best.pt")
189
+
190
+ # Download
191
+ yarn.download("results/metrics.json", local_path="./metrics.json")
192
+
193
+ # List
194
+ objects = client.storage.list(prefix="datasets/")
195
+ for obj in objects:
196
+ print(f"{obj.key} {obj.size} bytes")
197
+
198
+ # Delete
199
+ client.storage.delete("datasets/old.csv")
200
+
201
+ # Usage
202
+ usage = client.storage.usage()
203
+ print(f"{usage.total_bytes / 1e9:.1f} GB used, {usage.percent_used:.0f}% of quota")
204
+ ```
205
+
206
+ ### Billing
207
+
208
+ ```python
209
+ balance = client.billing.balance()
210
+ print(f"${balance.available:.2f} available, ${balance.held:.2f} held")
211
+
212
+ usage = client.billing.usage(period="daily", days=7)
213
+ for day in usage.data:
214
+ print(f"{day.date}: ${day.spend:.2f}")
215
+ ```
216
+
217
+ ### Inference (chat completions)
218
+
219
+ ```python
220
+ # Non-streaming
221
+ response = client.chat.completions(
222
+ model="mistral-7b-instruct-v0.2",
223
+ messages=[{"role": "user", "content": "Hello"}],
224
+ )
225
+ print(response.choices[0].message.content)
226
+
227
+ # Streaming
228
+ for chunk in client.chat.completions(
229
+ model="mistral-7b-instruct-v0.2",
230
+ messages=[{"role": "user", "content": "Hello"}],
231
+ stream=True,
232
+ ):
233
+ print(chunk.delta.content, end="")
234
+ ```
235
+
236
+ ### Secrets
237
+
238
+ ```python
239
+ client.secrets.create(name="WANDB_API_KEY", value="...", description="W&B tracking")
240
+ secrets = client.secrets.list() # names + descriptions only, never values
241
+ client.secrets.delete("WANDB_API_KEY")
242
+ ```
243
+
244
+ ## Error handling
245
+
246
+ ```python
247
+ from yarn.exceptions import (
248
+ YarnError, # base
249
+ AuthenticationError, # 401 — invalid API key
250
+ PermissionError, # 403 — missing researcher role
251
+ InsufficientCredits, # 402 — not enough balance
252
+ NotFoundError, # 404
253
+ QuotaExceeded, # 413 — storage quota
254
+ RateLimited, # 429
255
+ ServiceUnavailable, # 502/503
256
+ )
257
+
258
+ try:
259
+ job = yarn.submit_job(script="train.py")
260
+ except InsufficientCredits as e:
261
+ print(f"Need ${e.required:.2f}, have ${e.available:.2f}")
262
+ except YarnError as e:
263
+ print(f"Error: {e.message} (HTTP {e.status_code})")
264
+ ```
265
+
266
+ All API errors raise typed exceptions with `status_code`, `message`, and the raw response body.
267
+
268
+ ## API endpoint mapping
269
+
270
+ | SDK method | HTTP | Endpoint |
271
+ |---|---|---|
272
+ | `client.sessions.create()` | POST | `/v1/research/sessions` |
273
+ | `client.sessions.get(id)` | GET | `/v1/research/sessions/{id}` |
274
+ | `client.sessions.list()` | GET | `/v1/research/sessions` |
275
+ | `client.sessions.delete(id)` | DELETE | `/v1/research/sessions/{id}` |
276
+ | `client.jobs.submit()` | POST | `/v1/research/jobs` |
277
+ | `client.jobs.get(id)` | GET | `/v1/research/jobs/{id}` |
278
+ | `client.jobs.list()` | GET | `/v1/research/jobs` |
279
+ | `client.jobs.cancel(id)` | DELETE | `/v1/research/jobs/{id}` |
280
+ | `client.jobs.logs(id)` | GET | `/v1/research/jobs/{id}/logs` |
281
+ | `client.jobs.estimate()` | POST | `/v1/research/jobs/estimate` |
282
+ | `client.notebooks.create()` | POST | `/v1/research/notebooks` |
283
+ | `client.notebooks.get(id)` | GET | `/v1/research/notebooks/{id}` |
284
+ | `client.notebooks.list()` | GET | `/v1/research/notebooks` |
285
+ | `client.notebooks.delete(id)` | DELETE | `/v1/research/notebooks/{id}` |
286
+ | `client.secrets.create()` | POST | `/v1/research/secrets` |
287
+ | `client.secrets.list()` | GET | `/v1/research/secrets` |
288
+ | `client.secrets.delete(name)` | DELETE | `/v1/research/secrets/{name}` |
289
+ | `client.storage.upload()` | POST | `/v1/data/objects` |
290
+ | `client.storage.download(key)` | GET | `/v1/data/objects/{key}` |
291
+ | `client.storage.list()` | GET | `/v1/data/objects` |
292
+ | `client.storage.delete(key)` | DELETE | `/v1/data/objects/{key}` |
293
+ | `client.storage.usage()` | GET | `/v1/data/usage` |
294
+ | `client.storage.quota()` | GET | `/v1/data/quota` |
295
+ | `client.storage.datasets()` | GET | `/v1/data/datasets` |
296
+ | `client.billing.balance()` | GET | `/v1/billing/credits/balance` |
297
+ | `client.billing.usage()` | GET | `/v1/billing/credits/usage` |
298
+ | `client.billing.packs()` | GET | `/v1/billing/credits/packs` |
299
+ | `client.chat.completions()` | POST | `/v1/chat/completions` |
300
+ | `client.chat.models()` | GET | `/v1/models` |
301
+
302
+ ## Request/response schemas
303
+
304
+ Full schemas for every endpoint are documented in the exploration agent's output. Key ones for implementation:
305
+
306
+ ### Job submission request
307
+ ```json
308
+ {
309
+ "name": "str (1-63, K8s-safe)",
310
+ "code": "str (base64-encoded)",
311
+ "entrypoint": "str (default: python /home/ray/code/main.py)",
312
+ "image": "str (allowlist: ghcr.io/jordanlochhill/ray-base:*)",
313
+ "gpu_count": "int (0-8, default: 1)",
314
+ "cpu": "str (default: 2)",
315
+ "memory": "str (default: 4Gi)",
316
+ "env_vars": "dict[str, str]",
317
+ "description": "str (max 500)",
318
+ "max_runtime_hours": "float (0-24)",
319
+ "pip_packages": "list[str]",
320
+ "secret_names": "list[str]"
321
+ }
322
+ ```
323
+
324
+ ### Session response
325
+ ```json
326
+ {
327
+ "id": "str",
328
+ "session_name": "str",
329
+ "status": "created | running | failed",
330
+ "gpu_type": "rtx-4090",
331
+ "gpu_count": 1,
332
+ "ray_endpoint": "ray://{name}.research.prosodylabs.com.au:443",
333
+ "dashboard_url": "https://{name}-dash.research.prosodylabs.com.au",
334
+ "created_at": "ISO datetime",
335
+ "idle_timeout_minutes": 60
336
+ }
337
+ ```
338
+
339
+ ### Credit balance response
340
+ ```json
341
+ {
342
+ "balance": 142.50,
343
+ "available": 97.50,
344
+ "held": 45.00,
345
+ "total_credited": 500.00,
346
+ "total_debited": 357.50,
347
+ "transaction_count": 23,
348
+ "currency": "AUD",
349
+ "active_holds": [],
350
+ "recent_transactions": []
351
+ }
352
+ ```
353
+
354
+ ## Implementation notes
355
+
356
+ 1. **`gpu` parameter mapping:** The SDK accepts friendly names like `"rtx-4090"`, `"a100"`, `"h100"`. Map to the exact strings the API expects. Default is `"rtx-4090"`.
357
+
358
+ 2. **Polling in wait():** Use exponential backoff starting at 2s, capping at 30s. Log status transitions to stderr if verbose mode is on.
359
+
360
+ 3. **Session context manager:** `__enter__` calls `create()` then `wait_until_ready()`. `__exit__` calls `stop()` regardless of exception. If creation fails, don't call stop.
361
+
362
+ 4. **File reading in submit_job:** Read the file, base64-encode, and send as the `code` field. The SDK does the encoding — the researcher passes a file path.
363
+
364
+ 5. **Streaming logs:** `stream_logs()` returns a generator that polls `/jobs/{id}/logs?tail=N` every 2s and yields new lines. Stops when job status is terminal.
365
+
366
+ 6. **Storage upload:** Use `multipart/form-data` with the file. Set `key` query param to the remote path.
367
+
368
+ 7. **Thread safety:** `Client` instances should be thread-safe (each request is independent). Don't share mutable state.
369
+
370
+ 8. **Timeouts:** Default request timeout 30s for normal calls, 300s for uploads/downloads, configurable via `client.timeout`.
371
+
372
+ ## pyproject.toml
373
+
374
+ ```toml
375
+ [project]
376
+ name = "yarn"
377
+ version = "0.1.0"
378
+ description = "Python SDK for Yarn — sovereign GPU compute"
379
+ requires-python = ">=3.10"
380
+ dependencies = ["requests>=2.28"]
381
+
382
+ [build-system]
383
+ requires = ["hatchling"]
384
+ build-backend = "hatchling.build"
385
+
386
+ [tool.hatch.build.targets.wheel]
387
+ packages = ["src/yarn"]
388
+ ```
yarn_au-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Prosody Labs Pty Ltd
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
yarn_au-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.4
2
+ Name: yarn-au
3
+ Version: 0.1.0
4
+ Summary: Python SDK for Yarn — sovereign GPU compute for Australian researchers
5
+ Project-URL: Homepage, https://prosodylabs.com.au
6
+ Project-URL: Repository, https://github.com/prosodylabs/yarn
7
+ Author-email: Prosody Labs <jordan@prosodylabs.com.au>
8
+ License: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.10
15
+ Requires-Dist: requests>=2.28
@@ -0,0 +1,60 @@
1
+ # Yarn
2
+
3
+ Python SDK for [Yarn](https://prosodylabs.com.au) — sovereign GPU compute for Australian researchers.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install yarn
9
+ ```
10
+
11
+ ## Quick start
12
+
13
+ ```python
14
+ import yarn
15
+
16
+ client = yarn.Client(api_key="yarn_...")
17
+
18
+ # Submit a training job
19
+ job = client.jobs.submit(
20
+ script="train.py",
21
+ name="my-experiment",
22
+ gpu="rtx-4090",
23
+ pip_packages=["torch", "wandb"],
24
+ )
25
+ job.wait()
26
+ print(job.logs())
27
+
28
+ # Interactive GPU session
29
+ with client.sessions.session("exp-1", gpu="rtx-4090") as s:
30
+ print(s.ray_address) # Connect with Ray
31
+
32
+ # Check balance
33
+ balance = client.billing.balance()
34
+ print(f"${balance['available']:.2f} available")
35
+ ```
36
+
37
+ ## Authentication
38
+
39
+ Set your API key as an environment variable:
40
+
41
+ ```bash
42
+ export YARN_API_KEY=yarn_...
43
+ ```
44
+
45
+ Or pass it directly:
46
+
47
+ ```python
48
+ client = yarn.Client(api_key="yarn_...")
49
+ ```
50
+
51
+ Create API keys at [account.yarn.prosodylabs.com.au](https://account.yarn.prosodylabs.com.au).
52
+
53
+ ## Documentation
54
+
55
+ - [API Reference](https://prosodylabs.com.au)
56
+ - [Design Specification](./DESIGN.md)
57
+
58
+ ## License
59
+
60
+ MIT — Prosody Labs Pty Ltd
@@ -0,0 +1,22 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "yarn-au"
7
+ version = "0.1.0"
8
+ description = "Python SDK for Yarn — sovereign GPU compute for Australian researchers"
9
+ requires-python = ">=3.10"
10
+ dependencies = ["requests>=2.28"]
11
+ license = {text = "MIT"}
12
+ authors = [{name = "Prosody Labs", email = "jordan@prosodylabs.com.au"}]
13
+ urls = {Homepage = "https://prosodylabs.com.au", Repository = "https://github.com/prosodylabs/yarn"}
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Science/Research",
17
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
18
+ "Programming Language :: Python :: 3",
19
+ ]
20
+
21
+ [tool.hatch.build.targets.wheel]
22
+ packages = ["src/yarn"]
@@ -0,0 +1,26 @@
1
+ """Yarn Python SDK."""
2
+ from .client import Client
3
+ from .exceptions import (
4
+ AuthenticationError,
5
+ InsufficientCredits,
6
+ NotFoundError,
7
+ PermissionError,
8
+ QuotaExceeded,
9
+ RateLimitError,
10
+ ServiceUnavailable,
11
+ YarnError,
12
+ )
13
+
14
+ __version__ = "0.1.0"
15
+
16
+ __all__ = [
17
+ "Client",
18
+ "YarnError",
19
+ "AuthenticationError",
20
+ "PermissionError",
21
+ "InsufficientCredits",
22
+ "NotFoundError",
23
+ "QuotaExceeded",
24
+ "RateLimitError",
25
+ "ServiceUnavailable",
26
+ ]
@@ -0,0 +1,34 @@
1
+ """Billing and credit operations."""
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING:
7
+ from .client import Client
8
+
9
+
10
+ class Billing:
11
+ def __init__(self, client: Client):
12
+ self._client = client
13
+
14
+ def balance(self) -> dict:
15
+ """Get current credit balance."""
16
+ return self._client.get("/v1/billing/credits/balance")
17
+
18
+ def usage(
19
+ self,
20
+ period: str = "daily",
21
+ from_date: str = "",
22
+ to_date: str = "",
23
+ ) -> dict:
24
+ """Get credit usage over a time period."""
25
+ params: dict[str, str] = {"period": period}
26
+ if from_date:
27
+ params["from"] = from_date
28
+ if to_date:
29
+ params["to"] = to_date
30
+ return self._client.get("/v1/billing/credits/usage", params=params)
31
+
32
+ def packs(self) -> list[dict]:
33
+ """List available credit packs."""
34
+ return self._client.get("/v1/billing/credits/packs")
@@ -0,0 +1,64 @@
1
+ """Chat completions (inference)."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from typing import TYPE_CHECKING, Generator
6
+
7
+ if TYPE_CHECKING:
8
+ from .client import Client
9
+
10
+
11
+ class Chat:
12
+ def __init__(self, client: Client):
13
+ self._client = client
14
+
15
+ def completions(
16
+ self,
17
+ model: str,
18
+ messages: list[dict],
19
+ *,
20
+ stream: bool = False,
21
+ temperature: float | None = None,
22
+ max_tokens: int | None = None,
23
+ **kwargs,
24
+ ) -> dict | Generator[dict, None, None]:
25
+ """Send a chat completion request.
26
+
27
+ When ``stream=True``, returns a generator yielding SSE chunks.
28
+ """
29
+ body: dict = {"model": model, "messages": messages, "stream": stream}
30
+ if temperature is not None:
31
+ body["temperature"] = temperature
32
+ if max_tokens is not None:
33
+ body["max_tokens"] = max_tokens
34
+ body.update(kwargs)
35
+
36
+ if not stream:
37
+ return self._client.post("/v1/chat/completions", json=body)
38
+
39
+ return self._stream(body)
40
+
41
+ def _stream(self, body: dict) -> Generator[dict, None, None]:
42
+ """Handle SSE streaming responses."""
43
+ url = f"{self._client.base_url}/v1/chat/completions"
44
+ resp = self._client._session.post(
45
+ url, json=body, stream=True, timeout=self._client.timeout
46
+ )
47
+ if resp.status_code >= 400:
48
+ # Consume the response and raise through the normal error path
49
+ self._client._request("POST", "/v1/chat/completions", json=body)
50
+
51
+ for line in resp.iter_lines(decode_unicode=True):
52
+ if not line or not line.startswith("data: "):
53
+ continue
54
+ data = line[6:]
55
+ if data.strip() == "[DONE]":
56
+ return
57
+ try:
58
+ yield json.loads(data)
59
+ except json.JSONDecodeError:
60
+ continue
61
+
62
+ def models(self) -> list[dict]:
63
+ """List available models."""
64
+ return self._client.get("/v1/models")