inferencesh 0.2.37__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of inferencesh might be problematic. Click here for more details.
- {inferencesh-0.2.37/src/inferencesh.egg-info → inferencesh-0.3.1}/PKG-INFO +5 -1
- {inferencesh-0.2.37 → inferencesh-0.3.1}/pyproject.toml +7 -1
- {inferencesh-0.2.37 → inferencesh-0.3.1}/setup.py +2 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/__init__.py +5 -0
- inferencesh-0.3.1/src/inferencesh/client.py +830 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/models/file.py +27 -3
- {inferencesh-0.2.37 → inferencesh-0.3.1/src/inferencesh.egg-info}/PKG-INFO +5 -1
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh.egg-info/SOURCES.txt +2 -0
- inferencesh-0.3.1/src/inferencesh.egg-info/requires.txt +13 -0
- inferencesh-0.3.1/tests/test_client.py +159 -0
- inferencesh-0.2.37/src/inferencesh.egg-info/requires.txt +0 -6
- {inferencesh-0.2.37 → inferencesh-0.3.1}/LICENSE +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/README.md +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/setup.cfg +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/models/__init__.py +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/models/base.py +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/models/llm.py +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/utils/__init__.py +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/utils/download.py +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh/utils/storage.py +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh.egg-info/dependency_links.txt +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh.egg-info/entry_points.txt +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/src/inferencesh.egg-info/top_level.txt +0 -0
- {inferencesh-0.2.37 → inferencesh-0.3.1}/tests/test_sdk.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: inferencesh
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: inference.sh Python SDK
|
|
5
5
|
Author: Inference Shell Inc.
|
|
6
6
|
Author-email: "Inference Shell Inc." <hello@inference.sh>
|
|
@@ -14,9 +14,13 @@ Description-Content-Type: text/markdown
|
|
|
14
14
|
License-File: LICENSE
|
|
15
15
|
Requires-Dist: pydantic>=2.0.0
|
|
16
16
|
Requires-Dist: tqdm>=4.67.0
|
|
17
|
+
Requires-Dist: requests>=2.31.0
|
|
17
18
|
Provides-Extra: test
|
|
18
19
|
Requires-Dist: pytest>=7.0.0; extra == "test"
|
|
19
20
|
Requires-Dist: pytest-cov>=4.0.0; extra == "test"
|
|
21
|
+
Provides-Extra: async
|
|
22
|
+
Requires-Dist: aiohttp>=3.9.0; python_version >= "3.8" and extra == "async"
|
|
23
|
+
Requires-Dist: aiofiles>=23.2.1; python_version >= "3.8" and extra == "async"
|
|
20
24
|
Dynamic: author
|
|
21
25
|
Dynamic: license-file
|
|
22
26
|
Dynamic: requires-python
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "inferencesh"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.3.1"
|
|
8
8
|
description = "inference.sh Python SDK"
|
|
9
9
|
authors = [
|
|
10
10
|
{name = "Inference Shell Inc.", email = "hello@inference.sh"},
|
|
@@ -19,6 +19,8 @@ classifiers = [
|
|
|
19
19
|
dependencies = [
|
|
20
20
|
"pydantic>=2.0.0",
|
|
21
21
|
"tqdm>=4.67.0",
|
|
22
|
+
# Required for the synchronous client and examples
|
|
23
|
+
"requests>=2.31.0",
|
|
22
24
|
]
|
|
23
25
|
|
|
24
26
|
[project.urls]
|
|
@@ -42,3 +44,7 @@ test = [
|
|
|
42
44
|
"pytest>=7.0.0",
|
|
43
45
|
"pytest-cov>=4.0.0",
|
|
44
46
|
]
|
|
47
|
+
async = [
|
|
48
|
+
"aiohttp>=3.9.0; python_version >= '3.8'",
|
|
49
|
+
"aiofiles>=23.2.1; python_version >= '3.8'",
|
|
50
|
+
]
|
|
@@ -17,6 +17,7 @@ from .models import (
|
|
|
17
17
|
timing_context,
|
|
18
18
|
)
|
|
19
19
|
from .utils import StorageDir, download
|
|
20
|
+
from .client import Inference, AsyncInference, UploadFileOptions, TaskStatus
|
|
20
21
|
|
|
21
22
|
__all__ = [
|
|
22
23
|
"BaseApp",
|
|
@@ -33,4 +34,8 @@ __all__ = [
|
|
|
33
34
|
"timing_context",
|
|
34
35
|
"StorageDir",
|
|
35
36
|
"download",
|
|
37
|
+
"Inference",
|
|
38
|
+
"AsyncInference",
|
|
39
|
+
"UploadFileOptions",
|
|
40
|
+
"TaskStatus",
|
|
36
41
|
]
|
|
@@ -0,0 +1,830 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional, Callable, Generator, Union
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import IntEnum
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
import time
|
|
9
|
+
import mimetypes
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
|
|
14
|
+
def _require_requests():
|
|
15
|
+
try:
|
|
16
|
+
import requests # type: ignore
|
|
17
|
+
return requests
|
|
18
|
+
except Exception as exc: # pragma: no cover - dependency hint
|
|
19
|
+
raise RuntimeError(
|
|
20
|
+
"The 'requests' package is required for synchronous HTTP calls. Install with: pip install requests"
|
|
21
|
+
) from exc
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def _require_aiohttp():
|
|
25
|
+
try:
|
|
26
|
+
import aiohttp # type: ignore
|
|
27
|
+
return aiohttp
|
|
28
|
+
except Exception as exc: # pragma: no cover - dependency hint
|
|
29
|
+
raise RuntimeError(
|
|
30
|
+
"The 'aiohttp' package is required for async HTTP calls. Install with: pip install aiohttp"
|
|
31
|
+
) from exc
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TaskStatus(IntEnum):
|
|
35
|
+
RECEIVED = 1
|
|
36
|
+
QUEUED = 2
|
|
37
|
+
SCHEDULED = 3
|
|
38
|
+
PREPARING = 4
|
|
39
|
+
SERVING = 5
|
|
40
|
+
SETTING_UP = 6
|
|
41
|
+
RUNNING = 7
|
|
42
|
+
UPLOADING = 8
|
|
43
|
+
COMPLETED = 9
|
|
44
|
+
FAILED = 10
|
|
45
|
+
CANCELLED = 11
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
Base64_RE = re.compile(r"^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class UploadFileOptions:
|
|
53
|
+
filename: Optional[str] = None
|
|
54
|
+
content_type: Optional[str] = None
|
|
55
|
+
path: Optional[str] = None
|
|
56
|
+
public: Optional[bool] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class StreamManager:
|
|
60
|
+
"""Simple SSE stream manager with optional auto-reconnect."""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
*,
|
|
65
|
+
create_event_source: Callable[[], Any],
|
|
66
|
+
auto_reconnect: bool = True,
|
|
67
|
+
max_reconnects: int = 5,
|
|
68
|
+
reconnect_delay_ms: int = 1000,
|
|
69
|
+
on_error: Optional[Callable[[Exception], None]] = None,
|
|
70
|
+
on_start: Optional[Callable[[], None]] = None,
|
|
71
|
+
on_stop: Optional[Callable[[], None]] = None,
|
|
72
|
+
on_data: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
self._create_event_source = create_event_source
|
|
75
|
+
self._auto_reconnect = auto_reconnect
|
|
76
|
+
self._max_reconnects = max_reconnects
|
|
77
|
+
self._reconnect_delay_ms = reconnect_delay_ms
|
|
78
|
+
self._on_error = on_error
|
|
79
|
+
self._on_start = on_start
|
|
80
|
+
self._on_stop = on_stop
|
|
81
|
+
self._on_data = on_data
|
|
82
|
+
|
|
83
|
+
self._stopped = False
|
|
84
|
+
self._reconnect_attempts = 0
|
|
85
|
+
self._had_successful_connection = False
|
|
86
|
+
|
|
87
|
+
def stop(self) -> None:
|
|
88
|
+
self._stopped = True
|
|
89
|
+
if self._on_stop:
|
|
90
|
+
self._on_stop()
|
|
91
|
+
|
|
92
|
+
def connect(self) -> None:
|
|
93
|
+
self._stopped = False
|
|
94
|
+
self._reconnect_attempts = 0
|
|
95
|
+
while not self._stopped:
|
|
96
|
+
try:
|
|
97
|
+
if self._on_start:
|
|
98
|
+
self._on_start()
|
|
99
|
+
event_source = self._create_event_source()
|
|
100
|
+
try:
|
|
101
|
+
for data in event_source:
|
|
102
|
+
if self._stopped:
|
|
103
|
+
break
|
|
104
|
+
self._had_successful_connection = True
|
|
105
|
+
if self._on_data:
|
|
106
|
+
self._on_data(data)
|
|
107
|
+
# Check again after processing in case on_data stopped us
|
|
108
|
+
if self._stopped:
|
|
109
|
+
break
|
|
110
|
+
finally:
|
|
111
|
+
# Clean up the event source if it has a close method
|
|
112
|
+
try:
|
|
113
|
+
if hasattr(event_source, 'close'):
|
|
114
|
+
event_source.close()
|
|
115
|
+
except Exception:
|
|
116
|
+
raise
|
|
117
|
+
|
|
118
|
+
# If we're stopped or don't want to auto-reconnect, break immediately
|
|
119
|
+
if self._stopped or not self._auto_reconnect:
|
|
120
|
+
break
|
|
121
|
+
except Exception as exc: # noqa: BLE001
|
|
122
|
+
if self._on_error:
|
|
123
|
+
self._on_error(exc)
|
|
124
|
+
if self._stopped:
|
|
125
|
+
break
|
|
126
|
+
# If never connected and exceeded attempts, stop
|
|
127
|
+
if not self._had_successful_connection:
|
|
128
|
+
self._reconnect_attempts += 1
|
|
129
|
+
if self._reconnect_attempts > self._max_reconnects:
|
|
130
|
+
break
|
|
131
|
+
time.sleep(self._reconnect_delay_ms / 1000.0)
|
|
132
|
+
else:
|
|
133
|
+
# Completed without exception - if we want to auto-reconnect only after success
|
|
134
|
+
if not self._auto_reconnect:
|
|
135
|
+
break
|
|
136
|
+
time.sleep(self._reconnect_delay_ms / 1000.0)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class Inference:
|
|
140
|
+
"""Synchronous client for inference.sh API, mirroring the JS SDK behavior.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
api_key (str): The API key for authentication
|
|
144
|
+
base_url (Optional[str]): Override the default API base URL
|
|
145
|
+
sse_chunk_size (Optional[int]): Chunk size for SSE reading (default: 8192 bytes)
|
|
146
|
+
sse_mode (Optional[str]): SSE reading mode ('iter_lines' or 'raw', default: 'iter_lines')
|
|
147
|
+
|
|
148
|
+
The client supports performance tuning for SSE (Server-Sent Events) through:
|
|
149
|
+
1. sse_chunk_size: Controls the buffer size for reading SSE data (default: 8KB)
|
|
150
|
+
- Larger values may improve performance but use more memory
|
|
151
|
+
- Can also be set via INFERENCE_SSE_READ_BYTES environment variable
|
|
152
|
+
2. sse_mode: Controls how SSE data is read ('iter_lines' or 'raw')
|
|
153
|
+
- 'iter_lines': Uses requests' built-in line iteration (default)
|
|
154
|
+
- 'raw': Uses lower-level socket reading
|
|
155
|
+
- Can also be set via INFERENCE_SSE_MODE environment variable
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
*,
|
|
161
|
+
api_key: str,
|
|
162
|
+
base_url: Optional[str] = None,
|
|
163
|
+
sse_chunk_size: Optional[int] = None,
|
|
164
|
+
sse_mode: Optional[str] = None,
|
|
165
|
+
) -> None:
|
|
166
|
+
self._api_key = api_key
|
|
167
|
+
self._base_url = base_url or "https://api.inference.sh"
|
|
168
|
+
|
|
169
|
+
# SSE configuration with environment variable fallbacks
|
|
170
|
+
self._sse_mode = sse_mode or os.getenv("INFERENCE_SSE_MODE") or "iter_lines"
|
|
171
|
+
self._sse_mode = self._sse_mode.lower()
|
|
172
|
+
|
|
173
|
+
# Default to 8KB chunks, can be overridden by parameter or env var
|
|
174
|
+
try:
|
|
175
|
+
env_chunk_size = os.getenv("INFERENCE_SSE_READ_BYTES")
|
|
176
|
+
if sse_chunk_size is not None:
|
|
177
|
+
self._sse_read_bytes = sse_chunk_size
|
|
178
|
+
elif env_chunk_size is not None:
|
|
179
|
+
self._sse_read_bytes = int(env_chunk_size)
|
|
180
|
+
else:
|
|
181
|
+
self._sse_read_bytes = 8192 # 8KB default
|
|
182
|
+
except Exception:
|
|
183
|
+
self._sse_read_bytes = 8192 # Default to 8KB chunks on error
|
|
184
|
+
|
|
185
|
+
# --------------- HTTP helpers ---------------
|
|
186
|
+
def _headers(self) -> Dict[str, str]:
|
|
187
|
+
return {
|
|
188
|
+
"Content-Type": "application/json",
|
|
189
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def _request(
|
|
193
|
+
self,
|
|
194
|
+
method: str,
|
|
195
|
+
endpoint: str,
|
|
196
|
+
*,
|
|
197
|
+
params: Optional[Dict[str, Any]] = None,
|
|
198
|
+
data: Optional[Dict[str, Any]] = None,
|
|
199
|
+
headers: Optional[Dict[str, str]] = None,
|
|
200
|
+
stream: bool = False,
|
|
201
|
+
timeout: Optional[float] = None,
|
|
202
|
+
) -> Any:
|
|
203
|
+
requests = _require_requests()
|
|
204
|
+
url = f"{self._base_url}{endpoint}"
|
|
205
|
+
merged_headers = {**self._headers(), **(headers or {})}
|
|
206
|
+
resp = requests.request(
|
|
207
|
+
method=method.upper(),
|
|
208
|
+
url=url,
|
|
209
|
+
params=params,
|
|
210
|
+
data=json.dumps(data) if data is not None else None,
|
|
211
|
+
headers=merged_headers,
|
|
212
|
+
stream=stream,
|
|
213
|
+
timeout=timeout or 30,
|
|
214
|
+
)
|
|
215
|
+
if stream:
|
|
216
|
+
return resp
|
|
217
|
+
resp.raise_for_status()
|
|
218
|
+
payload = resp.json()
|
|
219
|
+
if not isinstance(payload, dict) or not payload.get("success", False):
|
|
220
|
+
message = None
|
|
221
|
+
if isinstance(payload, dict) and payload.get("error"):
|
|
222
|
+
err = payload["error"]
|
|
223
|
+
if isinstance(err, dict):
|
|
224
|
+
message = err.get("message")
|
|
225
|
+
else:
|
|
226
|
+
message = str(err)
|
|
227
|
+
raise RuntimeError(message or "Request failed")
|
|
228
|
+
return payload.get("data")
|
|
229
|
+
|
|
230
|
+
# --------------- Public API ---------------
|
|
231
|
+
def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
232
|
+
processed_input = self._process_input_data(params.get("input"))
|
|
233
|
+
task = self._request("post", "/run", data={**params, "input": processed_input})
|
|
234
|
+
return task
|
|
235
|
+
|
|
236
|
+
def run_sync(
|
|
237
|
+
self,
|
|
238
|
+
params: Dict[str, Any],
|
|
239
|
+
*,
|
|
240
|
+
auto_reconnect: bool = True,
|
|
241
|
+
max_reconnects: int = 5,
|
|
242
|
+
reconnect_delay_ms: int = 1000,
|
|
243
|
+
) -> Dict[str, Any]:
|
|
244
|
+
processed_input = self._process_input_data(params.get("input"))
|
|
245
|
+
task = self._request("post", "/run", data={**params, "input": processed_input})
|
|
246
|
+
task_id = task["id"]
|
|
247
|
+
|
|
248
|
+
final_task: Optional[Dict[str, Any]] = None
|
|
249
|
+
|
|
250
|
+
def on_data(data: Dict[str, Any]) -> None:
|
|
251
|
+
nonlocal final_task
|
|
252
|
+
try:
|
|
253
|
+
result = _process_stream_event(
|
|
254
|
+
data,
|
|
255
|
+
task=task,
|
|
256
|
+
stopper=lambda: manager.stop(),
|
|
257
|
+
)
|
|
258
|
+
if result is not None:
|
|
259
|
+
final_task = result
|
|
260
|
+
except Exception as exc:
|
|
261
|
+
raise
|
|
262
|
+
|
|
263
|
+
def on_error(exc: Exception) -> None:
|
|
264
|
+
raise exc
|
|
265
|
+
|
|
266
|
+
def on_start() -> None:
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
def on_stop() -> None:
|
|
270
|
+
pass
|
|
271
|
+
|
|
272
|
+
manager = StreamManager(
|
|
273
|
+
create_event_source=None, # We'll set this after defining it
|
|
274
|
+
auto_reconnect=auto_reconnect,
|
|
275
|
+
max_reconnects=max_reconnects,
|
|
276
|
+
reconnect_delay_ms=reconnect_delay_ms,
|
|
277
|
+
on_data=on_data,
|
|
278
|
+
on_error=on_error,
|
|
279
|
+
on_start=on_start,
|
|
280
|
+
on_stop=on_stop,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def create_event_source() -> Generator[Dict[str, Any], None, None]:
|
|
284
|
+
url = f"/tasks/{task_id}/stream"
|
|
285
|
+
resp = self._request(
|
|
286
|
+
"get",
|
|
287
|
+
url,
|
|
288
|
+
headers={
|
|
289
|
+
"Accept": "text/event-stream",
|
|
290
|
+
"Cache-Control": "no-cache",
|
|
291
|
+
"Accept-Encoding": "identity",
|
|
292
|
+
"Connection": "keep-alive",
|
|
293
|
+
},
|
|
294
|
+
stream=True,
|
|
295
|
+
timeout=60,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
last_event_at = time.perf_counter()
|
|
300
|
+
for evt in self._iter_sse(resp, stream_manager=manager):
|
|
301
|
+
yield evt
|
|
302
|
+
finally:
|
|
303
|
+
try:
|
|
304
|
+
# Force close the underlying socket if possible
|
|
305
|
+
try:
|
|
306
|
+
raw = getattr(resp, 'raw', None)
|
|
307
|
+
if raw is not None:
|
|
308
|
+
raw.close()
|
|
309
|
+
except Exception:
|
|
310
|
+
raise
|
|
311
|
+
# Close the response
|
|
312
|
+
resp.close()
|
|
313
|
+
except Exception:
|
|
314
|
+
raise
|
|
315
|
+
|
|
316
|
+
# Update the create_event_source function in the manager
|
|
317
|
+
manager._create_event_source = create_event_source
|
|
318
|
+
|
|
319
|
+
# Connect and wait for completion
|
|
320
|
+
manager.connect()
|
|
321
|
+
|
|
322
|
+
# At this point, we should have a final task state
|
|
323
|
+
if final_task is not None:
|
|
324
|
+
return final_task
|
|
325
|
+
|
|
326
|
+
# Try to fetch the latest state as a fallback
|
|
327
|
+
try:
|
|
328
|
+
latest = self.get_task(task_id)
|
|
329
|
+
status = latest.get("status")
|
|
330
|
+
if status == TaskStatus.COMPLETED:
|
|
331
|
+
return latest
|
|
332
|
+
if status == TaskStatus.FAILED:
|
|
333
|
+
raise RuntimeError(latest.get("error") or "task failed")
|
|
334
|
+
if status == TaskStatus.CANCELLED:
|
|
335
|
+
raise RuntimeError("task cancelled")
|
|
336
|
+
except Exception as exc:
|
|
337
|
+
raise
|
|
338
|
+
|
|
339
|
+
raise RuntimeError("Stream ended without completion")
|
|
340
|
+
|
|
341
|
+
def cancel(self, task_id: str) -> None:
|
|
342
|
+
self._request("post", f"/tasks/{task_id}/cancel")
|
|
343
|
+
|
|
344
|
+
def get_task(self, task_id: str) -> Dict[str, Any]:
|
|
345
|
+
return self._request("get", f"/tasks/{task_id}")
|
|
346
|
+
|
|
347
|
+
# --------------- File upload ---------------
|
|
348
|
+
def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
|
|
349
|
+
options = options or UploadFileOptions()
|
|
350
|
+
content_type = options.content_type
|
|
351
|
+
raw_bytes: bytes
|
|
352
|
+
if isinstance(data, bytes):
|
|
353
|
+
raw_bytes = data
|
|
354
|
+
if not content_type:
|
|
355
|
+
content_type = "application/octet-stream"
|
|
356
|
+
else:
|
|
357
|
+
# Prefer local filesystem path if it exists
|
|
358
|
+
if os.path.exists(data):
|
|
359
|
+
path = data
|
|
360
|
+
guessed = mimetypes.guess_type(path)[0]
|
|
361
|
+
content_type = content_type or guessed or "application/octet-stream"
|
|
362
|
+
with open(path, "rb") as f:
|
|
363
|
+
raw_bytes = f.read()
|
|
364
|
+
if not options.filename:
|
|
365
|
+
options.filename = os.path.basename(path)
|
|
366
|
+
elif data.startswith("data:"):
|
|
367
|
+
# data URI
|
|
368
|
+
match = re.match(r"^data:([^;]+);base64,(.+)$", data)
|
|
369
|
+
if not match:
|
|
370
|
+
raise ValueError("Invalid base64 data URI format")
|
|
371
|
+
content_type = content_type or match.group(1)
|
|
372
|
+
raw_bytes = _b64_to_bytes(match.group(2))
|
|
373
|
+
elif _looks_like_base64(data):
|
|
374
|
+
raw_bytes = _b64_to_bytes(data)
|
|
375
|
+
content_type = content_type or "application/octet-stream"
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
|
|
378
|
+
|
|
379
|
+
file_req = {
|
|
380
|
+
"files": [
|
|
381
|
+
{
|
|
382
|
+
"uri": "",
|
|
383
|
+
"filename": options.filename,
|
|
384
|
+
"content_type": content_type,
|
|
385
|
+
"path": options.path,
|
|
386
|
+
"size": len(raw_bytes),
|
|
387
|
+
"public": options.public,
|
|
388
|
+
}
|
|
389
|
+
]
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
created = self._request("post", "/files", data=file_req)
|
|
393
|
+
file_obj = created[0]
|
|
394
|
+
upload_url = file_obj.get("upload_url")
|
|
395
|
+
if not upload_url:
|
|
396
|
+
raise RuntimeError("No upload URL provided by the server")
|
|
397
|
+
|
|
398
|
+
# Upload to S3 (or compatible) signed URL
|
|
399
|
+
requests = _require_requests()
|
|
400
|
+
put_resp = requests.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type})
|
|
401
|
+
if not (200 <= put_resp.status_code < 300):
|
|
402
|
+
raise RuntimeError(f"Failed to upload file content: {put_resp.reason}")
|
|
403
|
+
return file_obj
|
|
404
|
+
|
|
405
|
+
# --------------- Helpers ---------------
|
|
406
|
+
def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
|
|
407
|
+
"""Iterate JSON events from an SSE response."""
|
|
408
|
+
# Mode 1: raw socket readline (can reduce buffering in some environments)
|
|
409
|
+
if self._sse_mode == "raw":
|
|
410
|
+
raw = getattr(resp, "raw", None)
|
|
411
|
+
if raw is not None:
|
|
412
|
+
try:
|
|
413
|
+
# Avoid urllib3 decompression buffering
|
|
414
|
+
raw.decode_content = False # type: ignore[attr-defined]
|
|
415
|
+
except Exception:
|
|
416
|
+
raise
|
|
417
|
+
buf = bytearray()
|
|
418
|
+
read_size = max(1, int(self._sse_read_bytes))
|
|
419
|
+
while True:
|
|
420
|
+
# Check if we've been asked to stop before reading more data
|
|
421
|
+
try:
|
|
422
|
+
if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
|
|
423
|
+
break
|
|
424
|
+
except Exception:
|
|
425
|
+
raise
|
|
426
|
+
|
|
427
|
+
chunk = raw.read(read_size)
|
|
428
|
+
if not chunk:
|
|
429
|
+
break
|
|
430
|
+
for b in chunk:
|
|
431
|
+
if b == 10: # '\n'
|
|
432
|
+
try:
|
|
433
|
+
line = buf.decode(errors="ignore").rstrip("\r")
|
|
434
|
+
except Exception:
|
|
435
|
+
line = ""
|
|
436
|
+
buf.clear()
|
|
437
|
+
if not line:
|
|
438
|
+
continue
|
|
439
|
+
if line.startswith(":"):
|
|
440
|
+
continue
|
|
441
|
+
if line.startswith("data:"):
|
|
442
|
+
data_str = line[5:].lstrip()
|
|
443
|
+
if not data_str:
|
|
444
|
+
continue
|
|
445
|
+
try:
|
|
446
|
+
yield json.loads(data_str)
|
|
447
|
+
except json.JSONDecodeError:
|
|
448
|
+
continue
|
|
449
|
+
else:
|
|
450
|
+
buf.append(b)
|
|
451
|
+
return
|
|
452
|
+
# Mode 2: default iter_lines with reasonable chunk size (8KB)
|
|
453
|
+
for line in resp.iter_lines(decode_unicode=True, chunk_size=8192):
|
|
454
|
+
# Check if we've been asked to stop before processing any more lines
|
|
455
|
+
try:
|
|
456
|
+
if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
|
|
457
|
+
break
|
|
458
|
+
except Exception:
|
|
459
|
+
raise
|
|
460
|
+
|
|
461
|
+
if not line:
|
|
462
|
+
continue
|
|
463
|
+
if line.startswith(":"):
|
|
464
|
+
continue
|
|
465
|
+
if line.startswith("data:"):
|
|
466
|
+
data_str = line[5:].lstrip()
|
|
467
|
+
if not data_str:
|
|
468
|
+
continue
|
|
469
|
+
try:
|
|
470
|
+
yield json.loads(data_str)
|
|
471
|
+
except json.JSONDecodeError:
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
|
|
475
|
+
if input_value is None:
|
|
476
|
+
return input_value
|
|
477
|
+
|
|
478
|
+
# Handle lists
|
|
479
|
+
if isinstance(input_value, list):
|
|
480
|
+
return [self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
|
|
481
|
+
|
|
482
|
+
# Handle dicts
|
|
483
|
+
if isinstance(input_value, dict):
|
|
484
|
+
processed: Dict[str, Any] = {}
|
|
485
|
+
for key, value in input_value.items():
|
|
486
|
+
processed[key] = self._process_input_data(value, f"{path}.{key}")
|
|
487
|
+
return processed
|
|
488
|
+
|
|
489
|
+
# Handle strings that are filesystem paths, data URIs, or base64
|
|
490
|
+
if isinstance(input_value, str):
|
|
491
|
+
# Prefer existing local file paths first to avoid misclassifying plain strings
|
|
492
|
+
if os.path.exists(input_value):
|
|
493
|
+
file_obj = self.upload_file(input_value)
|
|
494
|
+
return file_obj.get("uri")
|
|
495
|
+
if input_value.startswith("data:") or _looks_like_base64(input_value):
|
|
496
|
+
file_obj = self.upload_file(input_value)
|
|
497
|
+
return file_obj.get("uri")
|
|
498
|
+
return input_value
|
|
499
|
+
|
|
500
|
+
# Handle File-like objects from our models
|
|
501
|
+
try:
|
|
502
|
+
from .models.file import File as SDKFile # local import to avoid cycle
|
|
503
|
+
if isinstance(input_value, SDKFile):
|
|
504
|
+
# Prefer local path if present, else uri
|
|
505
|
+
src = input_value.path or input_value.uri
|
|
506
|
+
if not src:
|
|
507
|
+
return input_value
|
|
508
|
+
file_obj = self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
|
|
509
|
+
return file_obj.get("uri")
|
|
510
|
+
except Exception:
|
|
511
|
+
raise
|
|
512
|
+
|
|
513
|
+
return input_value
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class AsyncInference:
|
|
517
|
+
"""Async client for inference.sh API, mirroring the JS SDK behavior."""
|
|
518
|
+
|
|
519
|
+
def __init__(self, *, api_key: str, base_url: Optional[str] = None) -> None:
|
|
520
|
+
self._api_key = api_key
|
|
521
|
+
self._base_url = base_url or "https://api.inference.sh"
|
|
522
|
+
|
|
523
|
+
# --------------- HTTP helpers ---------------
|
|
524
|
+
def _headers(self) -> Dict[str, str]:
|
|
525
|
+
return {
|
|
526
|
+
"Content-Type": "application/json",
|
|
527
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
async def _request(
|
|
531
|
+
self,
|
|
532
|
+
method: str,
|
|
533
|
+
endpoint: str,
|
|
534
|
+
*,
|
|
535
|
+
params: Optional[Dict[str, Any]] = None,
|
|
536
|
+
data: Optional[Dict[str, Any]] = None,
|
|
537
|
+
headers: Optional[Dict[str, str]] = None,
|
|
538
|
+
timeout: Optional[float] = None,
|
|
539
|
+
expect_stream: bool = False,
|
|
540
|
+
) -> Any:
|
|
541
|
+
aiohttp = await _require_aiohttp()
|
|
542
|
+
url = f"{self._base_url}{endpoint}"
|
|
543
|
+
merged_headers = {**self._headers(), **(headers or {})}
|
|
544
|
+
timeout_cfg = aiohttp.ClientTimeout(total=timeout or 30)
|
|
545
|
+
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
546
|
+
async with session.request(
|
|
547
|
+
method=method.upper(),
|
|
548
|
+
url=url,
|
|
549
|
+
params=params,
|
|
550
|
+
json=data,
|
|
551
|
+
headers=merged_headers,
|
|
552
|
+
) as resp:
|
|
553
|
+
if expect_stream:
|
|
554
|
+
return resp
|
|
555
|
+
payload = await resp.json()
|
|
556
|
+
if not isinstance(payload, dict) or not payload.get("success", False):
|
|
557
|
+
message = None
|
|
558
|
+
if isinstance(payload, dict) and payload.get("error"):
|
|
559
|
+
err = payload["error"]
|
|
560
|
+
if isinstance(err, dict):
|
|
561
|
+
message = err.get("message")
|
|
562
|
+
else:
|
|
563
|
+
message = str(err)
|
|
564
|
+
raise RuntimeError(message or "Request failed")
|
|
565
|
+
return payload.get("data")
|
|
566
|
+
|
|
567
|
+
# --------------- Public API ---------------
|
|
568
|
+
async def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
569
|
+
processed_input = await self._process_input_data(params.get("input"))
|
|
570
|
+
task = await self._request("post", "/run", data={**params, "input": processed_input})
|
|
571
|
+
return task
|
|
572
|
+
|
|
573
|
+
async def run_sync(
|
|
574
|
+
self,
|
|
575
|
+
params: Dict[str, Any],
|
|
576
|
+
*,
|
|
577
|
+
auto_reconnect: bool = True,
|
|
578
|
+
max_reconnects: int = 5,
|
|
579
|
+
reconnect_delay_ms: int = 1000,
|
|
580
|
+
) -> Dict[str, Any]:
|
|
581
|
+
processed_input = await self._process_input_data(params.get("input"))
|
|
582
|
+
task = await self._request("post", "/run", data={**params, "input": processed_input})
|
|
583
|
+
task_id = task["id"]
|
|
584
|
+
|
|
585
|
+
final_task: Optional[Dict[str, Any]] = None
|
|
586
|
+
reconnect_attempts = 0
|
|
587
|
+
had_success = False
|
|
588
|
+
|
|
589
|
+
while True:
|
|
590
|
+
try:
|
|
591
|
+
resp = await self._request(
|
|
592
|
+
"get",
|
|
593
|
+
f"/tasks/{task_id}/stream",
|
|
594
|
+
headers={
|
|
595
|
+
"Accept": "text/event-stream",
|
|
596
|
+
"Cache-Control": "no-cache",
|
|
597
|
+
"Accept-Encoding": "identity",
|
|
598
|
+
"Connection": "keep-alive",
|
|
599
|
+
},
|
|
600
|
+
timeout=60,
|
|
601
|
+
expect_stream=True,
|
|
602
|
+
)
|
|
603
|
+
had_success = True
|
|
604
|
+
async for data in self._aiter_sse(resp):
|
|
605
|
+
result = _process_stream_event(
|
|
606
|
+
data,
|
|
607
|
+
task=task,
|
|
608
|
+
stopper=None,
|
|
609
|
+
)
|
|
610
|
+
if result is not None:
|
|
611
|
+
final_task = result
|
|
612
|
+
break
|
|
613
|
+
if final_task is not None:
|
|
614
|
+
break
|
|
615
|
+
except Exception as exc: # noqa: BLE001
|
|
616
|
+
if not auto_reconnect:
|
|
617
|
+
raise
|
|
618
|
+
if not had_success:
|
|
619
|
+
reconnect_attempts += 1
|
|
620
|
+
if reconnect_attempts > max_reconnects:
|
|
621
|
+
raise
|
|
622
|
+
await _async_sleep(reconnect_delay_ms / 1000.0)
|
|
623
|
+
else:
|
|
624
|
+
if not auto_reconnect:
|
|
625
|
+
break
|
|
626
|
+
await _async_sleep(reconnect_delay_ms / 1000.0)
|
|
627
|
+
|
|
628
|
+
if final_task is None:
|
|
629
|
+
# Fallback: fetch latest task state in case stream ended without a terminal event
|
|
630
|
+
try:
|
|
631
|
+
latest = await self.get_task(task_id)
|
|
632
|
+
status = latest.get("status")
|
|
633
|
+
if status == TaskStatus.COMPLETED:
|
|
634
|
+
return latest
|
|
635
|
+
if status == TaskStatus.FAILED:
|
|
636
|
+
raise RuntimeError(latest.get("error") or "task failed")
|
|
637
|
+
if status == TaskStatus.CANCELLED:
|
|
638
|
+
raise RuntimeError("task cancelled")
|
|
639
|
+
except Exception:
|
|
640
|
+
raise
|
|
641
|
+
raise RuntimeError("Stream ended without completion")
|
|
642
|
+
return final_task
|
|
643
|
+
|
|
644
|
+
async def cancel(self, task_id: str) -> None:
|
|
645
|
+
await self._request("post", f"/tasks/{task_id}/cancel")
|
|
646
|
+
|
|
647
|
+
async def get_task(self, task_id: str) -> Dict[str, Any]:
|
|
648
|
+
return await self._request("get", f"/tasks/{task_id}")
|
|
649
|
+
|
|
650
|
+
# --------------- File upload ---------------
|
|
651
|
+
async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
|
|
652
|
+
options = options or UploadFileOptions()
|
|
653
|
+
content_type = options.content_type
|
|
654
|
+
raw_bytes: bytes
|
|
655
|
+
if isinstance(data, bytes):
|
|
656
|
+
raw_bytes = data
|
|
657
|
+
if not content_type:
|
|
658
|
+
content_type = "application/octet-stream"
|
|
659
|
+
else:
|
|
660
|
+
if os.path.exists(data):
|
|
661
|
+
path = data
|
|
662
|
+
guessed = mimetypes.guess_type(path)[0]
|
|
663
|
+
content_type = content_type or guessed or "application/octet-stream"
|
|
664
|
+
async with await _aio_open_file(path, "rb") as f:
|
|
665
|
+
raw_bytes = await f.read() # type: ignore[attr-defined]
|
|
666
|
+
if not options.filename:
|
|
667
|
+
options.filename = os.path.basename(path)
|
|
668
|
+
elif data.startswith("data:"):
|
|
669
|
+
match = re.match(r"^data:([^;]+);base64,(.+)$", data)
|
|
670
|
+
if not match:
|
|
671
|
+
raise ValueError("Invalid base64 data URI format")
|
|
672
|
+
content_type = content_type or match.group(1)
|
|
673
|
+
raw_bytes = _b64_to_bytes(match.group(2))
|
|
674
|
+
elif _looks_like_base64(data):
|
|
675
|
+
raw_bytes = _b64_to_bytes(data)
|
|
676
|
+
content_type = content_type or "application/octet-stream"
|
|
677
|
+
else:
|
|
678
|
+
raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
|
|
679
|
+
|
|
680
|
+
file_req = {
|
|
681
|
+
"files": [
|
|
682
|
+
{
|
|
683
|
+
"uri": "",
|
|
684
|
+
"filename": options.filename,
|
|
685
|
+
"content_type": content_type,
|
|
686
|
+
"path": options.path,
|
|
687
|
+
"size": len(raw_bytes),
|
|
688
|
+
"public": options.public,
|
|
689
|
+
}
|
|
690
|
+
]
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
created = await self._request("post", "/files", data=file_req)
|
|
694
|
+
file_obj = created[0]
|
|
695
|
+
upload_url = file_obj.get("upload_url")
|
|
696
|
+
if not upload_url:
|
|
697
|
+
raise RuntimeError("No upload URL provided by the server")
|
|
698
|
+
|
|
699
|
+
aiohttp = await _require_aiohttp()
|
|
700
|
+
timeout_cfg = aiohttp.ClientTimeout(total=60)
|
|
701
|
+
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
702
|
+
async with session.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type}) as resp:
|
|
703
|
+
if resp.status // 100 != 2:
|
|
704
|
+
raise RuntimeError(f"Failed to upload file content: {resp.reason}")
|
|
705
|
+
return file_obj
|
|
706
|
+
|
|
707
|
+
# --------------- Helpers ---------------
|
|
708
|
+
async def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
|
|
709
|
+
if input_value is None:
|
|
710
|
+
return input_value
|
|
711
|
+
|
|
712
|
+
if isinstance(input_value, list):
|
|
713
|
+
return [await self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
|
|
714
|
+
|
|
715
|
+
if isinstance(input_value, dict):
|
|
716
|
+
processed: Dict[str, Any] = {}
|
|
717
|
+
for key, value in input_value.items():
|
|
718
|
+
processed[key] = await self._process_input_data(value, f"{path}.{key}")
|
|
719
|
+
return processed
|
|
720
|
+
|
|
721
|
+
if isinstance(input_value, str):
|
|
722
|
+
if os.path.exists(input_value):
|
|
723
|
+
file_obj = await self.upload_file(input_value)
|
|
724
|
+
return file_obj.get("uri")
|
|
725
|
+
if input_value.startswith("data:") or _looks_like_base64(input_value):
|
|
726
|
+
file_obj = await self.upload_file(input_value)
|
|
727
|
+
return file_obj.get("uri")
|
|
728
|
+
return input_value
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
from .models.file import File as SDKFile # local import
|
|
732
|
+
if isinstance(input_value, SDKFile):
|
|
733
|
+
src = input_value.path or input_value.uri
|
|
734
|
+
if not src:
|
|
735
|
+
return input_value
|
|
736
|
+
file_obj = await self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
|
|
737
|
+
return file_obj.get("uri")
|
|
738
|
+
except Exception:
|
|
739
|
+
raise
|
|
740
|
+
|
|
741
|
+
return input_value
|
|
742
|
+
|
|
743
|
+
async def _aiter_sse(self, resp: Any) -> Generator[Dict[str, Any], None, None]:
|
|
744
|
+
async for raw_line in resp.content: # type: ignore[attr-defined]
|
|
745
|
+
try:
|
|
746
|
+
line = raw_line.decode().rstrip("\n")
|
|
747
|
+
except Exception:
|
|
748
|
+
continue
|
|
749
|
+
if not line:
|
|
750
|
+
continue
|
|
751
|
+
if line.startswith(":"):
|
|
752
|
+
continue
|
|
753
|
+
if line.startswith("data:"):
|
|
754
|
+
data_str = line[5:].lstrip()
|
|
755
|
+
if not data_str:
|
|
756
|
+
continue
|
|
757
|
+
try:
|
|
758
|
+
yield json.loads(data_str)
|
|
759
|
+
except json.JSONDecodeError:
|
|
760
|
+
continue
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
# --------------- small async utilities ---------------
|
|
764
|
+
async def _async_sleep(seconds: float) -> None:
|
|
765
|
+
import asyncio
|
|
766
|
+
|
|
767
|
+
await asyncio.sleep(seconds)
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _b64_to_bytes(b64: str) -> bytes:
|
|
771
|
+
import base64
|
|
772
|
+
|
|
773
|
+
return base64.b64decode(b64)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
async def _aio_open_file(path: str, mode: str):
|
|
777
|
+
import aiofiles # type: ignore
|
|
778
|
+
|
|
779
|
+
return await aiofiles.open(path, mode)
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def _looks_like_base64(value: str) -> bool:
|
|
783
|
+
# Reject very short strings to avoid matching normal words like "hi"
|
|
784
|
+
if len(value) < 16:
|
|
785
|
+
return False
|
|
786
|
+
# Quick charset check
|
|
787
|
+
if not Base64_RE.match(value):
|
|
788
|
+
return False
|
|
789
|
+
# Must be divisible by 4
|
|
790
|
+
if len(value) % 4 != 0:
|
|
791
|
+
return False
|
|
792
|
+
# Try decode to be sure
|
|
793
|
+
try:
|
|
794
|
+
_ = _b64_to_bytes(value)
|
|
795
|
+
return True
|
|
796
|
+
except Exception:
|
|
797
|
+
return False
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def _process_stream_event(
|
|
801
|
+
data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
|
|
802
|
+
) -> Optional[Dict[str, Any]]:
|
|
803
|
+
"""Shared handler for SSE task events. Returns final task dict when completed, else None.
|
|
804
|
+
If stopper is provided, it will be called on terminal events to end streaming.
|
|
805
|
+
"""
|
|
806
|
+
status = data.get("status")
|
|
807
|
+
output = data.get("output")
|
|
808
|
+
logs = data.get("logs")
|
|
809
|
+
|
|
810
|
+
if status == TaskStatus.COMPLETED:
|
|
811
|
+
result = {
|
|
812
|
+
**task,
|
|
813
|
+
"status": data.get("status"),
|
|
814
|
+
"output": data.get("output"),
|
|
815
|
+
"logs": data.get("logs") or [],
|
|
816
|
+
}
|
|
817
|
+
if stopper:
|
|
818
|
+
stopper()
|
|
819
|
+
return result
|
|
820
|
+
if status == TaskStatus.FAILED:
|
|
821
|
+
if stopper:
|
|
822
|
+
stopper()
|
|
823
|
+
raise RuntimeError(data.get("error") or "task failed")
|
|
824
|
+
if status == TaskStatus.CANCELLED:
|
|
825
|
+
if stopper:
|
|
826
|
+
stopper()
|
|
827
|
+
raise RuntimeError("task cancelled")
|
|
828
|
+
return None
|
|
829
|
+
|
|
830
|
+
|
|
@@ -97,17 +97,41 @@ class File(BaseModel):
|
|
|
97
97
|
print(f"Downloading URL: {original_url} to {self._tmp_path}")
|
|
98
98
|
try:
|
|
99
99
|
with urllib.request.urlopen(req) as response:
|
|
100
|
-
|
|
100
|
+
# Safely retrieve content-length if available
|
|
101
|
+
total_size = 0
|
|
102
|
+
try:
|
|
103
|
+
if hasattr(response, 'headers') and response.headers is not None:
|
|
104
|
+
# urllib may expose headers as an email.message.Message
|
|
105
|
+
cl = response.headers.get('content-length')
|
|
106
|
+
total_size = int(cl) if cl is not None else 0
|
|
107
|
+
elif hasattr(response, 'getheader'):
|
|
108
|
+
cl = response.getheader('content-length')
|
|
109
|
+
total_size = int(cl) if cl is not None else 0
|
|
110
|
+
except Exception:
|
|
111
|
+
total_size = 0
|
|
112
|
+
|
|
101
113
|
block_size = 1024 # 1 Kibibyte
|
|
102
114
|
|
|
103
115
|
with tqdm(total=total_size, unit='iB', unit_scale=True) as pbar:
|
|
104
116
|
with open(self._tmp_path, 'wb') as out_file:
|
|
105
117
|
while True:
|
|
106
|
-
|
|
118
|
+
non_chunking = False
|
|
119
|
+
try:
|
|
120
|
+
buffer = response.read(block_size)
|
|
121
|
+
except TypeError:
|
|
122
|
+
# Some mocks (or minimal implementations) expose read() without size
|
|
123
|
+
buffer = response.read()
|
|
124
|
+
non_chunking = True
|
|
107
125
|
if not buffer:
|
|
108
126
|
break
|
|
109
127
|
out_file.write(buffer)
|
|
110
|
-
|
|
128
|
+
try:
|
|
129
|
+
pbar.update(len(buffer))
|
|
130
|
+
except Exception:
|
|
131
|
+
pass
|
|
132
|
+
if non_chunking:
|
|
133
|
+
# If we read the whole body at once, exit loop
|
|
134
|
+
break
|
|
111
135
|
|
|
112
136
|
self.path = self._tmp_path
|
|
113
137
|
except (urllib.error.URLError, urllib.error.HTTPError) as e:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: inferencesh
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: inference.sh Python SDK
|
|
5
5
|
Author: Inference Shell Inc.
|
|
6
6
|
Author-email: "Inference Shell Inc." <hello@inference.sh>
|
|
@@ -14,9 +14,13 @@ Description-Content-Type: text/markdown
|
|
|
14
14
|
License-File: LICENSE
|
|
15
15
|
Requires-Dist: pydantic>=2.0.0
|
|
16
16
|
Requires-Dist: tqdm>=4.67.0
|
|
17
|
+
Requires-Dist: requests>=2.31.0
|
|
17
18
|
Provides-Extra: test
|
|
18
19
|
Requires-Dist: pytest>=7.0.0; extra == "test"
|
|
19
20
|
Requires-Dist: pytest-cov>=4.0.0; extra == "test"
|
|
21
|
+
Provides-Extra: async
|
|
22
|
+
Requires-Dist: aiohttp>=3.9.0; python_version >= "3.8" and extra == "async"
|
|
23
|
+
Requires-Dist: aiofiles>=23.2.1; python_version >= "3.8" and extra == "async"
|
|
20
24
|
Dynamic: author
|
|
21
25
|
Dynamic: license-file
|
|
22
26
|
Dynamic: requires-python
|
|
@@ -3,6 +3,7 @@ README.md
|
|
|
3
3
|
pyproject.toml
|
|
4
4
|
setup.py
|
|
5
5
|
src/inferencesh/__init__.py
|
|
6
|
+
src/inferencesh/client.py
|
|
6
7
|
src/inferencesh.egg-info/PKG-INFO
|
|
7
8
|
src/inferencesh.egg-info/SOURCES.txt
|
|
8
9
|
src/inferencesh.egg-info/dependency_links.txt
|
|
@@ -16,4 +17,5 @@ src/inferencesh/models/llm.py
|
|
|
16
17
|
src/inferencesh/utils/__init__.py
|
|
17
18
|
src/inferencesh/utils/download.py
|
|
18
19
|
src/inferencesh/utils/storage.py
|
|
20
|
+
tests/test_client.py
|
|
19
21
|
tests/test_sdk.py
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import types
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import Iterator
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
from inferencesh import Inference
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DummyResponse:
|
|
14
|
+
def __init__(self, status_code=200, json_data=None, text="", lines=None):
|
|
15
|
+
self.status_code = status_code
|
|
16
|
+
self._json_data = json_data if json_data is not None else {"success": True, "data": {}}
|
|
17
|
+
self.text = text
|
|
18
|
+
self._lines = lines or []
|
|
19
|
+
|
|
20
|
+
def json(self):
|
|
21
|
+
return self._json_data
|
|
22
|
+
|
|
23
|
+
def raise_for_status(self):
|
|
24
|
+
if not (200 <= self.status_code < 300):
|
|
25
|
+
raise RuntimeError(f"HTTP error {self.status_code}")
|
|
26
|
+
|
|
27
|
+
def iter_lines(self, decode_unicode=False):
|
|
28
|
+
for line in self._lines:
|
|
29
|
+
yield line
|
|
30
|
+
|
|
31
|
+
def close(self):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture(autouse=True)
|
|
36
|
+
def patch_requests(monkeypatch):
|
|
37
|
+
calls = []
|
|
38
|
+
|
|
39
|
+
def fake_request(method, url, params=None, data=None, headers=None, stream=False, timeout=None):
|
|
40
|
+
calls.append({
|
|
41
|
+
"method": method,
|
|
42
|
+
"url": url,
|
|
43
|
+
"params": params,
|
|
44
|
+
"data": data,
|
|
45
|
+
"headers": headers,
|
|
46
|
+
"stream": stream,
|
|
47
|
+
"timeout": timeout,
|
|
48
|
+
})
|
|
49
|
+
|
|
50
|
+
# Create task
|
|
51
|
+
if url.endswith("/run") and method.upper() == "POST":
|
|
52
|
+
body = json.loads(data)
|
|
53
|
+
return DummyResponse(json_data={
|
|
54
|
+
"success": True,
|
|
55
|
+
"data": {
|
|
56
|
+
"id": "task_123",
|
|
57
|
+
"status": 1,
|
|
58
|
+
"input": body.get("input"),
|
|
59
|
+
},
|
|
60
|
+
})
|
|
61
|
+
|
|
62
|
+
# SSE stream
|
|
63
|
+
if url.endswith("/tasks/task_123/stream") and stream:
|
|
64
|
+
# Minimal SSE: send a completed event
|
|
65
|
+
event_payload = json.dumps({
|
|
66
|
+
"status": 9, # COMPLETED
|
|
67
|
+
"output": {"ok": True},
|
|
68
|
+
"logs": ["done"],
|
|
69
|
+
})
|
|
70
|
+
lines = [
|
|
71
|
+
f"data: {event_payload}",
|
|
72
|
+
"", # dispatch
|
|
73
|
+
]
|
|
74
|
+
return DummyResponse(status_code=200, lines=lines)
|
|
75
|
+
|
|
76
|
+
# Cancel
|
|
77
|
+
if url.endswith("/tasks/task_123/cancel") and method.upper() == "POST":
|
|
78
|
+
return DummyResponse(json_data={"success": True, "data": None})
|
|
79
|
+
|
|
80
|
+
# Files create
|
|
81
|
+
if url.endswith("/files") and method.upper() == "POST":
|
|
82
|
+
upload_url = "https://upload.example.com/file"
|
|
83
|
+
return DummyResponse(json_data={
|
|
84
|
+
"success": True,
|
|
85
|
+
"data": [
|
|
86
|
+
{
|
|
87
|
+
"id": "file_1",
|
|
88
|
+
"uri": "https://cloud.inference.sh/u/user/file_1.png",
|
|
89
|
+
"upload_url": upload_url,
|
|
90
|
+
}
|
|
91
|
+
],
|
|
92
|
+
})
|
|
93
|
+
|
|
94
|
+
return DummyResponse()
|
|
95
|
+
|
|
96
|
+
class FakeRequestsModule:
|
|
97
|
+
def __init__(self):
|
|
98
|
+
self.put_calls = []
|
|
99
|
+
|
|
100
|
+
def request(self, *args, **kwargs):
|
|
101
|
+
return fake_request(*args, **kwargs)
|
|
102
|
+
|
|
103
|
+
def put(self, url, data=None, headers=None):
|
|
104
|
+
self.put_calls.append({"url": url, "size": len(data or b"")})
|
|
105
|
+
return DummyResponse(status_code=200)
|
|
106
|
+
|
|
107
|
+
fake_requests = FakeRequestsModule()
|
|
108
|
+
|
|
109
|
+
def require_requests():
|
|
110
|
+
return fake_requests
|
|
111
|
+
|
|
112
|
+
# Patch internal loader
|
|
113
|
+
from inferencesh import client as client_mod
|
|
114
|
+
monkeypatch.setattr(client_mod, "_require_requests", require_requests)
|
|
115
|
+
|
|
116
|
+
yield fake_requests
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_run_and_run_sync_mocked(tmp_path):
|
|
120
|
+
client = Inference(api_key="test")
|
|
121
|
+
|
|
122
|
+
# run() should return a task id
|
|
123
|
+
task = client.run({
|
|
124
|
+
"app": "some/app",
|
|
125
|
+
"input": {"text": "hello"},
|
|
126
|
+
"worker_selection_mode": "private",
|
|
127
|
+
})
|
|
128
|
+
assert task["id"] == "task_123"
|
|
129
|
+
|
|
130
|
+
# run_sync should consume SSE and return final result merged
|
|
131
|
+
result = client.run_sync({
|
|
132
|
+
"app": "some/app",
|
|
133
|
+
"input": {"text": "hello"},
|
|
134
|
+
"worker_selection_mode": "private",
|
|
135
|
+
})
|
|
136
|
+
assert result["id"] == "task_123"
|
|
137
|
+
assert result["output"] == {"ok": True}
|
|
138
|
+
assert result["logs"] == ["done"]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_upload_and_recursive_input(monkeypatch, tmp_path, patch_requests):
|
|
142
|
+
# Create a small file
|
|
143
|
+
file_path = tmp_path / "image.png"
|
|
144
|
+
file_path.write_bytes(b"PNGDATA")
|
|
145
|
+
|
|
146
|
+
client = Inference(api_key="test")
|
|
147
|
+
|
|
148
|
+
# Input contains a local path - should be uploaded and replaced by uri before /run
|
|
149
|
+
task = client.run({
|
|
150
|
+
"app": "some/app",
|
|
151
|
+
"input": {"image": str(file_path)},
|
|
152
|
+
"worker_selection_mode": "private",
|
|
153
|
+
})
|
|
154
|
+
|
|
155
|
+
# The mocked /run echoes input; ensure it is not the raw path anymore (upload replaced it)
|
|
156
|
+
assert task["input"]["image"] != str(file_path)
|
|
157
|
+
assert task["input"]["image"].startswith("https://cloud.inference.sh/")
|
|
158
|
+
|
|
159
|
+
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|