inferencesh 0.2.31__py3-none-any.whl → 0.4.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inferencesh/__init__.py +5 -0
- inferencesh/client.py +1081 -0
- inferencesh/models/base.py +81 -3
- inferencesh/models/file.py +120 -21
- inferencesh/models/llm.py +251 -77
- inferencesh/utils/download.py +15 -7
- inferencesh-0.4.29.dist-info/METADATA +196 -0
- inferencesh-0.4.29.dist-info/RECORD +15 -0
- inferencesh-0.2.31.dist-info/METADATA +0 -105
- inferencesh-0.2.31.dist-info/RECORD +0 -14
- {inferencesh-0.2.31.dist-info → inferencesh-0.4.29.dist-info}/WHEEL +0 -0
- {inferencesh-0.2.31.dist-info → inferencesh-0.4.29.dist-info}/entry_points.txt +0 -0
- {inferencesh-0.2.31.dist-info → inferencesh-0.4.29.dist-info}/licenses/LICENSE +0 -0
- {inferencesh-0.2.31.dist-info → inferencesh-0.4.29.dist-info}/top_level.txt +0 -0
inferencesh/client.py
ADDED
|
@@ -0,0 +1,1081 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Callable, Generator, Union, Iterator
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import IntEnum
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
import time
|
|
9
|
+
import mimetypes
|
|
10
|
+
import os
|
|
11
|
+
from contextlib import AbstractContextManager
|
|
12
|
+
from typing import Protocol, runtime_checkable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TaskStream(AbstractContextManager['TaskStream']):
|
|
16
|
+
"""A context manager for streaming task updates.
|
|
17
|
+
|
|
18
|
+
This class provides a Pythonic interface for handling streaming updates from a task.
|
|
19
|
+
It can be used either as a context manager or as an iterator.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
```python
|
|
23
|
+
# As a context manager
|
|
24
|
+
with client.stream_task(task_id) as stream:
|
|
25
|
+
for update in stream:
|
|
26
|
+
print(f"Update: {update}")
|
|
27
|
+
|
|
28
|
+
# As an iterator
|
|
29
|
+
for update in client.stream_task(task_id):
|
|
30
|
+
print(f"Update: {update}")
|
|
31
|
+
```
|
|
32
|
+
"""
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
task: Dict[str, Any],
|
|
36
|
+
client: Any,
|
|
37
|
+
auto_reconnect: bool = True,
|
|
38
|
+
max_reconnects: int = 5,
|
|
39
|
+
reconnect_delay_ms: int = 1000,
|
|
40
|
+
):
|
|
41
|
+
self.task = task
|
|
42
|
+
self.client = client
|
|
43
|
+
self.task_id = task["id"]
|
|
44
|
+
self.auto_reconnect = auto_reconnect
|
|
45
|
+
self.max_reconnects = max_reconnects
|
|
46
|
+
self.reconnect_delay_ms = reconnect_delay_ms
|
|
47
|
+
self._final_task: Optional[Dict[str, Any]] = None
|
|
48
|
+
self._error: Optional[Exception] = None
|
|
49
|
+
|
|
50
|
+
def __enter__(self) -> 'TaskStream':
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
57
|
+
return self.stream()
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def result(self) -> Optional[Dict[str, Any]]:
|
|
61
|
+
"""The final task result if completed, None otherwise."""
|
|
62
|
+
return self._final_task
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def error(self) -> Optional[Exception]:
|
|
66
|
+
"""The error that occurred during streaming, if any."""
|
|
67
|
+
return self._error
|
|
68
|
+
|
|
69
|
+
def stream(self) -> Iterator[Dict[str, Any]]:
|
|
70
|
+
"""Stream updates for this task.
|
|
71
|
+
|
|
72
|
+
Yields:
|
|
73
|
+
Dict[str, Any]: Task update events
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
RuntimeError: If the task fails or is cancelled
|
|
77
|
+
"""
|
|
78
|
+
try:
|
|
79
|
+
for update in self.client._stream_updates(
|
|
80
|
+
self.task_id,
|
|
81
|
+
self.task,
|
|
82
|
+
):
|
|
83
|
+
if isinstance(update, Exception):
|
|
84
|
+
self._error = update
|
|
85
|
+
raise update
|
|
86
|
+
if update.get("status") == TaskStatus.COMPLETED:
|
|
87
|
+
self._final_task = update
|
|
88
|
+
yield update
|
|
89
|
+
except Exception as exc:
|
|
90
|
+
self._error = exc
|
|
91
|
+
raise
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@runtime_checkable
|
|
95
|
+
class TaskCallback(Protocol):
|
|
96
|
+
"""Protocol for task streaming callbacks."""
|
|
97
|
+
def on_update(self, data: Dict[str, Any]) -> None:
|
|
98
|
+
"""Called when a task update is received."""
|
|
99
|
+
...
|
|
100
|
+
|
|
101
|
+
def on_error(self, error: Exception) -> None:
|
|
102
|
+
"""Called when an error occurs during task execution."""
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
def on_complete(self, task: Dict[str, Any]) -> None:
|
|
106
|
+
"""Called when a task completes successfully."""
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
|
|
111
|
+
def _require_requests():
|
|
112
|
+
try:
|
|
113
|
+
import requests # type: ignore
|
|
114
|
+
return requests
|
|
115
|
+
except Exception as exc: # pragma: no cover - dependency hint
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"The 'requests' package is required for synchronous HTTP calls. Install with: pip install requests"
|
|
118
|
+
) from exc
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def _require_aiohttp():
|
|
122
|
+
try:
|
|
123
|
+
import aiohttp # type: ignore
|
|
124
|
+
return aiohttp
|
|
125
|
+
except Exception as exc: # pragma: no cover - dependency hint
|
|
126
|
+
raise RuntimeError(
|
|
127
|
+
"The 'aiohttp' package is required for async HTTP calls. Install with: pip install aiohttp"
|
|
128
|
+
) from exc
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TaskStatus(IntEnum):
|
|
132
|
+
RECEIVED = 1
|
|
133
|
+
QUEUED = 2
|
|
134
|
+
SCHEDULED = 3
|
|
135
|
+
PREPARING = 4
|
|
136
|
+
SERVING = 5
|
|
137
|
+
SETTING_UP = 6
|
|
138
|
+
RUNNING = 7
|
|
139
|
+
UPLOADING = 8
|
|
140
|
+
COMPLETED = 9
|
|
141
|
+
FAILED = 10
|
|
142
|
+
CANCELLED = 11
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
Base64_RE = re.compile(r"^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass
|
|
149
|
+
class UploadFileOptions:
|
|
150
|
+
filename: Optional[str] = None
|
|
151
|
+
content_type: Optional[str] = None
|
|
152
|
+
path: Optional[str] = None
|
|
153
|
+
public: Optional[bool] = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class StreamManager:
|
|
157
|
+
"""Simple SSE stream manager with optional auto-reconnect."""
|
|
158
|
+
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
*,
|
|
162
|
+
create_event_source: Callable[[], Any],
|
|
163
|
+
auto_reconnect: bool = True,
|
|
164
|
+
max_reconnects: int = 5,
|
|
165
|
+
reconnect_delay_ms: int = 1000,
|
|
166
|
+
on_error: Optional[Callable[[Exception], None]] = None,
|
|
167
|
+
on_start: Optional[Callable[[], None]] = None,
|
|
168
|
+
on_stop: Optional[Callable[[], None]] = None,
|
|
169
|
+
on_data: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
170
|
+
on_partial_data: Optional[Callable[[Dict[str, Any], list[str]], None]] = None,
|
|
171
|
+
) -> None:
|
|
172
|
+
self._create_event_source = create_event_source
|
|
173
|
+
self._auto_reconnect = auto_reconnect
|
|
174
|
+
self._max_reconnects = max_reconnects
|
|
175
|
+
self._reconnect_delay_ms = reconnect_delay_ms
|
|
176
|
+
self._on_error = on_error
|
|
177
|
+
self._on_start = on_start
|
|
178
|
+
self._on_stop = on_stop
|
|
179
|
+
self._on_data = on_data
|
|
180
|
+
self._on_partial_data = on_partial_data
|
|
181
|
+
|
|
182
|
+
self._stopped = False
|
|
183
|
+
self._reconnect_attempts = 0
|
|
184
|
+
self._had_successful_connection = False
|
|
185
|
+
|
|
186
|
+
def stop(self) -> None:
|
|
187
|
+
self._stopped = True
|
|
188
|
+
if self._on_stop:
|
|
189
|
+
self._on_stop()
|
|
190
|
+
|
|
191
|
+
def connect(self) -> None:
|
|
192
|
+
self._stopped = False
|
|
193
|
+
self._reconnect_attempts = 0
|
|
194
|
+
while not self._stopped:
|
|
195
|
+
try:
|
|
196
|
+
if self._on_start:
|
|
197
|
+
self._on_start()
|
|
198
|
+
event_source = self._create_event_source()
|
|
199
|
+
try:
|
|
200
|
+
for data in event_source:
|
|
201
|
+
if self._stopped:
|
|
202
|
+
break
|
|
203
|
+
self._had_successful_connection = True
|
|
204
|
+
|
|
205
|
+
# Handle generic messages through on_data callback
|
|
206
|
+
# Try parsing as {data: T, fields: []} structure first
|
|
207
|
+
print(f" {data}")
|
|
208
|
+
if (
|
|
209
|
+
isinstance(data, dict)
|
|
210
|
+
and "data" in data
|
|
211
|
+
and "fields" in data
|
|
212
|
+
and isinstance(data.get("fields"), list)
|
|
213
|
+
):
|
|
214
|
+
# Partial data structure detected
|
|
215
|
+
if self._on_partial_data:
|
|
216
|
+
self._on_partial_data(data["data"], data["fields"])
|
|
217
|
+
elif self._on_data:
|
|
218
|
+
# Fall back to on_data with just the data if on_partial_data not provided
|
|
219
|
+
self._on_data(data["data"])
|
|
220
|
+
elif self._on_data:
|
|
221
|
+
# Otherwise treat the whole thing as data
|
|
222
|
+
self._on_data(data)
|
|
223
|
+
|
|
224
|
+
# Check again after processing in case callbacks stopped us
|
|
225
|
+
if self._stopped:
|
|
226
|
+
break
|
|
227
|
+
finally:
|
|
228
|
+
# Clean up the event source if it has a close method
|
|
229
|
+
try:
|
|
230
|
+
if hasattr(event_source, 'close'):
|
|
231
|
+
event_source.close()
|
|
232
|
+
except Exception:
|
|
233
|
+
raise
|
|
234
|
+
|
|
235
|
+
# If we're stopped or don't want to auto-reconnect, break immediately
|
|
236
|
+
if self._stopped or not self._auto_reconnect:
|
|
237
|
+
break
|
|
238
|
+
except Exception as exc: # noqa: BLE001
|
|
239
|
+
if self._on_error:
|
|
240
|
+
self._on_error(exc)
|
|
241
|
+
if self._stopped:
|
|
242
|
+
break
|
|
243
|
+
# If never connected and exceeded attempts, stop
|
|
244
|
+
if not self._had_successful_connection:
|
|
245
|
+
self._reconnect_attempts += 1
|
|
246
|
+
if self._reconnect_attempts > self._max_reconnects:
|
|
247
|
+
break
|
|
248
|
+
time.sleep(self._reconnect_delay_ms / 1000.0)
|
|
249
|
+
else:
|
|
250
|
+
# Completed without exception - if we want to auto-reconnect only after success
|
|
251
|
+
if not self._auto_reconnect:
|
|
252
|
+
break
|
|
253
|
+
time.sleep(self._reconnect_delay_ms / 1000.0)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class Inference:
|
|
257
|
+
"""Synchronous client for inference.sh API, mirroring the JS SDK behavior.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
api_key (str): The API key for authentication
|
|
261
|
+
base_url (Optional[str]): Override the default API base URL
|
|
262
|
+
sse_chunk_size (Optional[int]): Chunk size for SSE reading (default: 8192 bytes)
|
|
263
|
+
sse_mode (Optional[str]): SSE reading mode ('iter_lines' or 'raw', default: 'iter_lines')
|
|
264
|
+
|
|
265
|
+
The client supports performance tuning for SSE (Server-Sent Events) through:
|
|
266
|
+
1. sse_chunk_size: Controls the buffer size for reading SSE data (default: 8KB)
|
|
267
|
+
- Larger values may improve performance but use more memory
|
|
268
|
+
- Can also be set via INFERENCE_SSE_READ_BYTES environment variable
|
|
269
|
+
2. sse_mode: Controls how SSE data is read ('iter_lines' or 'raw')
|
|
270
|
+
- 'iter_lines': Uses requests' built-in line iteration (default)
|
|
271
|
+
- 'raw': Uses lower-level socket reading
|
|
272
|
+
- Can also be set via INFERENCE_SSE_MODE environment variable
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
*,
|
|
278
|
+
api_key: str,
|
|
279
|
+
base_url: Optional[str] = None,
|
|
280
|
+
sse_chunk_size: Optional[int] = None,
|
|
281
|
+
sse_mode: Optional[str] = None,
|
|
282
|
+
) -> None:
|
|
283
|
+
self._api_key = api_key
|
|
284
|
+
self._base_url = base_url or "https://api.inference.sh"
|
|
285
|
+
|
|
286
|
+
# SSE configuration with environment variable fallbacks
|
|
287
|
+
self._sse_mode = sse_mode or os.getenv("INFERENCE_SSE_MODE") or "iter_lines"
|
|
288
|
+
self._sse_mode = self._sse_mode.lower()
|
|
289
|
+
|
|
290
|
+
# Default to 8KB chunks, can be overridden by parameter or env var
|
|
291
|
+
try:
|
|
292
|
+
env_chunk_size = os.getenv("INFERENCE_SSE_READ_BYTES")
|
|
293
|
+
if sse_chunk_size is not None:
|
|
294
|
+
self._sse_read_bytes = sse_chunk_size
|
|
295
|
+
elif env_chunk_size is not None:
|
|
296
|
+
self._sse_read_bytes = int(env_chunk_size)
|
|
297
|
+
else:
|
|
298
|
+
self._sse_read_bytes = 8192 # 8KB default
|
|
299
|
+
except Exception:
|
|
300
|
+
self._sse_read_bytes = 8192 # Default to 8KB chunks on error
|
|
301
|
+
|
|
302
|
+
# --------------- HTTP helpers ---------------
|
|
303
|
+
def _headers(self) -> Dict[str, str]:
|
|
304
|
+
return {
|
|
305
|
+
"Content-Type": "application/json",
|
|
306
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
def _request(
|
|
310
|
+
self,
|
|
311
|
+
method: str,
|
|
312
|
+
endpoint: str,
|
|
313
|
+
*,
|
|
314
|
+
params: Optional[Dict[str, Any]] = None,
|
|
315
|
+
data: Optional[Dict[str, Any]] = None,
|
|
316
|
+
headers: Optional[Dict[str, str]] = None,
|
|
317
|
+
stream: bool = False,
|
|
318
|
+
timeout: Optional[float] = None,
|
|
319
|
+
) -> Any:
|
|
320
|
+
requests = _require_requests()
|
|
321
|
+
url = f"{self._base_url}{endpoint}"
|
|
322
|
+
merged_headers = {**self._headers(), **(headers or {})}
|
|
323
|
+
resp = requests.request(
|
|
324
|
+
method=method.upper(),
|
|
325
|
+
url=url,
|
|
326
|
+
params=params,
|
|
327
|
+
data=json.dumps(data) if data is not None else None,
|
|
328
|
+
headers=merged_headers,
|
|
329
|
+
stream=stream,
|
|
330
|
+
timeout=timeout or 30,
|
|
331
|
+
)
|
|
332
|
+
if stream:
|
|
333
|
+
return resp
|
|
334
|
+
resp.raise_for_status()
|
|
335
|
+
payload = resp.json()
|
|
336
|
+
if not isinstance(payload, dict) or not payload.get("success", False):
|
|
337
|
+
message = None
|
|
338
|
+
if isinstance(payload, dict) and payload.get("error"):
|
|
339
|
+
err = payload["error"]
|
|
340
|
+
if isinstance(err, dict):
|
|
341
|
+
message = err.get("message")
|
|
342
|
+
else:
|
|
343
|
+
message = str(err)
|
|
344
|
+
raise RuntimeError(message or "Request failed")
|
|
345
|
+
return payload.get("data")
|
|
346
|
+
|
|
347
|
+
# --------------- Public API ---------------
|
|
348
|
+
def run(
|
|
349
|
+
self,
|
|
350
|
+
params: Dict[str, Any],
|
|
351
|
+
*,
|
|
352
|
+
wait: bool = True,
|
|
353
|
+
stream: bool = False,
|
|
354
|
+
auto_reconnect: bool = True,
|
|
355
|
+
max_reconnects: int = 5,
|
|
356
|
+
reconnect_delay_ms: int = 1000,
|
|
357
|
+
) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
|
|
358
|
+
"""Run a task with optional streaming updates.
|
|
359
|
+
|
|
360
|
+
By default, this method waits for the task to complete and returns the final result.
|
|
361
|
+
You can set wait=False to get just the task info, or stream=True to get an iterator
|
|
362
|
+
of status updates.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
params: Task parameters to pass to the API
|
|
366
|
+
wait: Whether to wait for task completion (default: True)
|
|
367
|
+
stream: Whether to return an iterator of updates (default: False)
|
|
368
|
+
auto_reconnect: Whether to automatically reconnect on connection loss
|
|
369
|
+
max_reconnects: Maximum number of reconnection attempts
|
|
370
|
+
reconnect_delay_ms: Delay between reconnection attempts in milliseconds
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
|
|
374
|
+
- If wait=True and stream=False: The completed task data
|
|
375
|
+
- If wait=False: The created task info
|
|
376
|
+
- If stream=True: An iterator of task updates
|
|
377
|
+
|
|
378
|
+
Example:
|
|
379
|
+
```python
|
|
380
|
+
# Simple usage - wait for result (default)
|
|
381
|
+
result = client.run(params)
|
|
382
|
+
print(f"Output: {result['output']}")
|
|
383
|
+
|
|
384
|
+
# Get task info without waiting
|
|
385
|
+
task = client.run(params, wait=False)
|
|
386
|
+
task_id = task["id"]
|
|
387
|
+
|
|
388
|
+
# Stream updates
|
|
389
|
+
stream = client.run(params, stream=True)
|
|
390
|
+
for update in stream:
|
|
391
|
+
print(f"Status: {update.get('status')}")
|
|
392
|
+
if update.get('status') == TaskStatus.COMPLETED:
|
|
393
|
+
print(f"Result: {update.get('output')}")
|
|
394
|
+
```
|
|
395
|
+
"""
|
|
396
|
+
# Create the task
|
|
397
|
+
processed_input = self._process_input_data(params.get("input"))
|
|
398
|
+
task = self._request("post", "/run", data={**params, "input": processed_input})
|
|
399
|
+
|
|
400
|
+
# Return immediately if not waiting
|
|
401
|
+
if not wait and not stream:
|
|
402
|
+
return _strip_task(task)
|
|
403
|
+
|
|
404
|
+
# Return stream if requested
|
|
405
|
+
if stream:
|
|
406
|
+
task_stream = TaskStream(
|
|
407
|
+
task=task,
|
|
408
|
+
client=self,
|
|
409
|
+
auto_reconnect=auto_reconnect,
|
|
410
|
+
max_reconnects=max_reconnects,
|
|
411
|
+
reconnect_delay_ms=reconnect_delay_ms,
|
|
412
|
+
)
|
|
413
|
+
return task_stream
|
|
414
|
+
|
|
415
|
+
# Otherwise wait for completion
|
|
416
|
+
return self.wait_for_completion(task["id"])
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def cancel(self, task_id: str) -> None:
|
|
421
|
+
self._request("post", f"/tasks/{task_id}/cancel")
|
|
422
|
+
|
|
423
|
+
def get_task(self, task_id: str) -> Dict[str, Any]:
|
|
424
|
+
"""Get the current state of a task.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
task_id: The ID of the task to get
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
Dict[str, Any]: The current task state
|
|
431
|
+
"""
|
|
432
|
+
return self._request("get", f"/tasks/{task_id}")
|
|
433
|
+
|
|
434
|
+
def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
|
|
435
|
+
"""Wait for a task to complete and return its final state.
|
|
436
|
+
|
|
437
|
+
This method polls the task status until it reaches a terminal state
|
|
438
|
+
(completed, failed, or cancelled).
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
task_id: The ID of the task to wait for
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Dict[str, Any]: The final task state
|
|
445
|
+
|
|
446
|
+
Raises:
|
|
447
|
+
RuntimeError: If the task fails or is cancelled
|
|
448
|
+
"""
|
|
449
|
+
with self.stream_task(task_id) as stream:
|
|
450
|
+
for update in stream:
|
|
451
|
+
if update.get("status") == TaskStatus.COMPLETED:
|
|
452
|
+
return update
|
|
453
|
+
elif update.get("status") == TaskStatus.FAILED:
|
|
454
|
+
raise RuntimeError(update.get("error") or "Task failed")
|
|
455
|
+
elif update.get("status") == TaskStatus.CANCELLED:
|
|
456
|
+
raise RuntimeError("Task cancelled")
|
|
457
|
+
raise RuntimeError("Stream ended without completion")
|
|
458
|
+
|
|
459
|
+
# --------------- File upload ---------------
|
|
460
|
+
def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
|
|
461
|
+
options = options or UploadFileOptions()
|
|
462
|
+
content_type = options.content_type
|
|
463
|
+
raw_bytes: bytes
|
|
464
|
+
if isinstance(data, bytes):
|
|
465
|
+
raw_bytes = data
|
|
466
|
+
if not content_type:
|
|
467
|
+
content_type = "application/octet-stream"
|
|
468
|
+
else:
|
|
469
|
+
# Prefer local filesystem path if it exists
|
|
470
|
+
if os.path.exists(data):
|
|
471
|
+
path = data
|
|
472
|
+
guessed = mimetypes.guess_type(path)[0]
|
|
473
|
+
content_type = content_type or guessed or "application/octet-stream"
|
|
474
|
+
with open(path, "rb") as f:
|
|
475
|
+
raw_bytes = f.read()
|
|
476
|
+
if not options.filename:
|
|
477
|
+
options.filename = os.path.basename(path)
|
|
478
|
+
elif data.startswith("data:"):
|
|
479
|
+
# data URI
|
|
480
|
+
match = re.match(r"^data:([^;]+);base64,(.+)$", data)
|
|
481
|
+
if not match:
|
|
482
|
+
raise ValueError("Invalid base64 data URI format")
|
|
483
|
+
content_type = content_type or match.group(1)
|
|
484
|
+
raw_bytes = _b64_to_bytes(match.group(2))
|
|
485
|
+
elif _looks_like_base64(data):
|
|
486
|
+
raw_bytes = _b64_to_bytes(data)
|
|
487
|
+
content_type = content_type or "application/octet-stream"
|
|
488
|
+
else:
|
|
489
|
+
raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
|
|
490
|
+
|
|
491
|
+
file_req = {
|
|
492
|
+
"files": [
|
|
493
|
+
{
|
|
494
|
+
"uri": "",
|
|
495
|
+
"filename": options.filename,
|
|
496
|
+
"content_type": content_type,
|
|
497
|
+
"path": options.path,
|
|
498
|
+
"size": len(raw_bytes),
|
|
499
|
+
"public": options.public,
|
|
500
|
+
}
|
|
501
|
+
]
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
created = self._request("post", "/files", data=file_req)
|
|
505
|
+
file_obj = created[0]
|
|
506
|
+
upload_url = file_obj.get("upload_url")
|
|
507
|
+
if not upload_url:
|
|
508
|
+
raise RuntimeError("No upload URL provided by the server")
|
|
509
|
+
|
|
510
|
+
# Upload to S3 (or compatible) signed URL
|
|
511
|
+
requests = _require_requests()
|
|
512
|
+
put_resp = requests.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type})
|
|
513
|
+
if not (200 <= put_resp.status_code < 300):
|
|
514
|
+
raise RuntimeError(f"Failed to upload file content: {put_resp.reason}")
|
|
515
|
+
return file_obj
|
|
516
|
+
|
|
517
|
+
# --------------- Helpers ---------------
|
|
518
|
+
def stream_task(
|
|
519
|
+
self,
|
|
520
|
+
task_id: str,
|
|
521
|
+
*,
|
|
522
|
+
auto_reconnect: bool = True,
|
|
523
|
+
max_reconnects: int = 5,
|
|
524
|
+
reconnect_delay_ms: int = 1000,
|
|
525
|
+
) -> TaskStream:
|
|
526
|
+
"""Create a TaskStream for getting streaming updates from a task.
|
|
527
|
+
|
|
528
|
+
This provides a more Pythonic interface for handling task updates compared to callbacks.
|
|
529
|
+
The returned TaskStream can be used either as a context manager or as an iterator.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
task_id: The ID of the task to stream
|
|
533
|
+
auto_reconnect: Whether to automatically reconnect on connection loss
|
|
534
|
+
max_reconnects: Maximum number of reconnection attempts
|
|
535
|
+
reconnect_delay_ms: Delay between reconnection attempts in milliseconds
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
TaskStream: A stream interface for the task
|
|
539
|
+
|
|
540
|
+
Example:
|
|
541
|
+
```python
|
|
542
|
+
# Run a task
|
|
543
|
+
task = client.run(params)
|
|
544
|
+
|
|
545
|
+
# Stream updates using context manager
|
|
546
|
+
with client.stream_task(task["id"]) as stream:
|
|
547
|
+
for update in stream:
|
|
548
|
+
print(f"Status: {update.get('status')}")
|
|
549
|
+
if update.get("status") == TaskStatus.COMPLETED:
|
|
550
|
+
print(f"Result: {update.get('output')}")
|
|
551
|
+
|
|
552
|
+
# Or use as a simple iterator
|
|
553
|
+
for update in client.stream_task(task["id"]):
|
|
554
|
+
print(f"Update: {update}")
|
|
555
|
+
```
|
|
556
|
+
"""
|
|
557
|
+
task = self.get_task(task_id)
|
|
558
|
+
return TaskStream(
|
|
559
|
+
task=task,
|
|
560
|
+
client=self,
|
|
561
|
+
auto_reconnect=auto_reconnect,
|
|
562
|
+
max_reconnects=max_reconnects,
|
|
563
|
+
reconnect_delay_ms=reconnect_delay_ms,
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
def _stream_updates(
|
|
567
|
+
self,
|
|
568
|
+
task_id: str,
|
|
569
|
+
task: Dict[str, Any],
|
|
570
|
+
) -> Generator[Union[Dict[str, Any], Exception], None, None]:
|
|
571
|
+
"""Internal method to stream task updates."""
|
|
572
|
+
url = f"/tasks/{task_id}/stream"
|
|
573
|
+
resp = self._request(
|
|
574
|
+
"get",
|
|
575
|
+
url,
|
|
576
|
+
headers={
|
|
577
|
+
"Accept": "text/event-stream",
|
|
578
|
+
"Cache-Control": "no-cache",
|
|
579
|
+
"Accept-Encoding": "identity",
|
|
580
|
+
"Connection": "keep-alive",
|
|
581
|
+
},
|
|
582
|
+
stream=True,
|
|
583
|
+
timeout=60,
|
|
584
|
+
)
|
|
585
|
+
try:
|
|
586
|
+
for evt in self._iter_sse(resp):
|
|
587
|
+
try:
|
|
588
|
+
# Handle generic messages - try parsing as {data: T, fields: []} structure first
|
|
589
|
+
if (
|
|
590
|
+
isinstance(evt, dict)
|
|
591
|
+
and "data" in evt
|
|
592
|
+
and "fields" in evt
|
|
593
|
+
and isinstance(evt.get("fields"), list)
|
|
594
|
+
):
|
|
595
|
+
# Partial data structure detected - extract just the data part
|
|
596
|
+
evt = evt["data"]
|
|
597
|
+
|
|
598
|
+
# Process the event to check for completion/errors
|
|
599
|
+
result = _process_stream_event(
|
|
600
|
+
evt,
|
|
601
|
+
task=task,
|
|
602
|
+
stopper=None, # We'll handle stopping via the iterator
|
|
603
|
+
)
|
|
604
|
+
if result is not None:
|
|
605
|
+
yield result
|
|
606
|
+
break
|
|
607
|
+
yield _strip_task(evt)
|
|
608
|
+
except Exception as exc:
|
|
609
|
+
yield exc
|
|
610
|
+
raise
|
|
611
|
+
finally:
|
|
612
|
+
try:
|
|
613
|
+
# Force close the underlying socket if possible
|
|
614
|
+
try:
|
|
615
|
+
raw = getattr(resp, 'raw', None)
|
|
616
|
+
if raw is not None:
|
|
617
|
+
raw.close()
|
|
618
|
+
except Exception:
|
|
619
|
+
raise
|
|
620
|
+
# Close the response
|
|
621
|
+
resp.close()
|
|
622
|
+
except Exception:
|
|
623
|
+
raise
|
|
624
|
+
|
|
625
|
+
def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
|
|
626
|
+
"""Iterate JSON events from an SSE response."""
|
|
627
|
+
# Mode 1: raw socket readline (can reduce buffering in some environments)
|
|
628
|
+
if self._sse_mode == "raw":
|
|
629
|
+
raw = getattr(resp, "raw", None)
|
|
630
|
+
if raw is not None:
|
|
631
|
+
try:
|
|
632
|
+
# Avoid urllib3 decompression buffering
|
|
633
|
+
raw.decode_content = False # type: ignore[attr-defined]
|
|
634
|
+
except Exception:
|
|
635
|
+
raise
|
|
636
|
+
buf = bytearray()
|
|
637
|
+
read_size = max(1, int(self._sse_read_bytes))
|
|
638
|
+
while True:
|
|
639
|
+
# Check if we've been asked to stop before reading more data
|
|
640
|
+
try:
|
|
641
|
+
if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
|
|
642
|
+
break
|
|
643
|
+
except Exception:
|
|
644
|
+
raise
|
|
645
|
+
|
|
646
|
+
chunk = raw.read(read_size)
|
|
647
|
+
if not chunk:
|
|
648
|
+
break
|
|
649
|
+
for b in chunk:
|
|
650
|
+
if b == 10: # '\n'
|
|
651
|
+
try:
|
|
652
|
+
line = buf.decode(errors="ignore").rstrip("\r")
|
|
653
|
+
except Exception:
|
|
654
|
+
line = ""
|
|
655
|
+
buf.clear()
|
|
656
|
+
if not line:
|
|
657
|
+
continue
|
|
658
|
+
if line.startswith(":"):
|
|
659
|
+
continue
|
|
660
|
+
if line.startswith("data:"):
|
|
661
|
+
data_str = line[5:].lstrip()
|
|
662
|
+
if not data_str:
|
|
663
|
+
continue
|
|
664
|
+
try:
|
|
665
|
+
yield json.loads(data_str)
|
|
666
|
+
except json.JSONDecodeError:
|
|
667
|
+
continue
|
|
668
|
+
else:
|
|
669
|
+
buf.append(b)
|
|
670
|
+
return
|
|
671
|
+
# Mode 2: default iter_lines with reasonable chunk size (8KB)
|
|
672
|
+
for line in resp.iter_lines(decode_unicode=True, chunk_size=8192):
|
|
673
|
+
# Check if we've been asked to stop before processing any more lines
|
|
674
|
+
try:
|
|
675
|
+
if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
|
|
676
|
+
break
|
|
677
|
+
except Exception:
|
|
678
|
+
raise
|
|
679
|
+
|
|
680
|
+
if not line:
|
|
681
|
+
continue
|
|
682
|
+
if line.startswith(":"):
|
|
683
|
+
continue
|
|
684
|
+
if line.startswith("data:"):
|
|
685
|
+
data_str = line[5:].lstrip()
|
|
686
|
+
if not data_str:
|
|
687
|
+
continue
|
|
688
|
+
try:
|
|
689
|
+
yield json.loads(data_str)
|
|
690
|
+
except json.JSONDecodeError:
|
|
691
|
+
continue
|
|
692
|
+
|
|
693
|
+
def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
|
|
694
|
+
if input_value is None:
|
|
695
|
+
return input_value
|
|
696
|
+
|
|
697
|
+
# Handle lists
|
|
698
|
+
if isinstance(input_value, list):
|
|
699
|
+
return [self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
|
|
700
|
+
|
|
701
|
+
# Handle dicts
|
|
702
|
+
if isinstance(input_value, dict):
|
|
703
|
+
processed: Dict[str, Any] = {}
|
|
704
|
+
for key, value in input_value.items():
|
|
705
|
+
processed[key] = self._process_input_data(value, f"{path}.{key}")
|
|
706
|
+
return processed
|
|
707
|
+
|
|
708
|
+
# Handle strings that are filesystem paths, data URIs, or base64
|
|
709
|
+
if isinstance(input_value, str):
|
|
710
|
+
# Prefer existing local file paths first to avoid misclassifying plain strings
|
|
711
|
+
if os.path.exists(input_value):
|
|
712
|
+
file_obj = self.upload_file(input_value)
|
|
713
|
+
return file_obj.get("uri")
|
|
714
|
+
if input_value.startswith("data:") or _looks_like_base64(input_value):
|
|
715
|
+
file_obj = self.upload_file(input_value)
|
|
716
|
+
return file_obj.get("uri")
|
|
717
|
+
return input_value
|
|
718
|
+
|
|
719
|
+
# Handle File-like objects from our models
|
|
720
|
+
try:
|
|
721
|
+
from .models.file import File as SDKFile # local import to avoid cycle
|
|
722
|
+
if isinstance(input_value, SDKFile):
|
|
723
|
+
# Prefer local path if present, else uri
|
|
724
|
+
src = input_value.path or input_value.uri
|
|
725
|
+
if not src:
|
|
726
|
+
return input_value
|
|
727
|
+
file_obj = self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
|
|
728
|
+
return file_obj.get("uri")
|
|
729
|
+
except Exception:
|
|
730
|
+
raise
|
|
731
|
+
|
|
732
|
+
return input_value
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
class AsyncInference:
|
|
736
|
+
"""Async client for inference.sh API, mirroring the JS SDK behavior."""
|
|
737
|
+
|
|
738
|
+
def __init__(self, *, api_key: str, base_url: Optional[str] = None) -> None:
|
|
739
|
+
self._api_key = api_key
|
|
740
|
+
self._base_url = base_url or "https://api.inference.sh"
|
|
741
|
+
|
|
742
|
+
# --------------- HTTP helpers ---------------
|
|
743
|
+
def _headers(self) -> Dict[str, str]:
|
|
744
|
+
return {
|
|
745
|
+
"Content-Type": "application/json",
|
|
746
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
async def _request(
|
|
750
|
+
self,
|
|
751
|
+
method: str,
|
|
752
|
+
endpoint: str,
|
|
753
|
+
*,
|
|
754
|
+
params: Optional[Dict[str, Any]] = None,
|
|
755
|
+
data: Optional[Dict[str, Any]] = None,
|
|
756
|
+
headers: Optional[Dict[str, str]] = None,
|
|
757
|
+
timeout: Optional[float] = None,
|
|
758
|
+
expect_stream: bool = False,
|
|
759
|
+
) -> Any:
|
|
760
|
+
aiohttp = await _require_aiohttp()
|
|
761
|
+
url = f"{self._base_url}{endpoint}"
|
|
762
|
+
merged_headers = {**self._headers(), **(headers or {})}
|
|
763
|
+
timeout_cfg = aiohttp.ClientTimeout(total=timeout or 30)
|
|
764
|
+
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
765
|
+
async with session.request(
|
|
766
|
+
method=method.upper(),
|
|
767
|
+
url=url,
|
|
768
|
+
params=params,
|
|
769
|
+
json=data,
|
|
770
|
+
headers=merged_headers,
|
|
771
|
+
) as resp:
|
|
772
|
+
if expect_stream:
|
|
773
|
+
return resp
|
|
774
|
+
payload = await resp.json()
|
|
775
|
+
if not isinstance(payload, dict) or not payload.get("success", False):
|
|
776
|
+
message = None
|
|
777
|
+
if isinstance(payload, dict) and payload.get("error"):
|
|
778
|
+
err = payload["error"]
|
|
779
|
+
if isinstance(err, dict):
|
|
780
|
+
message = err.get("message")
|
|
781
|
+
else:
|
|
782
|
+
message = str(err)
|
|
783
|
+
raise RuntimeError(message or "Request failed")
|
|
784
|
+
return payload.get("data")
|
|
785
|
+
|
|
786
|
+
# --------------- Public API ---------------
|
|
787
|
+
async def run(
|
|
788
|
+
self,
|
|
789
|
+
params: Dict[str, Any],
|
|
790
|
+
*,
|
|
791
|
+
wait: bool = True,
|
|
792
|
+
stream: bool = False,
|
|
793
|
+
auto_reconnect: bool = True,
|
|
794
|
+
max_reconnects: int = 5,
|
|
795
|
+
reconnect_delay_ms: int = 1000,
|
|
796
|
+
) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
|
|
797
|
+
"""Run a task with optional streaming updates.
|
|
798
|
+
|
|
799
|
+
By default, this method waits for the task to complete and returns the final result.
|
|
800
|
+
You can set wait=False to get just the task info, or stream=True to get an iterator
|
|
801
|
+
of status updates.
|
|
802
|
+
|
|
803
|
+
Args:
|
|
804
|
+
params: Task parameters to pass to the API
|
|
805
|
+
wait: Whether to wait for task completion (default: True)
|
|
806
|
+
stream: Whether to return an iterator of updates (default: False)
|
|
807
|
+
auto_reconnect: Whether to automatically reconnect on connection loss
|
|
808
|
+
max_reconnects: Maximum number of reconnection attempts
|
|
809
|
+
reconnect_delay_ms: Delay between reconnection attempts in milliseconds
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
|
|
813
|
+
- If wait=True and stream=False: The completed task data
|
|
814
|
+
- If wait=False: The created task info
|
|
815
|
+
- If stream=True: An iterator of task updates
|
|
816
|
+
|
|
817
|
+
Example:
|
|
818
|
+
```python
|
|
819
|
+
# Simple usage - wait for result (default)
|
|
820
|
+
result = await client.run(params)
|
|
821
|
+
print(f"Output: {result['output']}")
|
|
822
|
+
|
|
823
|
+
# Get task info without waiting
|
|
824
|
+
task = await client.run(params, wait=False)
|
|
825
|
+
task_id = task["id"]
|
|
826
|
+
|
|
827
|
+
# Stream updates
|
|
828
|
+
stream = await client.run(params, stream=True)
|
|
829
|
+
for update in stream:
|
|
830
|
+
print(f"Status: {update.get('status')}")
|
|
831
|
+
if update.get('status') == TaskStatus.COMPLETED:
|
|
832
|
+
print(f"Result: {update.get('output')}")
|
|
833
|
+
```
|
|
834
|
+
"""
|
|
835
|
+
# Create the task
|
|
836
|
+
processed_input = await self._process_input_data(params.get("input"))
|
|
837
|
+
task = await self._request("post", "/run", data={**params, "input": processed_input})
|
|
838
|
+
|
|
839
|
+
# Return immediately if not waiting
|
|
840
|
+
if not wait and not stream:
|
|
841
|
+
return task
|
|
842
|
+
|
|
843
|
+
# Return stream if requested
|
|
844
|
+
if stream:
|
|
845
|
+
task_stream = TaskStream(
|
|
846
|
+
task=task,
|
|
847
|
+
client=self,
|
|
848
|
+
auto_reconnect=auto_reconnect,
|
|
849
|
+
max_reconnects=max_reconnects,
|
|
850
|
+
reconnect_delay_ms=reconnect_delay_ms,
|
|
851
|
+
)
|
|
852
|
+
return task_stream
|
|
853
|
+
|
|
854
|
+
# Otherwise wait for completion
|
|
855
|
+
return await self.wait_for_completion(task["id"])
|
|
856
|
+
|
|
857
|
+
async def cancel(self, task_id: str) -> None:
|
|
858
|
+
await self._request("post", f"/tasks/{task_id}/cancel")
|
|
859
|
+
|
|
860
|
+
async def get_task(self, task_id: str) -> Dict[str, Any]:
|
|
861
|
+
"""Get the current state of a task.
|
|
862
|
+
|
|
863
|
+
Args:
|
|
864
|
+
task_id: The ID of the task to get
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
Dict[str, Any]: The current task state
|
|
868
|
+
"""
|
|
869
|
+
return await self._request("get", f"/tasks/{task_id}")
|
|
870
|
+
|
|
871
|
+
async def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
|
|
872
|
+
"""Wait for a task to complete and return its final state.
|
|
873
|
+
|
|
874
|
+
This method polls the task status until it reaches a terminal state
|
|
875
|
+
(completed, failed, or cancelled).
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
task_id: The ID of the task to wait for
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
Dict[str, Any]: The final task state
|
|
882
|
+
|
|
883
|
+
Raises:
|
|
884
|
+
RuntimeError: If the task fails or is cancelled
|
|
885
|
+
"""
|
|
886
|
+
with self.stream_task(task_id) as stream:
|
|
887
|
+
async for update in stream:
|
|
888
|
+
if update.get("status") == TaskStatus.COMPLETED:
|
|
889
|
+
return update
|
|
890
|
+
elif update.get("status") == TaskStatus.FAILED:
|
|
891
|
+
raise RuntimeError(update.get("error") or "Task failed")
|
|
892
|
+
elif update.get("status") == TaskStatus.CANCELLED:
|
|
893
|
+
raise RuntimeError("Task cancelled")
|
|
894
|
+
raise RuntimeError("Stream ended without completion")
|
|
895
|
+
|
|
896
|
+
# --------------- File upload ---------------
|
|
897
|
+
async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
|
|
898
|
+
options = options or UploadFileOptions()
|
|
899
|
+
content_type = options.content_type
|
|
900
|
+
raw_bytes: bytes
|
|
901
|
+
if isinstance(data, bytes):
|
|
902
|
+
raw_bytes = data
|
|
903
|
+
if not content_type:
|
|
904
|
+
content_type = "application/octet-stream"
|
|
905
|
+
else:
|
|
906
|
+
if os.path.exists(data):
|
|
907
|
+
path = data
|
|
908
|
+
guessed = mimetypes.guess_type(path)[0]
|
|
909
|
+
content_type = content_type or guessed or "application/octet-stream"
|
|
910
|
+
async with await _aio_open_file(path, "rb") as f:
|
|
911
|
+
raw_bytes = await f.read() # type: ignore[attr-defined]
|
|
912
|
+
if not options.filename:
|
|
913
|
+
options.filename = os.path.basename(path)
|
|
914
|
+
elif data.startswith("data:"):
|
|
915
|
+
match = re.match(r"^data:([^;]+);base64,(.+)$", data)
|
|
916
|
+
if not match:
|
|
917
|
+
raise ValueError("Invalid base64 data URI format")
|
|
918
|
+
content_type = content_type or match.group(1)
|
|
919
|
+
raw_bytes = _b64_to_bytes(match.group(2))
|
|
920
|
+
elif _looks_like_base64(data):
|
|
921
|
+
raw_bytes = _b64_to_bytes(data)
|
|
922
|
+
content_type = content_type or "application/octet-stream"
|
|
923
|
+
else:
|
|
924
|
+
raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
|
|
925
|
+
|
|
926
|
+
file_req = {
|
|
927
|
+
"files": [
|
|
928
|
+
{
|
|
929
|
+
"uri": "",
|
|
930
|
+
"filename": options.filename,
|
|
931
|
+
"content_type": content_type,
|
|
932
|
+
"path": options.path,
|
|
933
|
+
"size": len(raw_bytes),
|
|
934
|
+
"public": options.public,
|
|
935
|
+
}
|
|
936
|
+
]
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
created = await self._request("post", "/files", data=file_req)
|
|
940
|
+
file_obj = created[0]
|
|
941
|
+
upload_url = file_obj.get("upload_url")
|
|
942
|
+
if not upload_url:
|
|
943
|
+
raise RuntimeError("No upload URL provided by the server")
|
|
944
|
+
|
|
945
|
+
aiohttp = await _require_aiohttp()
|
|
946
|
+
timeout_cfg = aiohttp.ClientTimeout(total=60)
|
|
947
|
+
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
948
|
+
async with session.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type}) as resp:
|
|
949
|
+
if resp.status // 100 != 2:
|
|
950
|
+
raise RuntimeError(f"Failed to upload file content: {resp.reason}")
|
|
951
|
+
return file_obj
|
|
952
|
+
|
|
953
|
+
# --------------- Helpers ---------------
|
|
954
|
+
async def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
|
|
955
|
+
if input_value is None:
|
|
956
|
+
return input_value
|
|
957
|
+
|
|
958
|
+
if isinstance(input_value, list):
|
|
959
|
+
return [await self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
|
|
960
|
+
|
|
961
|
+
if isinstance(input_value, dict):
|
|
962
|
+
processed: Dict[str, Any] = {}
|
|
963
|
+
for key, value in input_value.items():
|
|
964
|
+
processed[key] = await self._process_input_data(value, f"{path}.{key}")
|
|
965
|
+
return processed
|
|
966
|
+
|
|
967
|
+
if isinstance(input_value, str):
|
|
968
|
+
if os.path.exists(input_value):
|
|
969
|
+
file_obj = await self.upload_file(input_value)
|
|
970
|
+
return file_obj.get("uri")
|
|
971
|
+
if input_value.startswith("data:") or _looks_like_base64(input_value):
|
|
972
|
+
file_obj = await self.upload_file(input_value)
|
|
973
|
+
return file_obj.get("uri")
|
|
974
|
+
return input_value
|
|
975
|
+
|
|
976
|
+
try:
|
|
977
|
+
from .models.file import File as SDKFile # local import
|
|
978
|
+
if isinstance(input_value, SDKFile):
|
|
979
|
+
src = input_value.path or input_value.uri
|
|
980
|
+
if not src:
|
|
981
|
+
return input_value
|
|
982
|
+
file_obj = await self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
|
|
983
|
+
return file_obj.get("uri")
|
|
984
|
+
except Exception:
|
|
985
|
+
raise
|
|
986
|
+
|
|
987
|
+
return input_value
|
|
988
|
+
|
|
989
|
+
async def _aiter_sse(self, resp: Any) -> Generator[Dict[str, Any], None, None]:
|
|
990
|
+
async for raw_line in resp.content: # type: ignore[attr-defined]
|
|
991
|
+
try:
|
|
992
|
+
line = raw_line.decode().rstrip("\n")
|
|
993
|
+
except Exception:
|
|
994
|
+
continue
|
|
995
|
+
if not line:
|
|
996
|
+
continue
|
|
997
|
+
if line.startswith(":"):
|
|
998
|
+
continue
|
|
999
|
+
if line.startswith("data:"):
|
|
1000
|
+
data_str = line[5:].lstrip()
|
|
1001
|
+
if not data_str:
|
|
1002
|
+
continue
|
|
1003
|
+
try:
|
|
1004
|
+
yield json.loads(data_str)
|
|
1005
|
+
except json.JSONDecodeError:
|
|
1006
|
+
continue
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
# --------------- small async utilities ---------------
|
|
1010
|
+
async def _async_sleep(seconds: float) -> None:
|
|
1011
|
+
import asyncio
|
|
1012
|
+
|
|
1013
|
+
await asyncio.sleep(seconds)
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def _b64_to_bytes(b64: str) -> bytes:
|
|
1017
|
+
import base64
|
|
1018
|
+
|
|
1019
|
+
return base64.b64decode(b64)
|
|
1020
|
+
|
|
1021
|
+
|
|
1022
|
+
async def _aio_open_file(path: str, mode: str):
|
|
1023
|
+
import aiofiles # type: ignore
|
|
1024
|
+
|
|
1025
|
+
return await aiofiles.open(path, mode)
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
def _looks_like_base64(value: str) -> bool:
|
|
1029
|
+
# Reject very short strings to avoid matching normal words like "hi"
|
|
1030
|
+
if len(value) < 16:
|
|
1031
|
+
return False
|
|
1032
|
+
# Quick charset check
|
|
1033
|
+
if not Base64_RE.match(value):
|
|
1034
|
+
return False
|
|
1035
|
+
# Must be divisible by 4
|
|
1036
|
+
if len(value) % 4 != 0:
|
|
1037
|
+
return False
|
|
1038
|
+
# Try decode to be sure
|
|
1039
|
+
try:
|
|
1040
|
+
_ = _b64_to_bytes(value)
|
|
1041
|
+
return True
|
|
1042
|
+
except Exception:
|
|
1043
|
+
return False
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
def _strip_task(task: Dict[str, Any]) -> Dict[str, Any]:
|
|
1047
|
+
"""Strip task to essential fields."""
|
|
1048
|
+
return {
|
|
1049
|
+
"id": task.get("id"),
|
|
1050
|
+
"created_at": task.get("created_at"),
|
|
1051
|
+
"updated_at": task.get("updated_at"),
|
|
1052
|
+
"input": task.get("input"),
|
|
1053
|
+
"output": task.get("output"),
|
|
1054
|
+
"logs": task.get("logs"),
|
|
1055
|
+
"status": task.get("status"),
|
|
1056
|
+
}
|
|
1057
|
+
|
|
1058
|
+
def _process_stream_event(
|
|
1059
|
+
data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
|
|
1060
|
+
) -> Optional[Dict[str, Any]]:
|
|
1061
|
+
"""Shared handler for SSE task events. Returns final task dict when completed, else None.
|
|
1062
|
+
If stopper is provided, it will be called on terminal events to end streaming.
|
|
1063
|
+
"""
|
|
1064
|
+
status = data.get("status")
|
|
1065
|
+
|
|
1066
|
+
if status == TaskStatus.COMPLETED:
|
|
1067
|
+
result = _strip_task(data)
|
|
1068
|
+
if stopper:
|
|
1069
|
+
stopper()
|
|
1070
|
+
return result
|
|
1071
|
+
if status == TaskStatus.FAILED:
|
|
1072
|
+
if stopper:
|
|
1073
|
+
stopper()
|
|
1074
|
+
raise RuntimeError(data.get("error") or "task failed")
|
|
1075
|
+
if status == TaskStatus.CANCELLED:
|
|
1076
|
+
if stopper:
|
|
1077
|
+
stopper()
|
|
1078
|
+
raise RuntimeError("task cancelled")
|
|
1079
|
+
return None
|
|
1080
|
+
|
|
1081
|
+
|