inferencesh 0.2.23__py3-none-any.whl → 0.4.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
inferencesh/client.py ADDED
@@ -0,0 +1,1081 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, Callable, Generator, Union, Iterator
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum
6
+ import json
7
+ import re
8
+ import time
9
+ import mimetypes
10
+ import os
11
+ from contextlib import AbstractContextManager
12
+ from typing import Protocol, runtime_checkable
13
+
14
+
15
+ class TaskStream(AbstractContextManager['TaskStream']):
16
+ """A context manager for streaming task updates.
17
+
18
+ This class provides a Pythonic interface for handling streaming updates from a task.
19
+ It can be used either as a context manager or as an iterator.
20
+
21
+ Example:
22
+ ```python
23
+ # As a context manager
24
+ with client.stream_task(task_id) as stream:
25
+ for update in stream:
26
+ print(f"Update: {update}")
27
+
28
+ # As an iterator
29
+ for update in client.stream_task(task_id):
30
+ print(f"Update: {update}")
31
+ ```
32
+ """
33
+ def __init__(
34
+ self,
35
+ task: Dict[str, Any],
36
+ client: Any,
37
+ auto_reconnect: bool = True,
38
+ max_reconnects: int = 5,
39
+ reconnect_delay_ms: int = 1000,
40
+ ):
41
+ self.task = task
42
+ self.client = client
43
+ self.task_id = task["id"]
44
+ self.auto_reconnect = auto_reconnect
45
+ self.max_reconnects = max_reconnects
46
+ self.reconnect_delay_ms = reconnect_delay_ms
47
+ self._final_task: Optional[Dict[str, Any]] = None
48
+ self._error: Optional[Exception] = None
49
+
50
+ def __enter__(self) -> 'TaskStream':
51
+ return self
52
+
53
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
54
+ pass
55
+
56
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
57
+ return self.stream()
58
+
59
+ @property
60
+ def result(self) -> Optional[Dict[str, Any]]:
61
+ """The final task result if completed, None otherwise."""
62
+ return self._final_task
63
+
64
+ @property
65
+ def error(self) -> Optional[Exception]:
66
+ """The error that occurred during streaming, if any."""
67
+ return self._error
68
+
69
+ def stream(self) -> Iterator[Dict[str, Any]]:
70
+ """Stream updates for this task.
71
+
72
+ Yields:
73
+ Dict[str, Any]: Task update events
74
+
75
+ Raises:
76
+ RuntimeError: If the task fails or is cancelled
77
+ """
78
+ try:
79
+ for update in self.client._stream_updates(
80
+ self.task_id,
81
+ self.task,
82
+ ):
83
+ if isinstance(update, Exception):
84
+ self._error = update
85
+ raise update
86
+ if update.get("status") == TaskStatus.COMPLETED:
87
+ self._final_task = update
88
+ yield update
89
+ except Exception as exc:
90
+ self._error = exc
91
+ raise
92
+
93
+
94
+ @runtime_checkable
95
+ class TaskCallback(Protocol):
96
+ """Protocol for task streaming callbacks."""
97
+ def on_update(self, data: Dict[str, Any]) -> None:
98
+ """Called when a task update is received."""
99
+ ...
100
+
101
+ def on_error(self, error: Exception) -> None:
102
+ """Called when an error occurs during task execution."""
103
+ ...
104
+
105
+ def on_complete(self, task: Dict[str, Any]) -> None:
106
+ """Called when a task completes successfully."""
107
+ ...
108
+
109
+
110
+ # Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
111
+ def _require_requests():
112
+ try:
113
+ import requests # type: ignore
114
+ return requests
115
+ except Exception as exc: # pragma: no cover - dependency hint
116
+ raise RuntimeError(
117
+ "The 'requests' package is required for synchronous HTTP calls. Install with: pip install requests"
118
+ ) from exc
119
+
120
+
121
+ async def _require_aiohttp():
122
+ try:
123
+ import aiohttp # type: ignore
124
+ return aiohttp
125
+ except Exception as exc: # pragma: no cover - dependency hint
126
+ raise RuntimeError(
127
+ "The 'aiohttp' package is required for async HTTP calls. Install with: pip install aiohttp"
128
+ ) from exc
129
+
130
+
131
+ class TaskStatus(IntEnum):
132
+ RECEIVED = 1
133
+ QUEUED = 2
134
+ SCHEDULED = 3
135
+ PREPARING = 4
136
+ SERVING = 5
137
+ SETTING_UP = 6
138
+ RUNNING = 7
139
+ UPLOADING = 8
140
+ COMPLETED = 9
141
+ FAILED = 10
142
+ CANCELLED = 11
143
+
144
+
145
+ Base64_RE = re.compile(r"^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$")
146
+
147
+
148
+ @dataclass
149
+ class UploadFileOptions:
150
+ filename: Optional[str] = None
151
+ content_type: Optional[str] = None
152
+ path: Optional[str] = None
153
+ public: Optional[bool] = None
154
+
155
+
156
+ class StreamManager:
157
+ """Simple SSE stream manager with optional auto-reconnect."""
158
+
159
+ def __init__(
160
+ self,
161
+ *,
162
+ create_event_source: Callable[[], Any],
163
+ auto_reconnect: bool = True,
164
+ max_reconnects: int = 5,
165
+ reconnect_delay_ms: int = 1000,
166
+ on_error: Optional[Callable[[Exception], None]] = None,
167
+ on_start: Optional[Callable[[], None]] = None,
168
+ on_stop: Optional[Callable[[], None]] = None,
169
+ on_data: Optional[Callable[[Dict[str, Any]], None]] = None,
170
+ on_partial_data: Optional[Callable[[Dict[str, Any], list[str]], None]] = None,
171
+ ) -> None:
172
+ self._create_event_source = create_event_source
173
+ self._auto_reconnect = auto_reconnect
174
+ self._max_reconnects = max_reconnects
175
+ self._reconnect_delay_ms = reconnect_delay_ms
176
+ self._on_error = on_error
177
+ self._on_start = on_start
178
+ self._on_stop = on_stop
179
+ self._on_data = on_data
180
+ self._on_partial_data = on_partial_data
181
+
182
+ self._stopped = False
183
+ self._reconnect_attempts = 0
184
+ self._had_successful_connection = False
185
+
186
+ def stop(self) -> None:
187
+ self._stopped = True
188
+ if self._on_stop:
189
+ self._on_stop()
190
+
191
+ def connect(self) -> None:
192
+ self._stopped = False
193
+ self._reconnect_attempts = 0
194
+ while not self._stopped:
195
+ try:
196
+ if self._on_start:
197
+ self._on_start()
198
+ event_source = self._create_event_source()
199
+ try:
200
+ for data in event_source:
201
+ if self._stopped:
202
+ break
203
+ self._had_successful_connection = True
204
+
205
+ # Handle generic messages through on_data callback
206
+ # Try parsing as {data: T, fields: []} structure first
207
+ print(f" {data}")
208
+ if (
209
+ isinstance(data, dict)
210
+ and "data" in data
211
+ and "fields" in data
212
+ and isinstance(data.get("fields"), list)
213
+ ):
214
+ # Partial data structure detected
215
+ if self._on_partial_data:
216
+ self._on_partial_data(data["data"], data["fields"])
217
+ elif self._on_data:
218
+ # Fall back to on_data with just the data if on_partial_data not provided
219
+ self._on_data(data["data"])
220
+ elif self._on_data:
221
+ # Otherwise treat the whole thing as data
222
+ self._on_data(data)
223
+
224
+ # Check again after processing in case callbacks stopped us
225
+ if self._stopped:
226
+ break
227
+ finally:
228
+ # Clean up the event source if it has a close method
229
+ try:
230
+ if hasattr(event_source, 'close'):
231
+ event_source.close()
232
+ except Exception:
233
+ raise
234
+
235
+ # If we're stopped or don't want to auto-reconnect, break immediately
236
+ if self._stopped or not self._auto_reconnect:
237
+ break
238
+ except Exception as exc: # noqa: BLE001
239
+ if self._on_error:
240
+ self._on_error(exc)
241
+ if self._stopped:
242
+ break
243
+ # If never connected and exceeded attempts, stop
244
+ if not self._had_successful_connection:
245
+ self._reconnect_attempts += 1
246
+ if self._reconnect_attempts > self._max_reconnects:
247
+ break
248
+ time.sleep(self._reconnect_delay_ms / 1000.0)
249
+ else:
250
+ # Completed without exception - if we want to auto-reconnect only after success
251
+ if not self._auto_reconnect:
252
+ break
253
+ time.sleep(self._reconnect_delay_ms / 1000.0)
254
+
255
+
256
+ class Inference:
257
+ """Synchronous client for inference.sh API, mirroring the JS SDK behavior.
258
+
259
+ Args:
260
+ api_key (str): The API key for authentication
261
+ base_url (Optional[str]): Override the default API base URL
262
+ sse_chunk_size (Optional[int]): Chunk size for SSE reading (default: 8192 bytes)
263
+ sse_mode (Optional[str]): SSE reading mode ('iter_lines' or 'raw', default: 'iter_lines')
264
+
265
+ The client supports performance tuning for SSE (Server-Sent Events) through:
266
+ 1. sse_chunk_size: Controls the buffer size for reading SSE data (default: 8KB)
267
+ - Larger values may improve performance but use more memory
268
+ - Can also be set via INFERENCE_SSE_READ_BYTES environment variable
269
+ 2. sse_mode: Controls how SSE data is read ('iter_lines' or 'raw')
270
+ - 'iter_lines': Uses requests' built-in line iteration (default)
271
+ - 'raw': Uses lower-level socket reading
272
+ - Can also be set via INFERENCE_SSE_MODE environment variable
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ *,
278
+ api_key: str,
279
+ base_url: Optional[str] = None,
280
+ sse_chunk_size: Optional[int] = None,
281
+ sse_mode: Optional[str] = None,
282
+ ) -> None:
283
+ self._api_key = api_key
284
+ self._base_url = base_url or "https://api.inference.sh"
285
+
286
+ # SSE configuration with environment variable fallbacks
287
+ self._sse_mode = sse_mode or os.getenv("INFERENCE_SSE_MODE") or "iter_lines"
288
+ self._sse_mode = self._sse_mode.lower()
289
+
290
+ # Default to 8KB chunks, can be overridden by parameter or env var
291
+ try:
292
+ env_chunk_size = os.getenv("INFERENCE_SSE_READ_BYTES")
293
+ if sse_chunk_size is not None:
294
+ self._sse_read_bytes = sse_chunk_size
295
+ elif env_chunk_size is not None:
296
+ self._sse_read_bytes = int(env_chunk_size)
297
+ else:
298
+ self._sse_read_bytes = 8192 # 8KB default
299
+ except Exception:
300
+ self._sse_read_bytes = 8192 # Default to 8KB chunks on error
301
+
302
+ # --------------- HTTP helpers ---------------
303
+ def _headers(self) -> Dict[str, str]:
304
+ return {
305
+ "Content-Type": "application/json",
306
+ "Authorization": f"Bearer {self._api_key}",
307
+ }
308
+
309
+ def _request(
310
+ self,
311
+ method: str,
312
+ endpoint: str,
313
+ *,
314
+ params: Optional[Dict[str, Any]] = None,
315
+ data: Optional[Dict[str, Any]] = None,
316
+ headers: Optional[Dict[str, str]] = None,
317
+ stream: bool = False,
318
+ timeout: Optional[float] = None,
319
+ ) -> Any:
320
+ requests = _require_requests()
321
+ url = f"{self._base_url}{endpoint}"
322
+ merged_headers = {**self._headers(), **(headers or {})}
323
+ resp = requests.request(
324
+ method=method.upper(),
325
+ url=url,
326
+ params=params,
327
+ data=json.dumps(data) if data is not None else None,
328
+ headers=merged_headers,
329
+ stream=stream,
330
+ timeout=timeout or 30,
331
+ )
332
+ if stream:
333
+ return resp
334
+ resp.raise_for_status()
335
+ payload = resp.json()
336
+ if not isinstance(payload, dict) or not payload.get("success", False):
337
+ message = None
338
+ if isinstance(payload, dict) and payload.get("error"):
339
+ err = payload["error"]
340
+ if isinstance(err, dict):
341
+ message = err.get("message")
342
+ else:
343
+ message = str(err)
344
+ raise RuntimeError(message or "Request failed")
345
+ return payload.get("data")
346
+
347
+ # --------------- Public API ---------------
348
+ def run(
349
+ self,
350
+ params: Dict[str, Any],
351
+ *,
352
+ wait: bool = True,
353
+ stream: bool = False,
354
+ auto_reconnect: bool = True,
355
+ max_reconnects: int = 5,
356
+ reconnect_delay_ms: int = 1000,
357
+ ) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
358
+ """Run a task with optional streaming updates.
359
+
360
+ By default, this method waits for the task to complete and returns the final result.
361
+ You can set wait=False to get just the task info, or stream=True to get an iterator
362
+ of status updates.
363
+
364
+ Args:
365
+ params: Task parameters to pass to the API
366
+ wait: Whether to wait for task completion (default: True)
367
+ stream: Whether to return an iterator of updates (default: False)
368
+ auto_reconnect: Whether to automatically reconnect on connection loss
369
+ max_reconnects: Maximum number of reconnection attempts
370
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
371
+
372
+ Returns:
373
+ Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
374
+ - If wait=True and stream=False: The completed task data
375
+ - If wait=False: The created task info
376
+ - If stream=True: An iterator of task updates
377
+
378
+ Example:
379
+ ```python
380
+ # Simple usage - wait for result (default)
381
+ result = client.run(params)
382
+ print(f"Output: {result['output']}")
383
+
384
+ # Get task info without waiting
385
+ task = client.run(params, wait=False)
386
+ task_id = task["id"]
387
+
388
+ # Stream updates
389
+ stream = client.run(params, stream=True)
390
+ for update in stream:
391
+ print(f"Status: {update.get('status')}")
392
+ if update.get('status') == TaskStatus.COMPLETED:
393
+ print(f"Result: {update.get('output')}")
394
+ ```
395
+ """
396
+ # Create the task
397
+ processed_input = self._process_input_data(params.get("input"))
398
+ task = self._request("post", "/run", data={**params, "input": processed_input})
399
+
400
+ # Return immediately if not waiting
401
+ if not wait and not stream:
402
+ return _strip_task(task)
403
+
404
+ # Return stream if requested
405
+ if stream:
406
+ task_stream = TaskStream(
407
+ task=task,
408
+ client=self,
409
+ auto_reconnect=auto_reconnect,
410
+ max_reconnects=max_reconnects,
411
+ reconnect_delay_ms=reconnect_delay_ms,
412
+ )
413
+ return task_stream
414
+
415
+ # Otherwise wait for completion
416
+ return self.wait_for_completion(task["id"])
417
+
418
+
419
+
420
+ def cancel(self, task_id: str) -> None:
421
+ self._request("post", f"/tasks/{task_id}/cancel")
422
+
423
+ def get_task(self, task_id: str) -> Dict[str, Any]:
424
+ """Get the current state of a task.
425
+
426
+ Args:
427
+ task_id: The ID of the task to get
428
+
429
+ Returns:
430
+ Dict[str, Any]: The current task state
431
+ """
432
+ return self._request("get", f"/tasks/{task_id}")
433
+
434
+ def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
435
+ """Wait for a task to complete and return its final state.
436
+
437
+ This method polls the task status until it reaches a terminal state
438
+ (completed, failed, or cancelled).
439
+
440
+ Args:
441
+ task_id: The ID of the task to wait for
442
+
443
+ Returns:
444
+ Dict[str, Any]: The final task state
445
+
446
+ Raises:
447
+ RuntimeError: If the task fails or is cancelled
448
+ """
449
+ with self.stream_task(task_id) as stream:
450
+ for update in stream:
451
+ if update.get("status") == TaskStatus.COMPLETED:
452
+ return update
453
+ elif update.get("status") == TaskStatus.FAILED:
454
+ raise RuntimeError(update.get("error") or "Task failed")
455
+ elif update.get("status") == TaskStatus.CANCELLED:
456
+ raise RuntimeError("Task cancelled")
457
+ raise RuntimeError("Stream ended without completion")
458
+
459
+ # --------------- File upload ---------------
460
+ def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
461
+ options = options or UploadFileOptions()
462
+ content_type = options.content_type
463
+ raw_bytes: bytes
464
+ if isinstance(data, bytes):
465
+ raw_bytes = data
466
+ if not content_type:
467
+ content_type = "application/octet-stream"
468
+ else:
469
+ # Prefer local filesystem path if it exists
470
+ if os.path.exists(data):
471
+ path = data
472
+ guessed = mimetypes.guess_type(path)[0]
473
+ content_type = content_type or guessed or "application/octet-stream"
474
+ with open(path, "rb") as f:
475
+ raw_bytes = f.read()
476
+ if not options.filename:
477
+ options.filename = os.path.basename(path)
478
+ elif data.startswith("data:"):
479
+ # data URI
480
+ match = re.match(r"^data:([^;]+);base64,(.+)$", data)
481
+ if not match:
482
+ raise ValueError("Invalid base64 data URI format")
483
+ content_type = content_type or match.group(1)
484
+ raw_bytes = _b64_to_bytes(match.group(2))
485
+ elif _looks_like_base64(data):
486
+ raw_bytes = _b64_to_bytes(data)
487
+ content_type = content_type or "application/octet-stream"
488
+ else:
489
+ raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
490
+
491
+ file_req = {
492
+ "files": [
493
+ {
494
+ "uri": "",
495
+ "filename": options.filename,
496
+ "content_type": content_type,
497
+ "path": options.path,
498
+ "size": len(raw_bytes),
499
+ "public": options.public,
500
+ }
501
+ ]
502
+ }
503
+
504
+ created = self._request("post", "/files", data=file_req)
505
+ file_obj = created[0]
506
+ upload_url = file_obj.get("upload_url")
507
+ if not upload_url:
508
+ raise RuntimeError("No upload URL provided by the server")
509
+
510
+ # Upload to S3 (or compatible) signed URL
511
+ requests = _require_requests()
512
+ put_resp = requests.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type})
513
+ if not (200 <= put_resp.status_code < 300):
514
+ raise RuntimeError(f"Failed to upload file content: {put_resp.reason}")
515
+ return file_obj
516
+
517
+ # --------------- Helpers ---------------
518
+ def stream_task(
519
+ self,
520
+ task_id: str,
521
+ *,
522
+ auto_reconnect: bool = True,
523
+ max_reconnects: int = 5,
524
+ reconnect_delay_ms: int = 1000,
525
+ ) -> TaskStream:
526
+ """Create a TaskStream for getting streaming updates from a task.
527
+
528
+ This provides a more Pythonic interface for handling task updates compared to callbacks.
529
+ The returned TaskStream can be used either as a context manager or as an iterator.
530
+
531
+ Args:
532
+ task_id: The ID of the task to stream
533
+ auto_reconnect: Whether to automatically reconnect on connection loss
534
+ max_reconnects: Maximum number of reconnection attempts
535
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
536
+
537
+ Returns:
538
+ TaskStream: A stream interface for the task
539
+
540
+ Example:
541
+ ```python
542
+ # Run a task
543
+ task = client.run(params)
544
+
545
+ # Stream updates using context manager
546
+ with client.stream_task(task["id"]) as stream:
547
+ for update in stream:
548
+ print(f"Status: {update.get('status')}")
549
+ if update.get("status") == TaskStatus.COMPLETED:
550
+ print(f"Result: {update.get('output')}")
551
+
552
+ # Or use as a simple iterator
553
+ for update in client.stream_task(task["id"]):
554
+ print(f"Update: {update}")
555
+ ```
556
+ """
557
+ task = self.get_task(task_id)
558
+ return TaskStream(
559
+ task=task,
560
+ client=self,
561
+ auto_reconnect=auto_reconnect,
562
+ max_reconnects=max_reconnects,
563
+ reconnect_delay_ms=reconnect_delay_ms,
564
+ )
565
+
566
+ def _stream_updates(
567
+ self,
568
+ task_id: str,
569
+ task: Dict[str, Any],
570
+ ) -> Generator[Union[Dict[str, Any], Exception], None, None]:
571
+ """Internal method to stream task updates."""
572
+ url = f"/tasks/{task_id}/stream"
573
+ resp = self._request(
574
+ "get",
575
+ url,
576
+ headers={
577
+ "Accept": "text/event-stream",
578
+ "Cache-Control": "no-cache",
579
+ "Accept-Encoding": "identity",
580
+ "Connection": "keep-alive",
581
+ },
582
+ stream=True,
583
+ timeout=60,
584
+ )
585
+ try:
586
+ for evt in self._iter_sse(resp):
587
+ try:
588
+ # Handle generic messages - try parsing as {data: T, fields: []} structure first
589
+ if (
590
+ isinstance(evt, dict)
591
+ and "data" in evt
592
+ and "fields" in evt
593
+ and isinstance(evt.get("fields"), list)
594
+ ):
595
+ # Partial data structure detected - extract just the data part
596
+ evt = evt["data"]
597
+
598
+ # Process the event to check for completion/errors
599
+ result = _process_stream_event(
600
+ evt,
601
+ task=task,
602
+ stopper=None, # We'll handle stopping via the iterator
603
+ )
604
+ if result is not None:
605
+ yield result
606
+ break
607
+ yield _strip_task(evt)
608
+ except Exception as exc:
609
+ yield exc
610
+ raise
611
+ finally:
612
+ try:
613
+ # Force close the underlying socket if possible
614
+ try:
615
+ raw = getattr(resp, 'raw', None)
616
+ if raw is not None:
617
+ raw.close()
618
+ except Exception:
619
+ raise
620
+ # Close the response
621
+ resp.close()
622
+ except Exception:
623
+ raise
624
+
625
+ def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
626
+ """Iterate JSON events from an SSE response."""
627
+ # Mode 1: raw socket readline (can reduce buffering in some environments)
628
+ if self._sse_mode == "raw":
629
+ raw = getattr(resp, "raw", None)
630
+ if raw is not None:
631
+ try:
632
+ # Avoid urllib3 decompression buffering
633
+ raw.decode_content = False # type: ignore[attr-defined]
634
+ except Exception:
635
+ raise
636
+ buf = bytearray()
637
+ read_size = max(1, int(self._sse_read_bytes))
638
+ while True:
639
+ # Check if we've been asked to stop before reading more data
640
+ try:
641
+ if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
642
+ break
643
+ except Exception:
644
+ raise
645
+
646
+ chunk = raw.read(read_size)
647
+ if not chunk:
648
+ break
649
+ for b in chunk:
650
+ if b == 10: # '\n'
651
+ try:
652
+ line = buf.decode(errors="ignore").rstrip("\r")
653
+ except Exception:
654
+ line = ""
655
+ buf.clear()
656
+ if not line:
657
+ continue
658
+ if line.startswith(":"):
659
+ continue
660
+ if line.startswith("data:"):
661
+ data_str = line[5:].lstrip()
662
+ if not data_str:
663
+ continue
664
+ try:
665
+ yield json.loads(data_str)
666
+ except json.JSONDecodeError:
667
+ continue
668
+ else:
669
+ buf.append(b)
670
+ return
671
+ # Mode 2: default iter_lines with reasonable chunk size (8KB)
672
+ for line in resp.iter_lines(decode_unicode=True, chunk_size=8192):
673
+ # Check if we've been asked to stop before processing any more lines
674
+ try:
675
+ if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
676
+ break
677
+ except Exception:
678
+ raise
679
+
680
+ if not line:
681
+ continue
682
+ if line.startswith(":"):
683
+ continue
684
+ if line.startswith("data:"):
685
+ data_str = line[5:].lstrip()
686
+ if not data_str:
687
+ continue
688
+ try:
689
+ yield json.loads(data_str)
690
+ except json.JSONDecodeError:
691
+ continue
692
+
693
+ def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
694
+ if input_value is None:
695
+ return input_value
696
+
697
+ # Handle lists
698
+ if isinstance(input_value, list):
699
+ return [self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
700
+
701
+ # Handle dicts
702
+ if isinstance(input_value, dict):
703
+ processed: Dict[str, Any] = {}
704
+ for key, value in input_value.items():
705
+ processed[key] = self._process_input_data(value, f"{path}.{key}")
706
+ return processed
707
+
708
+ # Handle strings that are filesystem paths, data URIs, or base64
709
+ if isinstance(input_value, str):
710
+ # Prefer existing local file paths first to avoid misclassifying plain strings
711
+ if os.path.exists(input_value):
712
+ file_obj = self.upload_file(input_value)
713
+ return file_obj.get("uri")
714
+ if input_value.startswith("data:") or _looks_like_base64(input_value):
715
+ file_obj = self.upload_file(input_value)
716
+ return file_obj.get("uri")
717
+ return input_value
718
+
719
+ # Handle File-like objects from our models
720
+ try:
721
+ from .models.file import File as SDKFile # local import to avoid cycle
722
+ if isinstance(input_value, SDKFile):
723
+ # Prefer local path if present, else uri
724
+ src = input_value.path or input_value.uri
725
+ if not src:
726
+ return input_value
727
+ file_obj = self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
728
+ return file_obj.get("uri")
729
+ except Exception:
730
+ raise
731
+
732
+ return input_value
733
+
734
+
735
+ class AsyncInference:
736
+ """Async client for inference.sh API, mirroring the JS SDK behavior."""
737
+
738
+ def __init__(self, *, api_key: str, base_url: Optional[str] = None) -> None:
739
+ self._api_key = api_key
740
+ self._base_url = base_url or "https://api.inference.sh"
741
+
742
+ # --------------- HTTP helpers ---------------
743
+ def _headers(self) -> Dict[str, str]:
744
+ return {
745
+ "Content-Type": "application/json",
746
+ "Authorization": f"Bearer {self._api_key}",
747
+ }
748
+
749
+ async def _request(
750
+ self,
751
+ method: str,
752
+ endpoint: str,
753
+ *,
754
+ params: Optional[Dict[str, Any]] = None,
755
+ data: Optional[Dict[str, Any]] = None,
756
+ headers: Optional[Dict[str, str]] = None,
757
+ timeout: Optional[float] = None,
758
+ expect_stream: bool = False,
759
+ ) -> Any:
760
+ aiohttp = await _require_aiohttp()
761
+ url = f"{self._base_url}{endpoint}"
762
+ merged_headers = {**self._headers(), **(headers or {})}
763
+ timeout_cfg = aiohttp.ClientTimeout(total=timeout or 30)
764
+ async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
765
+ async with session.request(
766
+ method=method.upper(),
767
+ url=url,
768
+ params=params,
769
+ json=data,
770
+ headers=merged_headers,
771
+ ) as resp:
772
+ if expect_stream:
773
+ return resp
774
+ payload = await resp.json()
775
+ if not isinstance(payload, dict) or not payload.get("success", False):
776
+ message = None
777
+ if isinstance(payload, dict) and payload.get("error"):
778
+ err = payload["error"]
779
+ if isinstance(err, dict):
780
+ message = err.get("message")
781
+ else:
782
+ message = str(err)
783
+ raise RuntimeError(message or "Request failed")
784
+ return payload.get("data")
785
+
786
+ # --------------- Public API ---------------
787
+ async def run(
788
+ self,
789
+ params: Dict[str, Any],
790
+ *,
791
+ wait: bool = True,
792
+ stream: bool = False,
793
+ auto_reconnect: bool = True,
794
+ max_reconnects: int = 5,
795
+ reconnect_delay_ms: int = 1000,
796
+ ) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
797
+ """Run a task with optional streaming updates.
798
+
799
+ By default, this method waits for the task to complete and returns the final result.
800
+ You can set wait=False to get just the task info, or stream=True to get an iterator
801
+ of status updates.
802
+
803
+ Args:
804
+ params: Task parameters to pass to the API
805
+ wait: Whether to wait for task completion (default: True)
806
+ stream: Whether to return an iterator of updates (default: False)
807
+ auto_reconnect: Whether to automatically reconnect on connection loss
808
+ max_reconnects: Maximum number of reconnection attempts
809
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
810
+
811
+ Returns:
812
+ Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
813
+ - If wait=True and stream=False: The completed task data
814
+ - If wait=False: The created task info
815
+ - If stream=True: An iterator of task updates
816
+
817
+ Example:
818
+ ```python
819
+ # Simple usage - wait for result (default)
820
+ result = await client.run(params)
821
+ print(f"Output: {result['output']}")
822
+
823
+ # Get task info without waiting
824
+ task = await client.run(params, wait=False)
825
+ task_id = task["id"]
826
+
827
+ # Stream updates
828
+ stream = await client.run(params, stream=True)
829
+ for update in stream:
830
+ print(f"Status: {update.get('status')}")
831
+ if update.get('status') == TaskStatus.COMPLETED:
832
+ print(f"Result: {update.get('output')}")
833
+ ```
834
+ """
835
+ # Create the task
836
+ processed_input = await self._process_input_data(params.get("input"))
837
+ task = await self._request("post", "/run", data={**params, "input": processed_input})
838
+
839
+ # Return immediately if not waiting
840
+ if not wait and not stream:
841
+ return task
842
+
843
+ # Return stream if requested
844
+ if stream:
845
+ task_stream = TaskStream(
846
+ task=task,
847
+ client=self,
848
+ auto_reconnect=auto_reconnect,
849
+ max_reconnects=max_reconnects,
850
+ reconnect_delay_ms=reconnect_delay_ms,
851
+ )
852
+ return task_stream
853
+
854
+ # Otherwise wait for completion
855
+ return await self.wait_for_completion(task["id"])
856
+
857
+ async def cancel(self, task_id: str) -> None:
858
+ await self._request("post", f"/tasks/{task_id}/cancel")
859
+
860
+ async def get_task(self, task_id: str) -> Dict[str, Any]:
861
+ """Get the current state of a task.
862
+
863
+ Args:
864
+ task_id: The ID of the task to get
865
+
866
+ Returns:
867
+ Dict[str, Any]: The current task state
868
+ """
869
+ return await self._request("get", f"/tasks/{task_id}")
870
+
871
+ async def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
872
+ """Wait for a task to complete and return its final state.
873
+
874
+ This method polls the task status until it reaches a terminal state
875
+ (completed, failed, or cancelled).
876
+
877
+ Args:
878
+ task_id: The ID of the task to wait for
879
+
880
+ Returns:
881
+ Dict[str, Any]: The final task state
882
+
883
+ Raises:
884
+ RuntimeError: If the task fails or is cancelled
885
+ """
886
+ with self.stream_task(task_id) as stream:
887
+ async for update in stream:
888
+ if update.get("status") == TaskStatus.COMPLETED:
889
+ return update
890
+ elif update.get("status") == TaskStatus.FAILED:
891
+ raise RuntimeError(update.get("error") or "Task failed")
892
+ elif update.get("status") == TaskStatus.CANCELLED:
893
+ raise RuntimeError("Task cancelled")
894
+ raise RuntimeError("Stream ended without completion")
895
+
896
+ # --------------- File upload ---------------
897
+ async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
898
+ options = options or UploadFileOptions()
899
+ content_type = options.content_type
900
+ raw_bytes: bytes
901
+ if isinstance(data, bytes):
902
+ raw_bytes = data
903
+ if not content_type:
904
+ content_type = "application/octet-stream"
905
+ else:
906
+ if os.path.exists(data):
907
+ path = data
908
+ guessed = mimetypes.guess_type(path)[0]
909
+ content_type = content_type or guessed or "application/octet-stream"
910
+ async with await _aio_open_file(path, "rb") as f:
911
+ raw_bytes = await f.read() # type: ignore[attr-defined]
912
+ if not options.filename:
913
+ options.filename = os.path.basename(path)
914
+ elif data.startswith("data:"):
915
+ match = re.match(r"^data:([^;]+);base64,(.+)$", data)
916
+ if not match:
917
+ raise ValueError("Invalid base64 data URI format")
918
+ content_type = content_type or match.group(1)
919
+ raw_bytes = _b64_to_bytes(match.group(2))
920
+ elif _looks_like_base64(data):
921
+ raw_bytes = _b64_to_bytes(data)
922
+ content_type = content_type or "application/octet-stream"
923
+ else:
924
+ raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
925
+
926
+ file_req = {
927
+ "files": [
928
+ {
929
+ "uri": "",
930
+ "filename": options.filename,
931
+ "content_type": content_type,
932
+ "path": options.path,
933
+ "size": len(raw_bytes),
934
+ "public": options.public,
935
+ }
936
+ ]
937
+ }
938
+
939
+ created = await self._request("post", "/files", data=file_req)
940
+ file_obj = created[0]
941
+ upload_url = file_obj.get("upload_url")
942
+ if not upload_url:
943
+ raise RuntimeError("No upload URL provided by the server")
944
+
945
+ aiohttp = await _require_aiohttp()
946
+ timeout_cfg = aiohttp.ClientTimeout(total=60)
947
+ async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
948
+ async with session.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type}) as resp:
949
+ if resp.status // 100 != 2:
950
+ raise RuntimeError(f"Failed to upload file content: {resp.reason}")
951
+ return file_obj
952
+
953
+ # --------------- Helpers ---------------
954
+ async def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
955
+ if input_value is None:
956
+ return input_value
957
+
958
+ if isinstance(input_value, list):
959
+ return [await self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
960
+
961
+ if isinstance(input_value, dict):
962
+ processed: Dict[str, Any] = {}
963
+ for key, value in input_value.items():
964
+ processed[key] = await self._process_input_data(value, f"{path}.{key}")
965
+ return processed
966
+
967
+ if isinstance(input_value, str):
968
+ if os.path.exists(input_value):
969
+ file_obj = await self.upload_file(input_value)
970
+ return file_obj.get("uri")
971
+ if input_value.startswith("data:") or _looks_like_base64(input_value):
972
+ file_obj = await self.upload_file(input_value)
973
+ return file_obj.get("uri")
974
+ return input_value
975
+
976
+ try:
977
+ from .models.file import File as SDKFile # local import
978
+ if isinstance(input_value, SDKFile):
979
+ src = input_value.path or input_value.uri
980
+ if not src:
981
+ return input_value
982
+ file_obj = await self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
983
+ return file_obj.get("uri")
984
+ except Exception:
985
+ raise
986
+
987
+ return input_value
988
+
989
+ async def _aiter_sse(self, resp: Any) -> Generator[Dict[str, Any], None, None]:
990
+ async for raw_line in resp.content: # type: ignore[attr-defined]
991
+ try:
992
+ line = raw_line.decode().rstrip("\n")
993
+ except Exception:
994
+ continue
995
+ if not line:
996
+ continue
997
+ if line.startswith(":"):
998
+ continue
999
+ if line.startswith("data:"):
1000
+ data_str = line[5:].lstrip()
1001
+ if not data_str:
1002
+ continue
1003
+ try:
1004
+ yield json.loads(data_str)
1005
+ except json.JSONDecodeError:
1006
+ continue
1007
+
1008
+
1009
+ # --------------- small async utilities ---------------
1010
+ async def _async_sleep(seconds: float) -> None:
1011
+ import asyncio
1012
+
1013
+ await asyncio.sleep(seconds)
1014
+
1015
+
1016
+ def _b64_to_bytes(b64: str) -> bytes:
1017
+ import base64
1018
+
1019
+ return base64.b64decode(b64)
1020
+
1021
+
1022
+ async def _aio_open_file(path: str, mode: str):
1023
+ import aiofiles # type: ignore
1024
+
1025
+ return await aiofiles.open(path, mode)
1026
+
1027
+
1028
+ def _looks_like_base64(value: str) -> bool:
1029
+ # Reject very short strings to avoid matching normal words like "hi"
1030
+ if len(value) < 16:
1031
+ return False
1032
+ # Quick charset check
1033
+ if not Base64_RE.match(value):
1034
+ return False
1035
+ # Must be divisible by 4
1036
+ if len(value) % 4 != 0:
1037
+ return False
1038
+ # Try decode to be sure
1039
+ try:
1040
+ _ = _b64_to_bytes(value)
1041
+ return True
1042
+ except Exception:
1043
+ return False
1044
+
1045
+
1046
+ def _strip_task(task: Dict[str, Any]) -> Dict[str, Any]:
1047
+ """Strip task to essential fields."""
1048
+ return {
1049
+ "id": task.get("id"),
1050
+ "created_at": task.get("created_at"),
1051
+ "updated_at": task.get("updated_at"),
1052
+ "input": task.get("input"),
1053
+ "output": task.get("output"),
1054
+ "logs": task.get("logs"),
1055
+ "status": task.get("status"),
1056
+ }
1057
+
1058
+ def _process_stream_event(
1059
+ data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
1060
+ ) -> Optional[Dict[str, Any]]:
1061
+ """Shared handler for SSE task events. Returns final task dict when completed, else None.
1062
+ If stopper is provided, it will be called on terminal events to end streaming.
1063
+ """
1064
+ status = data.get("status")
1065
+
1066
+ if status == TaskStatus.COMPLETED:
1067
+ result = _strip_task(data)
1068
+ if stopper:
1069
+ stopper()
1070
+ return result
1071
+ if status == TaskStatus.FAILED:
1072
+ if stopper:
1073
+ stopper()
1074
+ raise RuntimeError(data.get("error") or "task failed")
1075
+ if status == TaskStatus.CANCELLED:
1076
+ if stopper:
1077
+ stopper()
1078
+ raise RuntimeError("task cancelled")
1079
+ return None
1080
+
1081
+