inferencesh 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of inferencesh might be problematic. Click here for more details.
- inferencesh/__init__.py +37 -1
- inferencesh/client.py +830 -0
- inferencesh/models/__init__.py +29 -0
- inferencesh/models/base.py +94 -0
- inferencesh/models/file.py +252 -0
- inferencesh/models/llm.py +729 -0
- inferencesh/utils/__init__.py +6 -0
- inferencesh/utils/download.py +59 -0
- inferencesh/utils/storage.py +16 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.4.0.dist-info}/METADATA +6 -1
- inferencesh-0.4.0.dist-info/RECORD +15 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.4.0.dist-info}/WHEEL +1 -1
- inferencesh/sdk.py +0 -363
- inferencesh-0.3.0.dist-info/RECORD +0 -8
- {inferencesh-0.3.0.dist-info → inferencesh-0.4.0.dist-info}/entry_points.txt +0 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {inferencesh-0.3.0.dist-info → inferencesh-0.4.0.dist-info}/top_level.txt +0 -0
inferencesh/client.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Callable, Generator, Union
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import IntEnum
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
import time
|
|
9
|
+
import mimetypes
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
|
|
14
|
+
def _require_requests():
|
|
15
|
+
try:
|
|
16
|
+
import requests # type: ignore
|
|
17
|
+
return requests
|
|
18
|
+
except Exception as exc: # pragma: no cover - dependency hint
|
|
19
|
+
raise RuntimeError(
|
|
20
|
+
"The 'requests' package is required for synchronous HTTP calls. Install with: pip install requests"
|
|
21
|
+
) from exc
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def _require_aiohttp():
|
|
25
|
+
try:
|
|
26
|
+
import aiohttp # type: ignore
|
|
27
|
+
return aiohttp
|
|
28
|
+
except Exception as exc: # pragma: no cover - dependency hint
|
|
29
|
+
raise RuntimeError(
|
|
30
|
+
"The 'aiohttp' package is required for async HTTP calls. Install with: pip install aiohttp"
|
|
31
|
+
) from exc
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TaskStatus(IntEnum):
|
|
35
|
+
RECEIVED = 1
|
|
36
|
+
QUEUED = 2
|
|
37
|
+
SCHEDULED = 3
|
|
38
|
+
PREPARING = 4
|
|
39
|
+
SERVING = 5
|
|
40
|
+
SETTING_UP = 6
|
|
41
|
+
RUNNING = 7
|
|
42
|
+
UPLOADING = 8
|
|
43
|
+
COMPLETED = 9
|
|
44
|
+
FAILED = 10
|
|
45
|
+
CANCELLED = 11
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
Base64_RE = re.compile(r"^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class UploadFileOptions:
|
|
53
|
+
filename: Optional[str] = None
|
|
54
|
+
content_type: Optional[str] = None
|
|
55
|
+
path: Optional[str] = None
|
|
56
|
+
public: Optional[bool] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class StreamManager:
|
|
60
|
+
"""Simple SSE stream manager with optional auto-reconnect."""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
*,
|
|
65
|
+
create_event_source: Callable[[], Any],
|
|
66
|
+
auto_reconnect: bool = True,
|
|
67
|
+
max_reconnects: int = 5,
|
|
68
|
+
reconnect_delay_ms: int = 1000,
|
|
69
|
+
on_error: Optional[Callable[[Exception], None]] = None,
|
|
70
|
+
on_start: Optional[Callable[[], None]] = None,
|
|
71
|
+
on_stop: Optional[Callable[[], None]] = None,
|
|
72
|
+
on_data: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
self._create_event_source = create_event_source
|
|
75
|
+
self._auto_reconnect = auto_reconnect
|
|
76
|
+
self._max_reconnects = max_reconnects
|
|
77
|
+
self._reconnect_delay_ms = reconnect_delay_ms
|
|
78
|
+
self._on_error = on_error
|
|
79
|
+
self._on_start = on_start
|
|
80
|
+
self._on_stop = on_stop
|
|
81
|
+
self._on_data = on_data
|
|
82
|
+
|
|
83
|
+
self._stopped = False
|
|
84
|
+
self._reconnect_attempts = 0
|
|
85
|
+
self._had_successful_connection = False
|
|
86
|
+
|
|
87
|
+
def stop(self) -> None:
|
|
88
|
+
self._stopped = True
|
|
89
|
+
if self._on_stop:
|
|
90
|
+
self._on_stop()
|
|
91
|
+
|
|
92
|
+
def connect(self) -> None:
|
|
93
|
+
self._stopped = False
|
|
94
|
+
self._reconnect_attempts = 0
|
|
95
|
+
while not self._stopped:
|
|
96
|
+
try:
|
|
97
|
+
if self._on_start:
|
|
98
|
+
self._on_start()
|
|
99
|
+
event_source = self._create_event_source()
|
|
100
|
+
try:
|
|
101
|
+
for data in event_source:
|
|
102
|
+
if self._stopped:
|
|
103
|
+
break
|
|
104
|
+
self._had_successful_connection = True
|
|
105
|
+
if self._on_data:
|
|
106
|
+
self._on_data(data)
|
|
107
|
+
# Check again after processing in case on_data stopped us
|
|
108
|
+
if self._stopped:
|
|
109
|
+
break
|
|
110
|
+
finally:
|
|
111
|
+
# Clean up the event source if it has a close method
|
|
112
|
+
try:
|
|
113
|
+
if hasattr(event_source, 'close'):
|
|
114
|
+
event_source.close()
|
|
115
|
+
except Exception:
|
|
116
|
+
raise
|
|
117
|
+
|
|
118
|
+
# If we're stopped or don't want to auto-reconnect, break immediately
|
|
119
|
+
if self._stopped or not self._auto_reconnect:
|
|
120
|
+
break
|
|
121
|
+
except Exception as exc: # noqa: BLE001
|
|
122
|
+
if self._on_error:
|
|
123
|
+
self._on_error(exc)
|
|
124
|
+
if self._stopped:
|
|
125
|
+
break
|
|
126
|
+
# If never connected and exceeded attempts, stop
|
|
127
|
+
if not self._had_successful_connection:
|
|
128
|
+
self._reconnect_attempts += 1
|
|
129
|
+
if self._reconnect_attempts > self._max_reconnects:
|
|
130
|
+
break
|
|
131
|
+
time.sleep(self._reconnect_delay_ms / 1000.0)
|
|
132
|
+
else:
|
|
133
|
+
# Completed without exception - if we want to auto-reconnect only after success
|
|
134
|
+
if not self._auto_reconnect:
|
|
135
|
+
break
|
|
136
|
+
time.sleep(self._reconnect_delay_ms / 1000.0)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class Inference:
|
|
140
|
+
"""Synchronous client for inference.sh API, mirroring the JS SDK behavior.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
api_key (str): The API key for authentication
|
|
144
|
+
base_url (Optional[str]): Override the default API base URL
|
|
145
|
+
sse_chunk_size (Optional[int]): Chunk size for SSE reading (default: 8192 bytes)
|
|
146
|
+
sse_mode (Optional[str]): SSE reading mode ('iter_lines' or 'raw', default: 'iter_lines')
|
|
147
|
+
|
|
148
|
+
The client supports performance tuning for SSE (Server-Sent Events) through:
|
|
149
|
+
1. sse_chunk_size: Controls the buffer size for reading SSE data (default: 8KB)
|
|
150
|
+
- Larger values may improve performance but use more memory
|
|
151
|
+
- Can also be set via INFERENCE_SSE_READ_BYTES environment variable
|
|
152
|
+
2. sse_mode: Controls how SSE data is read ('iter_lines' or 'raw')
|
|
153
|
+
- 'iter_lines': Uses requests' built-in line iteration (default)
|
|
154
|
+
- 'raw': Uses lower-level socket reading
|
|
155
|
+
- Can also be set via INFERENCE_SSE_MODE environment variable
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
*,
|
|
161
|
+
api_key: str,
|
|
162
|
+
base_url: Optional[str] = None,
|
|
163
|
+
sse_chunk_size: Optional[int] = None,
|
|
164
|
+
sse_mode: Optional[str] = None,
|
|
165
|
+
) -> None:
|
|
166
|
+
self._api_key = api_key
|
|
167
|
+
self._base_url = base_url or "https://api.inference.sh"
|
|
168
|
+
|
|
169
|
+
# SSE configuration with environment variable fallbacks
|
|
170
|
+
self._sse_mode = sse_mode or os.getenv("INFERENCE_SSE_MODE") or "iter_lines"
|
|
171
|
+
self._sse_mode = self._sse_mode.lower()
|
|
172
|
+
|
|
173
|
+
# Default to 8KB chunks, can be overridden by parameter or env var
|
|
174
|
+
try:
|
|
175
|
+
env_chunk_size = os.getenv("INFERENCE_SSE_READ_BYTES")
|
|
176
|
+
if sse_chunk_size is not None:
|
|
177
|
+
self._sse_read_bytes = sse_chunk_size
|
|
178
|
+
elif env_chunk_size is not None:
|
|
179
|
+
self._sse_read_bytes = int(env_chunk_size)
|
|
180
|
+
else:
|
|
181
|
+
self._sse_read_bytes = 8192 # 8KB default
|
|
182
|
+
except Exception:
|
|
183
|
+
self._sse_read_bytes = 8192 # Default to 8KB chunks on error
|
|
184
|
+
|
|
185
|
+
# --------------- HTTP helpers ---------------
|
|
186
|
+
def _headers(self) -> Dict[str, str]:
|
|
187
|
+
return {
|
|
188
|
+
"Content-Type": "application/json",
|
|
189
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def _request(
|
|
193
|
+
self,
|
|
194
|
+
method: str,
|
|
195
|
+
endpoint: str,
|
|
196
|
+
*,
|
|
197
|
+
params: Optional[Dict[str, Any]] = None,
|
|
198
|
+
data: Optional[Dict[str, Any]] = None,
|
|
199
|
+
headers: Optional[Dict[str, str]] = None,
|
|
200
|
+
stream: bool = False,
|
|
201
|
+
timeout: Optional[float] = None,
|
|
202
|
+
) -> Any:
|
|
203
|
+
requests = _require_requests()
|
|
204
|
+
url = f"{self._base_url}{endpoint}"
|
|
205
|
+
merged_headers = {**self._headers(), **(headers or {})}
|
|
206
|
+
resp = requests.request(
|
|
207
|
+
method=method.upper(),
|
|
208
|
+
url=url,
|
|
209
|
+
params=params,
|
|
210
|
+
data=json.dumps(data) if data is not None else None,
|
|
211
|
+
headers=merged_headers,
|
|
212
|
+
stream=stream,
|
|
213
|
+
timeout=timeout or 30,
|
|
214
|
+
)
|
|
215
|
+
if stream:
|
|
216
|
+
return resp
|
|
217
|
+
resp.raise_for_status()
|
|
218
|
+
payload = resp.json()
|
|
219
|
+
if not isinstance(payload, dict) or not payload.get("success", False):
|
|
220
|
+
message = None
|
|
221
|
+
if isinstance(payload, dict) and payload.get("error"):
|
|
222
|
+
err = payload["error"]
|
|
223
|
+
if isinstance(err, dict):
|
|
224
|
+
message = err.get("message")
|
|
225
|
+
else:
|
|
226
|
+
message = str(err)
|
|
227
|
+
raise RuntimeError(message or "Request failed")
|
|
228
|
+
return payload.get("data")
|
|
229
|
+
|
|
230
|
+
# --------------- Public API ---------------
|
|
231
|
+
def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
232
|
+
processed_input = self._process_input_data(params.get("input"))
|
|
233
|
+
task = self._request("post", "/run", data={**params, "input": processed_input})
|
|
234
|
+
return task
|
|
235
|
+
|
|
236
|
+
def run_sync(
|
|
237
|
+
self,
|
|
238
|
+
params: Dict[str, Any],
|
|
239
|
+
*,
|
|
240
|
+
auto_reconnect: bool = True,
|
|
241
|
+
max_reconnects: int = 5,
|
|
242
|
+
reconnect_delay_ms: int = 1000,
|
|
243
|
+
) -> Dict[str, Any]:
|
|
244
|
+
processed_input = self._process_input_data(params.get("input"))
|
|
245
|
+
task = self._request("post", "/run", data={**params, "input": processed_input})
|
|
246
|
+
task_id = task["id"]
|
|
247
|
+
|
|
248
|
+
final_task: Optional[Dict[str, Any]] = None
|
|
249
|
+
|
|
250
|
+
def on_data(data: Dict[str, Any]) -> None:
|
|
251
|
+
nonlocal final_task
|
|
252
|
+
try:
|
|
253
|
+
result = _process_stream_event(
|
|
254
|
+
data,
|
|
255
|
+
task=task,
|
|
256
|
+
stopper=lambda: manager.stop(),
|
|
257
|
+
)
|
|
258
|
+
if result is not None:
|
|
259
|
+
final_task = result
|
|
260
|
+
except Exception as exc:
|
|
261
|
+
raise
|
|
262
|
+
|
|
263
|
+
def on_error(exc: Exception) -> None:
|
|
264
|
+
raise exc
|
|
265
|
+
|
|
266
|
+
def on_start() -> None:
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
def on_stop() -> None:
|
|
270
|
+
pass
|
|
271
|
+
|
|
272
|
+
manager = StreamManager(
|
|
273
|
+
create_event_source=None, # We'll set this after defining it
|
|
274
|
+
auto_reconnect=auto_reconnect,
|
|
275
|
+
max_reconnects=max_reconnects,
|
|
276
|
+
reconnect_delay_ms=reconnect_delay_ms,
|
|
277
|
+
on_data=on_data,
|
|
278
|
+
on_error=on_error,
|
|
279
|
+
on_start=on_start,
|
|
280
|
+
on_stop=on_stop,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def create_event_source() -> Generator[Dict[str, Any], None, None]:
|
|
284
|
+
url = f"/tasks/{task_id}/stream"
|
|
285
|
+
resp = self._request(
|
|
286
|
+
"get",
|
|
287
|
+
url,
|
|
288
|
+
headers={
|
|
289
|
+
"Accept": "text/event-stream",
|
|
290
|
+
"Cache-Control": "no-cache",
|
|
291
|
+
"Accept-Encoding": "identity",
|
|
292
|
+
"Connection": "keep-alive",
|
|
293
|
+
},
|
|
294
|
+
stream=True,
|
|
295
|
+
timeout=60,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
last_event_at = time.perf_counter()
|
|
300
|
+
for evt in self._iter_sse(resp, stream_manager=manager):
|
|
301
|
+
yield evt
|
|
302
|
+
finally:
|
|
303
|
+
try:
|
|
304
|
+
# Force close the underlying socket if possible
|
|
305
|
+
try:
|
|
306
|
+
raw = getattr(resp, 'raw', None)
|
|
307
|
+
if raw is not None:
|
|
308
|
+
raw.close()
|
|
309
|
+
except Exception:
|
|
310
|
+
raise
|
|
311
|
+
# Close the response
|
|
312
|
+
resp.close()
|
|
313
|
+
except Exception:
|
|
314
|
+
raise
|
|
315
|
+
|
|
316
|
+
# Update the create_event_source function in the manager
|
|
317
|
+
manager._create_event_source = create_event_source
|
|
318
|
+
|
|
319
|
+
# Connect and wait for completion
|
|
320
|
+
manager.connect()
|
|
321
|
+
|
|
322
|
+
# At this point, we should have a final task state
|
|
323
|
+
if final_task is not None:
|
|
324
|
+
return final_task
|
|
325
|
+
|
|
326
|
+
# Try to fetch the latest state as a fallback
|
|
327
|
+
try:
|
|
328
|
+
latest = self.get_task(task_id)
|
|
329
|
+
status = latest.get("status")
|
|
330
|
+
if status == TaskStatus.COMPLETED:
|
|
331
|
+
return latest
|
|
332
|
+
if status == TaskStatus.FAILED:
|
|
333
|
+
raise RuntimeError(latest.get("error") or "task failed")
|
|
334
|
+
if status == TaskStatus.CANCELLED:
|
|
335
|
+
raise RuntimeError("task cancelled")
|
|
336
|
+
except Exception as exc:
|
|
337
|
+
raise
|
|
338
|
+
|
|
339
|
+
raise RuntimeError("Stream ended without completion")
|
|
340
|
+
|
|
341
|
+
def cancel(self, task_id: str) -> None:
|
|
342
|
+
self._request("post", f"/tasks/{task_id}/cancel")
|
|
343
|
+
|
|
344
|
+
def get_task(self, task_id: str) -> Dict[str, Any]:
|
|
345
|
+
return self._request("get", f"/tasks/{task_id}")
|
|
346
|
+
|
|
347
|
+
# --------------- File upload ---------------
|
|
348
|
+
def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
|
|
349
|
+
options = options or UploadFileOptions()
|
|
350
|
+
content_type = options.content_type
|
|
351
|
+
raw_bytes: bytes
|
|
352
|
+
if isinstance(data, bytes):
|
|
353
|
+
raw_bytes = data
|
|
354
|
+
if not content_type:
|
|
355
|
+
content_type = "application/octet-stream"
|
|
356
|
+
else:
|
|
357
|
+
# Prefer local filesystem path if it exists
|
|
358
|
+
if os.path.exists(data):
|
|
359
|
+
path = data
|
|
360
|
+
guessed = mimetypes.guess_type(path)[0]
|
|
361
|
+
content_type = content_type or guessed or "application/octet-stream"
|
|
362
|
+
with open(path, "rb") as f:
|
|
363
|
+
raw_bytes = f.read()
|
|
364
|
+
if not options.filename:
|
|
365
|
+
options.filename = os.path.basename(path)
|
|
366
|
+
elif data.startswith("data:"):
|
|
367
|
+
# data URI
|
|
368
|
+
match = re.match(r"^data:([^;]+);base64,(.+)$", data)
|
|
369
|
+
if not match:
|
|
370
|
+
raise ValueError("Invalid base64 data URI format")
|
|
371
|
+
content_type = content_type or match.group(1)
|
|
372
|
+
raw_bytes = _b64_to_bytes(match.group(2))
|
|
373
|
+
elif _looks_like_base64(data):
|
|
374
|
+
raw_bytes = _b64_to_bytes(data)
|
|
375
|
+
content_type = content_type or "application/octet-stream"
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
|
|
378
|
+
|
|
379
|
+
file_req = {
|
|
380
|
+
"files": [
|
|
381
|
+
{
|
|
382
|
+
"uri": "",
|
|
383
|
+
"filename": options.filename,
|
|
384
|
+
"content_type": content_type,
|
|
385
|
+
"path": options.path,
|
|
386
|
+
"size": len(raw_bytes),
|
|
387
|
+
"public": options.public,
|
|
388
|
+
}
|
|
389
|
+
]
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
created = self._request("post", "/files", data=file_req)
|
|
393
|
+
file_obj = created[0]
|
|
394
|
+
upload_url = file_obj.get("upload_url")
|
|
395
|
+
if not upload_url:
|
|
396
|
+
raise RuntimeError("No upload URL provided by the server")
|
|
397
|
+
|
|
398
|
+
# Upload to S3 (or compatible) signed URL
|
|
399
|
+
requests = _require_requests()
|
|
400
|
+
put_resp = requests.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type})
|
|
401
|
+
if not (200 <= put_resp.status_code < 300):
|
|
402
|
+
raise RuntimeError(f"Failed to upload file content: {put_resp.reason}")
|
|
403
|
+
return file_obj
|
|
404
|
+
|
|
405
|
+
# --------------- Helpers ---------------
|
|
406
|
+
def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
|
|
407
|
+
"""Iterate JSON events from an SSE response."""
|
|
408
|
+
# Mode 1: raw socket readline (can reduce buffering in some environments)
|
|
409
|
+
if self._sse_mode == "raw":
|
|
410
|
+
raw = getattr(resp, "raw", None)
|
|
411
|
+
if raw is not None:
|
|
412
|
+
try:
|
|
413
|
+
# Avoid urllib3 decompression buffering
|
|
414
|
+
raw.decode_content = False # type: ignore[attr-defined]
|
|
415
|
+
except Exception:
|
|
416
|
+
raise
|
|
417
|
+
buf = bytearray()
|
|
418
|
+
read_size = max(1, int(self._sse_read_bytes))
|
|
419
|
+
while True:
|
|
420
|
+
# Check if we've been asked to stop before reading more data
|
|
421
|
+
try:
|
|
422
|
+
if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
|
|
423
|
+
break
|
|
424
|
+
except Exception:
|
|
425
|
+
raise
|
|
426
|
+
|
|
427
|
+
chunk = raw.read(read_size)
|
|
428
|
+
if not chunk:
|
|
429
|
+
break
|
|
430
|
+
for b in chunk:
|
|
431
|
+
if b == 10: # '\n'
|
|
432
|
+
try:
|
|
433
|
+
line = buf.decode(errors="ignore").rstrip("\r")
|
|
434
|
+
except Exception:
|
|
435
|
+
line = ""
|
|
436
|
+
buf.clear()
|
|
437
|
+
if not line:
|
|
438
|
+
continue
|
|
439
|
+
if line.startswith(":"):
|
|
440
|
+
continue
|
|
441
|
+
if line.startswith("data:"):
|
|
442
|
+
data_str = line[5:].lstrip()
|
|
443
|
+
if not data_str:
|
|
444
|
+
continue
|
|
445
|
+
try:
|
|
446
|
+
yield json.loads(data_str)
|
|
447
|
+
except json.JSONDecodeError:
|
|
448
|
+
continue
|
|
449
|
+
else:
|
|
450
|
+
buf.append(b)
|
|
451
|
+
return
|
|
452
|
+
# Mode 2: default iter_lines with reasonable chunk size (8KB)
|
|
453
|
+
for line in resp.iter_lines(decode_unicode=True, chunk_size=8192):
|
|
454
|
+
# Check if we've been asked to stop before processing any more lines
|
|
455
|
+
try:
|
|
456
|
+
if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
|
|
457
|
+
break
|
|
458
|
+
except Exception:
|
|
459
|
+
raise
|
|
460
|
+
|
|
461
|
+
if not line:
|
|
462
|
+
continue
|
|
463
|
+
if line.startswith(":"):
|
|
464
|
+
continue
|
|
465
|
+
if line.startswith("data:"):
|
|
466
|
+
data_str = line[5:].lstrip()
|
|
467
|
+
if not data_str:
|
|
468
|
+
continue
|
|
469
|
+
try:
|
|
470
|
+
yield json.loads(data_str)
|
|
471
|
+
except json.JSONDecodeError:
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
|
|
475
|
+
if input_value is None:
|
|
476
|
+
return input_value
|
|
477
|
+
|
|
478
|
+
# Handle lists
|
|
479
|
+
if isinstance(input_value, list):
|
|
480
|
+
return [self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
|
|
481
|
+
|
|
482
|
+
# Handle dicts
|
|
483
|
+
if isinstance(input_value, dict):
|
|
484
|
+
processed: Dict[str, Any] = {}
|
|
485
|
+
for key, value in input_value.items():
|
|
486
|
+
processed[key] = self._process_input_data(value, f"{path}.{key}")
|
|
487
|
+
return processed
|
|
488
|
+
|
|
489
|
+
# Handle strings that are filesystem paths, data URIs, or base64
|
|
490
|
+
if isinstance(input_value, str):
|
|
491
|
+
# Prefer existing local file paths first to avoid misclassifying plain strings
|
|
492
|
+
if os.path.exists(input_value):
|
|
493
|
+
file_obj = self.upload_file(input_value)
|
|
494
|
+
return file_obj.get("uri")
|
|
495
|
+
if input_value.startswith("data:") or _looks_like_base64(input_value):
|
|
496
|
+
file_obj = self.upload_file(input_value)
|
|
497
|
+
return file_obj.get("uri")
|
|
498
|
+
return input_value
|
|
499
|
+
|
|
500
|
+
# Handle File-like objects from our models
|
|
501
|
+
try:
|
|
502
|
+
from .models.file import File as SDKFile # local import to avoid cycle
|
|
503
|
+
if isinstance(input_value, SDKFile):
|
|
504
|
+
# Prefer local path if present, else uri
|
|
505
|
+
src = input_value.path or input_value.uri
|
|
506
|
+
if not src:
|
|
507
|
+
return input_value
|
|
508
|
+
file_obj = self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
|
|
509
|
+
return file_obj.get("uri")
|
|
510
|
+
except Exception:
|
|
511
|
+
raise
|
|
512
|
+
|
|
513
|
+
return input_value
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class AsyncInference:
|
|
517
|
+
"""Async client for inference.sh API, mirroring the JS SDK behavior."""
|
|
518
|
+
|
|
519
|
+
def __init__(self, *, api_key: str, base_url: Optional[str] = None) -> None:
|
|
520
|
+
self._api_key = api_key
|
|
521
|
+
self._base_url = base_url or "https://api.inference.sh"
|
|
522
|
+
|
|
523
|
+
# --------------- HTTP helpers ---------------
|
|
524
|
+
def _headers(self) -> Dict[str, str]:
|
|
525
|
+
return {
|
|
526
|
+
"Content-Type": "application/json",
|
|
527
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
async def _request(
|
|
531
|
+
self,
|
|
532
|
+
method: str,
|
|
533
|
+
endpoint: str,
|
|
534
|
+
*,
|
|
535
|
+
params: Optional[Dict[str, Any]] = None,
|
|
536
|
+
data: Optional[Dict[str, Any]] = None,
|
|
537
|
+
headers: Optional[Dict[str, str]] = None,
|
|
538
|
+
timeout: Optional[float] = None,
|
|
539
|
+
expect_stream: bool = False,
|
|
540
|
+
) -> Any:
|
|
541
|
+
aiohttp = await _require_aiohttp()
|
|
542
|
+
url = f"{self._base_url}{endpoint}"
|
|
543
|
+
merged_headers = {**self._headers(), **(headers or {})}
|
|
544
|
+
timeout_cfg = aiohttp.ClientTimeout(total=timeout or 30)
|
|
545
|
+
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
546
|
+
async with session.request(
|
|
547
|
+
method=method.upper(),
|
|
548
|
+
url=url,
|
|
549
|
+
params=params,
|
|
550
|
+
json=data,
|
|
551
|
+
headers=merged_headers,
|
|
552
|
+
) as resp:
|
|
553
|
+
if expect_stream:
|
|
554
|
+
return resp
|
|
555
|
+
payload = await resp.json()
|
|
556
|
+
if not isinstance(payload, dict) or not payload.get("success", False):
|
|
557
|
+
message = None
|
|
558
|
+
if isinstance(payload, dict) and payload.get("error"):
|
|
559
|
+
err = payload["error"]
|
|
560
|
+
if isinstance(err, dict):
|
|
561
|
+
message = err.get("message")
|
|
562
|
+
else:
|
|
563
|
+
message = str(err)
|
|
564
|
+
raise RuntimeError(message or "Request failed")
|
|
565
|
+
return payload.get("data")
|
|
566
|
+
|
|
567
|
+
# --------------- Public API ---------------
|
|
568
|
+
async def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
569
|
+
processed_input = await self._process_input_data(params.get("input"))
|
|
570
|
+
task = await self._request("post", "/run", data={**params, "input": processed_input})
|
|
571
|
+
return task
|
|
572
|
+
|
|
573
|
+
async def run_sync(
|
|
574
|
+
self,
|
|
575
|
+
params: Dict[str, Any],
|
|
576
|
+
*,
|
|
577
|
+
auto_reconnect: bool = True,
|
|
578
|
+
max_reconnects: int = 5,
|
|
579
|
+
reconnect_delay_ms: int = 1000,
|
|
580
|
+
) -> Dict[str, Any]:
|
|
581
|
+
processed_input = await self._process_input_data(params.get("input"))
|
|
582
|
+
task = await self._request("post", "/run", data={**params, "input": processed_input})
|
|
583
|
+
task_id = task["id"]
|
|
584
|
+
|
|
585
|
+
final_task: Optional[Dict[str, Any]] = None
|
|
586
|
+
reconnect_attempts = 0
|
|
587
|
+
had_success = False
|
|
588
|
+
|
|
589
|
+
while True:
|
|
590
|
+
try:
|
|
591
|
+
resp = await self._request(
|
|
592
|
+
"get",
|
|
593
|
+
f"/tasks/{task_id}/stream",
|
|
594
|
+
headers={
|
|
595
|
+
"Accept": "text/event-stream",
|
|
596
|
+
"Cache-Control": "no-cache",
|
|
597
|
+
"Accept-Encoding": "identity",
|
|
598
|
+
"Connection": "keep-alive",
|
|
599
|
+
},
|
|
600
|
+
timeout=60,
|
|
601
|
+
expect_stream=True,
|
|
602
|
+
)
|
|
603
|
+
had_success = True
|
|
604
|
+
async for data in self._aiter_sse(resp):
|
|
605
|
+
result = _process_stream_event(
|
|
606
|
+
data,
|
|
607
|
+
task=task,
|
|
608
|
+
stopper=None,
|
|
609
|
+
)
|
|
610
|
+
if result is not None:
|
|
611
|
+
final_task = result
|
|
612
|
+
break
|
|
613
|
+
if final_task is not None:
|
|
614
|
+
break
|
|
615
|
+
except Exception as exc: # noqa: BLE001
|
|
616
|
+
if not auto_reconnect:
|
|
617
|
+
raise
|
|
618
|
+
if not had_success:
|
|
619
|
+
reconnect_attempts += 1
|
|
620
|
+
if reconnect_attempts > max_reconnects:
|
|
621
|
+
raise
|
|
622
|
+
await _async_sleep(reconnect_delay_ms / 1000.0)
|
|
623
|
+
else:
|
|
624
|
+
if not auto_reconnect:
|
|
625
|
+
break
|
|
626
|
+
await _async_sleep(reconnect_delay_ms / 1000.0)
|
|
627
|
+
|
|
628
|
+
if final_task is None:
|
|
629
|
+
# Fallback: fetch latest task state in case stream ended without a terminal event
|
|
630
|
+
try:
|
|
631
|
+
latest = await self.get_task(task_id)
|
|
632
|
+
status = latest.get("status")
|
|
633
|
+
if status == TaskStatus.COMPLETED:
|
|
634
|
+
return latest
|
|
635
|
+
if status == TaskStatus.FAILED:
|
|
636
|
+
raise RuntimeError(latest.get("error") or "task failed")
|
|
637
|
+
if status == TaskStatus.CANCELLED:
|
|
638
|
+
raise RuntimeError("task cancelled")
|
|
639
|
+
except Exception:
|
|
640
|
+
raise
|
|
641
|
+
raise RuntimeError("Stream ended without completion")
|
|
642
|
+
return final_task
|
|
643
|
+
|
|
644
|
+
async def cancel(self, task_id: str) -> None:
|
|
645
|
+
await self._request("post", f"/tasks/{task_id}/cancel")
|
|
646
|
+
|
|
647
|
+
async def get_task(self, task_id: str) -> Dict[str, Any]:
|
|
648
|
+
return await self._request("get", f"/tasks/{task_id}")
|
|
649
|
+
|
|
650
|
+
# --------------- File upload ---------------
|
|
651
|
+
async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
|
|
652
|
+
options = options or UploadFileOptions()
|
|
653
|
+
content_type = options.content_type
|
|
654
|
+
raw_bytes: bytes
|
|
655
|
+
if isinstance(data, bytes):
|
|
656
|
+
raw_bytes = data
|
|
657
|
+
if not content_type:
|
|
658
|
+
content_type = "application/octet-stream"
|
|
659
|
+
else:
|
|
660
|
+
if os.path.exists(data):
|
|
661
|
+
path = data
|
|
662
|
+
guessed = mimetypes.guess_type(path)[0]
|
|
663
|
+
content_type = content_type or guessed or "application/octet-stream"
|
|
664
|
+
async with await _aio_open_file(path, "rb") as f:
|
|
665
|
+
raw_bytes = await f.read() # type: ignore[attr-defined]
|
|
666
|
+
if not options.filename:
|
|
667
|
+
options.filename = os.path.basename(path)
|
|
668
|
+
elif data.startswith("data:"):
|
|
669
|
+
match = re.match(r"^data:([^;]+);base64,(.+)$", data)
|
|
670
|
+
if not match:
|
|
671
|
+
raise ValueError("Invalid base64 data URI format")
|
|
672
|
+
content_type = content_type or match.group(1)
|
|
673
|
+
raw_bytes = _b64_to_bytes(match.group(2))
|
|
674
|
+
elif _looks_like_base64(data):
|
|
675
|
+
raw_bytes = _b64_to_bytes(data)
|
|
676
|
+
content_type = content_type or "application/octet-stream"
|
|
677
|
+
else:
|
|
678
|
+
raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
|
|
679
|
+
|
|
680
|
+
file_req = {
|
|
681
|
+
"files": [
|
|
682
|
+
{
|
|
683
|
+
"uri": "",
|
|
684
|
+
"filename": options.filename,
|
|
685
|
+
"content_type": content_type,
|
|
686
|
+
"path": options.path,
|
|
687
|
+
"size": len(raw_bytes),
|
|
688
|
+
"public": options.public,
|
|
689
|
+
}
|
|
690
|
+
]
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
created = await self._request("post", "/files", data=file_req)
|
|
694
|
+
file_obj = created[0]
|
|
695
|
+
upload_url = file_obj.get("upload_url")
|
|
696
|
+
if not upload_url:
|
|
697
|
+
raise RuntimeError("No upload URL provided by the server")
|
|
698
|
+
|
|
699
|
+
aiohttp = await _require_aiohttp()
|
|
700
|
+
timeout_cfg = aiohttp.ClientTimeout(total=60)
|
|
701
|
+
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
702
|
+
async with session.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type}) as resp:
|
|
703
|
+
if resp.status // 100 != 2:
|
|
704
|
+
raise RuntimeError(f"Failed to upload file content: {resp.reason}")
|
|
705
|
+
return file_obj
|
|
706
|
+
|
|
707
|
+
# --------------- Helpers ---------------
|
|
708
|
+
async def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
|
|
709
|
+
if input_value is None:
|
|
710
|
+
return input_value
|
|
711
|
+
|
|
712
|
+
if isinstance(input_value, list):
|
|
713
|
+
return [await self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
|
|
714
|
+
|
|
715
|
+
if isinstance(input_value, dict):
|
|
716
|
+
processed: Dict[str, Any] = {}
|
|
717
|
+
for key, value in input_value.items():
|
|
718
|
+
processed[key] = await self._process_input_data(value, f"{path}.{key}")
|
|
719
|
+
return processed
|
|
720
|
+
|
|
721
|
+
if isinstance(input_value, str):
|
|
722
|
+
if os.path.exists(input_value):
|
|
723
|
+
file_obj = await self.upload_file(input_value)
|
|
724
|
+
return file_obj.get("uri")
|
|
725
|
+
if input_value.startswith("data:") or _looks_like_base64(input_value):
|
|
726
|
+
file_obj = await self.upload_file(input_value)
|
|
727
|
+
return file_obj.get("uri")
|
|
728
|
+
return input_value
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
from .models.file import File as SDKFile # local import
|
|
732
|
+
if isinstance(input_value, SDKFile):
|
|
733
|
+
src = input_value.path or input_value.uri
|
|
734
|
+
if not src:
|
|
735
|
+
return input_value
|
|
736
|
+
file_obj = await self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
|
|
737
|
+
return file_obj.get("uri")
|
|
738
|
+
except Exception:
|
|
739
|
+
raise
|
|
740
|
+
|
|
741
|
+
return input_value
|
|
742
|
+
|
|
743
|
+
async def _aiter_sse(self, resp: Any) -> Generator[Dict[str, Any], None, None]:
|
|
744
|
+
async for raw_line in resp.content: # type: ignore[attr-defined]
|
|
745
|
+
try:
|
|
746
|
+
line = raw_line.decode().rstrip("\n")
|
|
747
|
+
except Exception:
|
|
748
|
+
continue
|
|
749
|
+
if not line:
|
|
750
|
+
continue
|
|
751
|
+
if line.startswith(":"):
|
|
752
|
+
continue
|
|
753
|
+
if line.startswith("data:"):
|
|
754
|
+
data_str = line[5:].lstrip()
|
|
755
|
+
if not data_str:
|
|
756
|
+
continue
|
|
757
|
+
try:
|
|
758
|
+
yield json.loads(data_str)
|
|
759
|
+
except json.JSONDecodeError:
|
|
760
|
+
continue
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
# --------------- small async utilities ---------------
|
|
764
|
+
async def _async_sleep(seconds: float) -> None:
|
|
765
|
+
import asyncio
|
|
766
|
+
|
|
767
|
+
await asyncio.sleep(seconds)
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _b64_to_bytes(b64: str) -> bytes:
|
|
771
|
+
import base64
|
|
772
|
+
|
|
773
|
+
return base64.b64decode(b64)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
async def _aio_open_file(path: str, mode: str):
|
|
777
|
+
import aiofiles # type: ignore
|
|
778
|
+
|
|
779
|
+
return await aiofiles.open(path, mode)
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def _looks_like_base64(value: str) -> bool:
|
|
783
|
+
# Reject very short strings to avoid matching normal words like "hi"
|
|
784
|
+
if len(value) < 16:
|
|
785
|
+
return False
|
|
786
|
+
# Quick charset check
|
|
787
|
+
if not Base64_RE.match(value):
|
|
788
|
+
return False
|
|
789
|
+
# Must be divisible by 4
|
|
790
|
+
if len(value) % 4 != 0:
|
|
791
|
+
return False
|
|
792
|
+
# Try decode to be sure
|
|
793
|
+
try:
|
|
794
|
+
_ = _b64_to_bytes(value)
|
|
795
|
+
return True
|
|
796
|
+
except Exception:
|
|
797
|
+
return False
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def _process_stream_event(
|
|
801
|
+
data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
|
|
802
|
+
) -> Optional[Dict[str, Any]]:
|
|
803
|
+
"""Shared handler for SSE task events. Returns final task dict when completed, else None.
|
|
804
|
+
If stopper is provided, it will be called on terminal events to end streaming.
|
|
805
|
+
"""
|
|
806
|
+
status = data.get("status")
|
|
807
|
+
output = data.get("output")
|
|
808
|
+
logs = data.get("logs")
|
|
809
|
+
|
|
810
|
+
if status == TaskStatus.COMPLETED:
|
|
811
|
+
result = {
|
|
812
|
+
**task,
|
|
813
|
+
"status": data.get("status"),
|
|
814
|
+
"output": data.get("output"),
|
|
815
|
+
"logs": data.get("logs") or [],
|
|
816
|
+
}
|
|
817
|
+
if stopper:
|
|
818
|
+
stopper()
|
|
819
|
+
return result
|
|
820
|
+
if status == TaskStatus.FAILED:
|
|
821
|
+
if stopper:
|
|
822
|
+
stopper()
|
|
823
|
+
raise RuntimeError(data.get("error") or "task failed")
|
|
824
|
+
if status == TaskStatus.CANCELLED:
|
|
825
|
+
if stopper:
|
|
826
|
+
stopper()
|
|
827
|
+
raise RuntimeError("task cancelled")
|
|
828
|
+
return None
|
|
829
|
+
|
|
830
|
+
|