inferencesh 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of inferencesh might be problematic. Click here for more details.

inferencesh/client.py ADDED
@@ -0,0 +1,830 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, Callable, Generator, Union
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum
6
+ import json
7
+ import re
8
+ import time
9
+ import mimetypes
10
+ import os
11
+
12
+
13
+ # Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
14
+ def _require_requests():
15
+ try:
16
+ import requests # type: ignore
17
+ return requests
18
+ except Exception as exc: # pragma: no cover - dependency hint
19
+ raise RuntimeError(
20
+ "The 'requests' package is required for synchronous HTTP calls. Install with: pip install requests"
21
+ ) from exc
22
+
23
+
24
+ async def _require_aiohttp():
25
+ try:
26
+ import aiohttp # type: ignore
27
+ return aiohttp
28
+ except Exception as exc: # pragma: no cover - dependency hint
29
+ raise RuntimeError(
30
+ "The 'aiohttp' package is required for async HTTP calls. Install with: pip install aiohttp"
31
+ ) from exc
32
+
33
+
34
+ class TaskStatus(IntEnum):
35
+ RECEIVED = 1
36
+ QUEUED = 2
37
+ SCHEDULED = 3
38
+ PREPARING = 4
39
+ SERVING = 5
40
+ SETTING_UP = 6
41
+ RUNNING = 7
42
+ UPLOADING = 8
43
+ COMPLETED = 9
44
+ FAILED = 10
45
+ CANCELLED = 11
46
+
47
+
48
+ Base64_RE = re.compile(r"^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)?$")
49
+
50
+
51
+ @dataclass
52
+ class UploadFileOptions:
53
+ filename: Optional[str] = None
54
+ content_type: Optional[str] = None
55
+ path: Optional[str] = None
56
+ public: Optional[bool] = None
57
+
58
+
59
+ class StreamManager:
60
+ """Simple SSE stream manager with optional auto-reconnect."""
61
+
62
+ def __init__(
63
+ self,
64
+ *,
65
+ create_event_source: Callable[[], Any],
66
+ auto_reconnect: bool = True,
67
+ max_reconnects: int = 5,
68
+ reconnect_delay_ms: int = 1000,
69
+ on_error: Optional[Callable[[Exception], None]] = None,
70
+ on_start: Optional[Callable[[], None]] = None,
71
+ on_stop: Optional[Callable[[], None]] = None,
72
+ on_data: Optional[Callable[[Dict[str, Any]], None]] = None,
73
+ ) -> None:
74
+ self._create_event_source = create_event_source
75
+ self._auto_reconnect = auto_reconnect
76
+ self._max_reconnects = max_reconnects
77
+ self._reconnect_delay_ms = reconnect_delay_ms
78
+ self._on_error = on_error
79
+ self._on_start = on_start
80
+ self._on_stop = on_stop
81
+ self._on_data = on_data
82
+
83
+ self._stopped = False
84
+ self._reconnect_attempts = 0
85
+ self._had_successful_connection = False
86
+
87
+ def stop(self) -> None:
88
+ self._stopped = True
89
+ if self._on_stop:
90
+ self._on_stop()
91
+
92
+ def connect(self) -> None:
93
+ self._stopped = False
94
+ self._reconnect_attempts = 0
95
+ while not self._stopped:
96
+ try:
97
+ if self._on_start:
98
+ self._on_start()
99
+ event_source = self._create_event_source()
100
+ try:
101
+ for data in event_source:
102
+ if self._stopped:
103
+ break
104
+ self._had_successful_connection = True
105
+ if self._on_data:
106
+ self._on_data(data)
107
+ # Check again after processing in case on_data stopped us
108
+ if self._stopped:
109
+ break
110
+ finally:
111
+ # Clean up the event source if it has a close method
112
+ try:
113
+ if hasattr(event_source, 'close'):
114
+ event_source.close()
115
+ except Exception:
116
+ raise
117
+
118
+ # If we're stopped or don't want to auto-reconnect, break immediately
119
+ if self._stopped or not self._auto_reconnect:
120
+ break
121
+ except Exception as exc: # noqa: BLE001
122
+ if self._on_error:
123
+ self._on_error(exc)
124
+ if self._stopped:
125
+ break
126
+ # If never connected and exceeded attempts, stop
127
+ if not self._had_successful_connection:
128
+ self._reconnect_attempts += 1
129
+ if self._reconnect_attempts > self._max_reconnects:
130
+ break
131
+ time.sleep(self._reconnect_delay_ms / 1000.0)
132
+ else:
133
+ # Completed without exception - if we want to auto-reconnect only after success
134
+ if not self._auto_reconnect:
135
+ break
136
+ time.sleep(self._reconnect_delay_ms / 1000.0)
137
+
138
+
139
+ class Inference:
140
+ """Synchronous client for inference.sh API, mirroring the JS SDK behavior.
141
+
142
+ Args:
143
+ api_key (str): The API key for authentication
144
+ base_url (Optional[str]): Override the default API base URL
145
+ sse_chunk_size (Optional[int]): Chunk size for SSE reading (default: 8192 bytes)
146
+ sse_mode (Optional[str]): SSE reading mode ('iter_lines' or 'raw', default: 'iter_lines')
147
+
148
+ The client supports performance tuning for SSE (Server-Sent Events) through:
149
+ 1. sse_chunk_size: Controls the buffer size for reading SSE data (default: 8KB)
150
+ - Larger values may improve performance but use more memory
151
+ - Can also be set via INFERENCE_SSE_READ_BYTES environment variable
152
+ 2. sse_mode: Controls how SSE data is read ('iter_lines' or 'raw')
153
+ - 'iter_lines': Uses requests' built-in line iteration (default)
154
+ - 'raw': Uses lower-level socket reading
155
+ - Can also be set via INFERENCE_SSE_MODE environment variable
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ *,
161
+ api_key: str,
162
+ base_url: Optional[str] = None,
163
+ sse_chunk_size: Optional[int] = None,
164
+ sse_mode: Optional[str] = None,
165
+ ) -> None:
166
+ self._api_key = api_key
167
+ self._base_url = base_url or "https://api.inference.sh"
168
+
169
+ # SSE configuration with environment variable fallbacks
170
+ self._sse_mode = sse_mode or os.getenv("INFERENCE_SSE_MODE") or "iter_lines"
171
+ self._sse_mode = self._sse_mode.lower()
172
+
173
+ # Default to 8KB chunks, can be overridden by parameter or env var
174
+ try:
175
+ env_chunk_size = os.getenv("INFERENCE_SSE_READ_BYTES")
176
+ if sse_chunk_size is not None:
177
+ self._sse_read_bytes = sse_chunk_size
178
+ elif env_chunk_size is not None:
179
+ self._sse_read_bytes = int(env_chunk_size)
180
+ else:
181
+ self._sse_read_bytes = 8192 # 8KB default
182
+ except Exception:
183
+ self._sse_read_bytes = 8192 # Default to 8KB chunks on error
184
+
185
+ # --------------- HTTP helpers ---------------
186
+ def _headers(self) -> Dict[str, str]:
187
+ return {
188
+ "Content-Type": "application/json",
189
+ "Authorization": f"Bearer {self._api_key}",
190
+ }
191
+
192
+ def _request(
193
+ self,
194
+ method: str,
195
+ endpoint: str,
196
+ *,
197
+ params: Optional[Dict[str, Any]] = None,
198
+ data: Optional[Dict[str, Any]] = None,
199
+ headers: Optional[Dict[str, str]] = None,
200
+ stream: bool = False,
201
+ timeout: Optional[float] = None,
202
+ ) -> Any:
203
+ requests = _require_requests()
204
+ url = f"{self._base_url}{endpoint}"
205
+ merged_headers = {**self._headers(), **(headers or {})}
206
+ resp = requests.request(
207
+ method=method.upper(),
208
+ url=url,
209
+ params=params,
210
+ data=json.dumps(data) if data is not None else None,
211
+ headers=merged_headers,
212
+ stream=stream,
213
+ timeout=timeout or 30,
214
+ )
215
+ if stream:
216
+ return resp
217
+ resp.raise_for_status()
218
+ payload = resp.json()
219
+ if not isinstance(payload, dict) or not payload.get("success", False):
220
+ message = None
221
+ if isinstance(payload, dict) and payload.get("error"):
222
+ err = payload["error"]
223
+ if isinstance(err, dict):
224
+ message = err.get("message")
225
+ else:
226
+ message = str(err)
227
+ raise RuntimeError(message or "Request failed")
228
+ return payload.get("data")
229
+
230
+ # --------------- Public API ---------------
231
+ def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
232
+ processed_input = self._process_input_data(params.get("input"))
233
+ task = self._request("post", "/run", data={**params, "input": processed_input})
234
+ return task
235
+
236
+ def run_sync(
237
+ self,
238
+ params: Dict[str, Any],
239
+ *,
240
+ auto_reconnect: bool = True,
241
+ max_reconnects: int = 5,
242
+ reconnect_delay_ms: int = 1000,
243
+ ) -> Dict[str, Any]:
244
+ processed_input = self._process_input_data(params.get("input"))
245
+ task = self._request("post", "/run", data={**params, "input": processed_input})
246
+ task_id = task["id"]
247
+
248
+ final_task: Optional[Dict[str, Any]] = None
249
+
250
+ def on_data(data: Dict[str, Any]) -> None:
251
+ nonlocal final_task
252
+ try:
253
+ result = _process_stream_event(
254
+ data,
255
+ task=task,
256
+ stopper=lambda: manager.stop(),
257
+ )
258
+ if result is not None:
259
+ final_task = result
260
+ except Exception as exc:
261
+ raise
262
+
263
+ def on_error(exc: Exception) -> None:
264
+ raise exc
265
+
266
+ def on_start() -> None:
267
+ pass
268
+
269
+ def on_stop() -> None:
270
+ pass
271
+
272
+ manager = StreamManager(
273
+ create_event_source=None, # We'll set this after defining it
274
+ auto_reconnect=auto_reconnect,
275
+ max_reconnects=max_reconnects,
276
+ reconnect_delay_ms=reconnect_delay_ms,
277
+ on_data=on_data,
278
+ on_error=on_error,
279
+ on_start=on_start,
280
+ on_stop=on_stop,
281
+ )
282
+
283
+ def create_event_source() -> Generator[Dict[str, Any], None, None]:
284
+ url = f"/tasks/{task_id}/stream"
285
+ resp = self._request(
286
+ "get",
287
+ url,
288
+ headers={
289
+ "Accept": "text/event-stream",
290
+ "Cache-Control": "no-cache",
291
+ "Accept-Encoding": "identity",
292
+ "Connection": "keep-alive",
293
+ },
294
+ stream=True,
295
+ timeout=60,
296
+ )
297
+
298
+ try:
299
+ last_event_at = time.perf_counter()
300
+ for evt in self._iter_sse(resp, stream_manager=manager):
301
+ yield evt
302
+ finally:
303
+ try:
304
+ # Force close the underlying socket if possible
305
+ try:
306
+ raw = getattr(resp, 'raw', None)
307
+ if raw is not None:
308
+ raw.close()
309
+ except Exception:
310
+ raise
311
+ # Close the response
312
+ resp.close()
313
+ except Exception:
314
+ raise
315
+
316
+ # Update the create_event_source function in the manager
317
+ manager._create_event_source = create_event_source
318
+
319
+ # Connect and wait for completion
320
+ manager.connect()
321
+
322
+ # At this point, we should have a final task state
323
+ if final_task is not None:
324
+ return final_task
325
+
326
+ # Try to fetch the latest state as a fallback
327
+ try:
328
+ latest = self.get_task(task_id)
329
+ status = latest.get("status")
330
+ if status == TaskStatus.COMPLETED:
331
+ return latest
332
+ if status == TaskStatus.FAILED:
333
+ raise RuntimeError(latest.get("error") or "task failed")
334
+ if status == TaskStatus.CANCELLED:
335
+ raise RuntimeError("task cancelled")
336
+ except Exception as exc:
337
+ raise
338
+
339
+ raise RuntimeError("Stream ended without completion")
340
+
341
+ def cancel(self, task_id: str) -> None:
342
+ self._request("post", f"/tasks/{task_id}/cancel")
343
+
344
+ def get_task(self, task_id: str) -> Dict[str, Any]:
345
+ return self._request("get", f"/tasks/{task_id}")
346
+
347
+ # --------------- File upload ---------------
348
+ def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
349
+ options = options or UploadFileOptions()
350
+ content_type = options.content_type
351
+ raw_bytes: bytes
352
+ if isinstance(data, bytes):
353
+ raw_bytes = data
354
+ if not content_type:
355
+ content_type = "application/octet-stream"
356
+ else:
357
+ # Prefer local filesystem path if it exists
358
+ if os.path.exists(data):
359
+ path = data
360
+ guessed = mimetypes.guess_type(path)[0]
361
+ content_type = content_type or guessed or "application/octet-stream"
362
+ with open(path, "rb") as f:
363
+ raw_bytes = f.read()
364
+ if not options.filename:
365
+ options.filename = os.path.basename(path)
366
+ elif data.startswith("data:"):
367
+ # data URI
368
+ match = re.match(r"^data:([^;]+);base64,(.+)$", data)
369
+ if not match:
370
+ raise ValueError("Invalid base64 data URI format")
371
+ content_type = content_type or match.group(1)
372
+ raw_bytes = _b64_to_bytes(match.group(2))
373
+ elif _looks_like_base64(data):
374
+ raw_bytes = _b64_to_bytes(data)
375
+ content_type = content_type or "application/octet-stream"
376
+ else:
377
+ raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
378
+
379
+ file_req = {
380
+ "files": [
381
+ {
382
+ "uri": "",
383
+ "filename": options.filename,
384
+ "content_type": content_type,
385
+ "path": options.path,
386
+ "size": len(raw_bytes),
387
+ "public": options.public,
388
+ }
389
+ ]
390
+ }
391
+
392
+ created = self._request("post", "/files", data=file_req)
393
+ file_obj = created[0]
394
+ upload_url = file_obj.get("upload_url")
395
+ if not upload_url:
396
+ raise RuntimeError("No upload URL provided by the server")
397
+
398
+ # Upload to S3 (or compatible) signed URL
399
+ requests = _require_requests()
400
+ put_resp = requests.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type})
401
+ if not (200 <= put_resp.status_code < 300):
402
+ raise RuntimeError(f"Failed to upload file content: {put_resp.reason}")
403
+ return file_obj
404
+
405
+ # --------------- Helpers ---------------
406
+ def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
407
+ """Iterate JSON events from an SSE response."""
408
+ # Mode 1: raw socket readline (can reduce buffering in some environments)
409
+ if self._sse_mode == "raw":
410
+ raw = getattr(resp, "raw", None)
411
+ if raw is not None:
412
+ try:
413
+ # Avoid urllib3 decompression buffering
414
+ raw.decode_content = False # type: ignore[attr-defined]
415
+ except Exception:
416
+ raise
417
+ buf = bytearray()
418
+ read_size = max(1, int(self._sse_read_bytes))
419
+ while True:
420
+ # Check if we've been asked to stop before reading more data
421
+ try:
422
+ if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
423
+ break
424
+ except Exception:
425
+ raise
426
+
427
+ chunk = raw.read(read_size)
428
+ if not chunk:
429
+ break
430
+ for b in chunk:
431
+ if b == 10: # '\n'
432
+ try:
433
+ line = buf.decode(errors="ignore").rstrip("\r")
434
+ except Exception:
435
+ line = ""
436
+ buf.clear()
437
+ if not line:
438
+ continue
439
+ if line.startswith(":"):
440
+ continue
441
+ if line.startswith("data:"):
442
+ data_str = line[5:].lstrip()
443
+ if not data_str:
444
+ continue
445
+ try:
446
+ yield json.loads(data_str)
447
+ except json.JSONDecodeError:
448
+ continue
449
+ else:
450
+ buf.append(b)
451
+ return
452
+ # Mode 2: default iter_lines with reasonable chunk size (8KB)
453
+ for line in resp.iter_lines(decode_unicode=True, chunk_size=8192):
454
+ # Check if we've been asked to stop before processing any more lines
455
+ try:
456
+ if stream_manager and stream_manager._stopped: # type: ignore[attr-defined]
457
+ break
458
+ except Exception:
459
+ raise
460
+
461
+ if not line:
462
+ continue
463
+ if line.startswith(":"):
464
+ continue
465
+ if line.startswith("data:"):
466
+ data_str = line[5:].lstrip()
467
+ if not data_str:
468
+ continue
469
+ try:
470
+ yield json.loads(data_str)
471
+ except json.JSONDecodeError:
472
+ continue
473
+
474
+ def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
475
+ if input_value is None:
476
+ return input_value
477
+
478
+ # Handle lists
479
+ if isinstance(input_value, list):
480
+ return [self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
481
+
482
+ # Handle dicts
483
+ if isinstance(input_value, dict):
484
+ processed: Dict[str, Any] = {}
485
+ for key, value in input_value.items():
486
+ processed[key] = self._process_input_data(value, f"{path}.{key}")
487
+ return processed
488
+
489
+ # Handle strings that are filesystem paths, data URIs, or base64
490
+ if isinstance(input_value, str):
491
+ # Prefer existing local file paths first to avoid misclassifying plain strings
492
+ if os.path.exists(input_value):
493
+ file_obj = self.upload_file(input_value)
494
+ return file_obj.get("uri")
495
+ if input_value.startswith("data:") or _looks_like_base64(input_value):
496
+ file_obj = self.upload_file(input_value)
497
+ return file_obj.get("uri")
498
+ return input_value
499
+
500
+ # Handle File-like objects from our models
501
+ try:
502
+ from .models.file import File as SDKFile # local import to avoid cycle
503
+ if isinstance(input_value, SDKFile):
504
+ # Prefer local path if present, else uri
505
+ src = input_value.path or input_value.uri
506
+ if not src:
507
+ return input_value
508
+ file_obj = self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
509
+ return file_obj.get("uri")
510
+ except Exception:
511
+ raise
512
+
513
+ return input_value
514
+
515
+
516
+ class AsyncInference:
517
+ """Async client for inference.sh API, mirroring the JS SDK behavior."""
518
+
519
+ def __init__(self, *, api_key: str, base_url: Optional[str] = None) -> None:
520
+ self._api_key = api_key
521
+ self._base_url = base_url or "https://api.inference.sh"
522
+
523
+ # --------------- HTTP helpers ---------------
524
+ def _headers(self) -> Dict[str, str]:
525
+ return {
526
+ "Content-Type": "application/json",
527
+ "Authorization": f"Bearer {self._api_key}",
528
+ }
529
+
530
+ async def _request(
531
+ self,
532
+ method: str,
533
+ endpoint: str,
534
+ *,
535
+ params: Optional[Dict[str, Any]] = None,
536
+ data: Optional[Dict[str, Any]] = None,
537
+ headers: Optional[Dict[str, str]] = None,
538
+ timeout: Optional[float] = None,
539
+ expect_stream: bool = False,
540
+ ) -> Any:
541
+ aiohttp = await _require_aiohttp()
542
+ url = f"{self._base_url}{endpoint}"
543
+ merged_headers = {**self._headers(), **(headers or {})}
544
+ timeout_cfg = aiohttp.ClientTimeout(total=timeout or 30)
545
+ async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
546
+ async with session.request(
547
+ method=method.upper(),
548
+ url=url,
549
+ params=params,
550
+ json=data,
551
+ headers=merged_headers,
552
+ ) as resp:
553
+ if expect_stream:
554
+ return resp
555
+ payload = await resp.json()
556
+ if not isinstance(payload, dict) or not payload.get("success", False):
557
+ message = None
558
+ if isinstance(payload, dict) and payload.get("error"):
559
+ err = payload["error"]
560
+ if isinstance(err, dict):
561
+ message = err.get("message")
562
+ else:
563
+ message = str(err)
564
+ raise RuntimeError(message or "Request failed")
565
+ return payload.get("data")
566
+
567
+ # --------------- Public API ---------------
568
+ async def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
569
+ processed_input = await self._process_input_data(params.get("input"))
570
+ task = await self._request("post", "/run", data={**params, "input": processed_input})
571
+ return task
572
+
573
+ async def run_sync(
574
+ self,
575
+ params: Dict[str, Any],
576
+ *,
577
+ auto_reconnect: bool = True,
578
+ max_reconnects: int = 5,
579
+ reconnect_delay_ms: int = 1000,
580
+ ) -> Dict[str, Any]:
581
+ processed_input = await self._process_input_data(params.get("input"))
582
+ task = await self._request("post", "/run", data={**params, "input": processed_input})
583
+ task_id = task["id"]
584
+
585
+ final_task: Optional[Dict[str, Any]] = None
586
+ reconnect_attempts = 0
587
+ had_success = False
588
+
589
+ while True:
590
+ try:
591
+ resp = await self._request(
592
+ "get",
593
+ f"/tasks/{task_id}/stream",
594
+ headers={
595
+ "Accept": "text/event-stream",
596
+ "Cache-Control": "no-cache",
597
+ "Accept-Encoding": "identity",
598
+ "Connection": "keep-alive",
599
+ },
600
+ timeout=60,
601
+ expect_stream=True,
602
+ )
603
+ had_success = True
604
+ async for data in self._aiter_sse(resp):
605
+ result = _process_stream_event(
606
+ data,
607
+ task=task,
608
+ stopper=None,
609
+ )
610
+ if result is not None:
611
+ final_task = result
612
+ break
613
+ if final_task is not None:
614
+ break
615
+ except Exception as exc: # noqa: BLE001
616
+ if not auto_reconnect:
617
+ raise
618
+ if not had_success:
619
+ reconnect_attempts += 1
620
+ if reconnect_attempts > max_reconnects:
621
+ raise
622
+ await _async_sleep(reconnect_delay_ms / 1000.0)
623
+ else:
624
+ if not auto_reconnect:
625
+ break
626
+ await _async_sleep(reconnect_delay_ms / 1000.0)
627
+
628
+ if final_task is None:
629
+ # Fallback: fetch latest task state in case stream ended without a terminal event
630
+ try:
631
+ latest = await self.get_task(task_id)
632
+ status = latest.get("status")
633
+ if status == TaskStatus.COMPLETED:
634
+ return latest
635
+ if status == TaskStatus.FAILED:
636
+ raise RuntimeError(latest.get("error") or "task failed")
637
+ if status == TaskStatus.CANCELLED:
638
+ raise RuntimeError("task cancelled")
639
+ except Exception:
640
+ raise
641
+ raise RuntimeError("Stream ended without completion")
642
+ return final_task
643
+
644
+ async def cancel(self, task_id: str) -> None:
645
+ await self._request("post", f"/tasks/{task_id}/cancel")
646
+
647
+ async def get_task(self, task_id: str) -> Dict[str, Any]:
648
+ return await self._request("get", f"/tasks/{task_id}")
649
+
650
+ # --------------- File upload ---------------
651
+ async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
652
+ options = options or UploadFileOptions()
653
+ content_type = options.content_type
654
+ raw_bytes: bytes
655
+ if isinstance(data, bytes):
656
+ raw_bytes = data
657
+ if not content_type:
658
+ content_type = "application/octet-stream"
659
+ else:
660
+ if os.path.exists(data):
661
+ path = data
662
+ guessed = mimetypes.guess_type(path)[0]
663
+ content_type = content_type or guessed or "application/octet-stream"
664
+ async with await _aio_open_file(path, "rb") as f:
665
+ raw_bytes = await f.read() # type: ignore[attr-defined]
666
+ if not options.filename:
667
+ options.filename = os.path.basename(path)
668
+ elif data.startswith("data:"):
669
+ match = re.match(r"^data:([^;]+);base64,(.+)$", data)
670
+ if not match:
671
+ raise ValueError("Invalid base64 data URI format")
672
+ content_type = content_type or match.group(1)
673
+ raw_bytes = _b64_to_bytes(match.group(2))
674
+ elif _looks_like_base64(data):
675
+ raw_bytes = _b64_to_bytes(data)
676
+ content_type = content_type or "application/octet-stream"
677
+ else:
678
+ raise ValueError("upload_file expected bytes, data URI, base64 string, or existing file path")
679
+
680
+ file_req = {
681
+ "files": [
682
+ {
683
+ "uri": "",
684
+ "filename": options.filename,
685
+ "content_type": content_type,
686
+ "path": options.path,
687
+ "size": len(raw_bytes),
688
+ "public": options.public,
689
+ }
690
+ ]
691
+ }
692
+
693
+ created = await self._request("post", "/files", data=file_req)
694
+ file_obj = created[0]
695
+ upload_url = file_obj.get("upload_url")
696
+ if not upload_url:
697
+ raise RuntimeError("No upload URL provided by the server")
698
+
699
+ aiohttp = await _require_aiohttp()
700
+ timeout_cfg = aiohttp.ClientTimeout(total=60)
701
+ async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
702
+ async with session.put(upload_url, data=raw_bytes, headers={"Content-Type": content_type}) as resp:
703
+ if resp.status // 100 != 2:
704
+ raise RuntimeError(f"Failed to upload file content: {resp.reason}")
705
+ return file_obj
706
+
707
+ # --------------- Helpers ---------------
708
+ async def _process_input_data(self, input_value: Any, path: str = "root") -> Any:
709
+ if input_value is None:
710
+ return input_value
711
+
712
+ if isinstance(input_value, list):
713
+ return [await self._process_input_data(item, f"{path}[{idx}]") for idx, item in enumerate(input_value)]
714
+
715
+ if isinstance(input_value, dict):
716
+ processed: Dict[str, Any] = {}
717
+ for key, value in input_value.items():
718
+ processed[key] = await self._process_input_data(value, f"{path}.{key}")
719
+ return processed
720
+
721
+ if isinstance(input_value, str):
722
+ if os.path.exists(input_value):
723
+ file_obj = await self.upload_file(input_value)
724
+ return file_obj.get("uri")
725
+ if input_value.startswith("data:") or _looks_like_base64(input_value):
726
+ file_obj = await self.upload_file(input_value)
727
+ return file_obj.get("uri")
728
+ return input_value
729
+
730
+ try:
731
+ from .models.file import File as SDKFile # local import
732
+ if isinstance(input_value, SDKFile):
733
+ src = input_value.path or input_value.uri
734
+ if not src:
735
+ return input_value
736
+ file_obj = await self.upload_file(src, UploadFileOptions(filename=input_value.filename, content_type=input_value.content_type))
737
+ return file_obj.get("uri")
738
+ except Exception:
739
+ raise
740
+
741
+ return input_value
742
+
743
+ async def _aiter_sse(self, resp: Any) -> Generator[Dict[str, Any], None, None]:
744
+ async for raw_line in resp.content: # type: ignore[attr-defined]
745
+ try:
746
+ line = raw_line.decode().rstrip("\n")
747
+ except Exception:
748
+ continue
749
+ if not line:
750
+ continue
751
+ if line.startswith(":"):
752
+ continue
753
+ if line.startswith("data:"):
754
+ data_str = line[5:].lstrip()
755
+ if not data_str:
756
+ continue
757
+ try:
758
+ yield json.loads(data_str)
759
+ except json.JSONDecodeError:
760
+ continue
761
+
762
+
763
+ # --------------- small async utilities ---------------
764
+ async def _async_sleep(seconds: float) -> None:
765
+ import asyncio
766
+
767
+ await asyncio.sleep(seconds)
768
+
769
+
770
+ def _b64_to_bytes(b64: str) -> bytes:
771
+ import base64
772
+
773
+ return base64.b64decode(b64)
774
+
775
+
776
+ async def _aio_open_file(path: str, mode: str):
777
+ import aiofiles # type: ignore
778
+
779
+ return await aiofiles.open(path, mode)
780
+
781
+
782
+ def _looks_like_base64(value: str) -> bool:
783
+ # Reject very short strings to avoid matching normal words like "hi"
784
+ if len(value) < 16:
785
+ return False
786
+ # Quick charset check
787
+ if not Base64_RE.match(value):
788
+ return False
789
+ # Must be divisible by 4
790
+ if len(value) % 4 != 0:
791
+ return False
792
+ # Try decode to be sure
793
+ try:
794
+ _ = _b64_to_bytes(value)
795
+ return True
796
+ except Exception:
797
+ return False
798
+
799
+
800
+ def _process_stream_event(
801
+ data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
802
+ ) -> Optional[Dict[str, Any]]:
803
+ """Shared handler for SSE task events. Returns final task dict when completed, else None.
804
+ If stopper is provided, it will be called on terminal events to end streaming.
805
+ """
806
+ status = data.get("status")
807
+ output = data.get("output")
808
+ logs = data.get("logs")
809
+
810
+ if status == TaskStatus.COMPLETED:
811
+ result = {
812
+ **task,
813
+ "status": data.get("status"),
814
+ "output": data.get("output"),
815
+ "logs": data.get("logs") or [],
816
+ }
817
+ if stopper:
818
+ stopper()
819
+ return result
820
+ if status == TaskStatus.FAILED:
821
+ if stopper:
822
+ stopper()
823
+ raise RuntimeError(data.get("error") or "task failed")
824
+ if status == TaskStatus.CANCELLED:
825
+ if stopper:
826
+ stopper()
827
+ raise RuntimeError("task cancelled")
828
+ return None
829
+
830
+