inferencesh 0.3.1__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of inferencesh might be problematic. Click here for more details.

Files changed (24) hide show
  1. {inferencesh-0.3.1/src/inferencesh.egg-info → inferencesh-0.4.1}/PKG-INFO +63 -17
  2. inferencesh-0.4.1/README.md +128 -0
  3. {inferencesh-0.3.1 → inferencesh-0.4.1}/pyproject.toml +1 -1
  4. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/client.py +392 -173
  5. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/models/file.py +49 -3
  6. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/utils/download.py +15 -7
  7. {inferencesh-0.3.1 → inferencesh-0.4.1/src/inferencesh.egg-info}/PKG-INFO +63 -17
  8. inferencesh-0.3.1/README.md +0 -82
  9. {inferencesh-0.3.1 → inferencesh-0.4.1}/LICENSE +0 -0
  10. {inferencesh-0.3.1 → inferencesh-0.4.1}/setup.cfg +0 -0
  11. {inferencesh-0.3.1 → inferencesh-0.4.1}/setup.py +0 -0
  12. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/__init__.py +0 -0
  13. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/models/__init__.py +0 -0
  14. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/models/base.py +0 -0
  15. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/models/llm.py +0 -0
  16. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/utils/__init__.py +0 -0
  17. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh/utils/storage.py +0 -0
  18. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh.egg-info/SOURCES.txt +0 -0
  19. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh.egg-info/dependency_links.txt +0 -0
  20. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh.egg-info/entry_points.txt +0 -0
  21. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh.egg-info/requires.txt +0 -0
  22. {inferencesh-0.3.1 → inferencesh-0.4.1}/src/inferencesh.egg-info/top_level.txt +0 -0
  23. {inferencesh-0.3.1 → inferencesh-0.4.1}/tests/test_client.py +0 -0
  24. {inferencesh-0.3.1 → inferencesh-0.4.1}/tests/test_sdk.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: inferencesh
3
- Version: 0.3.1
3
+ Version: 0.4.1
4
4
  Summary: inference.sh Python SDK
5
5
  Author: Inference Shell Inc.
6
6
  Author-email: "Inference Shell Inc." <hello@inference.sh>
@@ -25,19 +25,65 @@ Dynamic: author
25
25
  Dynamic: license-file
26
26
  Dynamic: requires-python
27
27
 
28
- # inference.sh CLI
28
+ # inference.sh sdk
29
29
 
30
- Helper package for inference.sh Python applications.
30
+ helper package for inference.sh python applications.
31
31
 
32
- ## Installation
32
+ ## installation
33
33
 
34
34
  ```bash
35
35
  pip install infsh
36
36
  ```
37
37
 
38
- ## File Handling
38
+ ## client usage
39
39
 
40
- The `File` class provides a standardized way to handle files in the inference.sh ecosystem:
40
+ ```python
41
+ from infsh import Inference, TaskStatus
42
+
43
+ # create client
44
+ client = Inference(api_key="your-api-key")
45
+
46
+ # simple usage - wait for result
47
+ result = client.run({
48
+ "app": "your-app",
49
+ "input": {"key": "value"},
50
+ "variant": "default"
51
+ })
52
+ print(f"output: {result['output']}")
53
+
54
+ # get task info without waiting
55
+ task = client.run(params, wait=False)
56
+ print(f"task id: {task['id']}")
57
+
58
+ # stream updates (recommended)
59
+ for update in client.run(params, stream=True):
60
+ status = update.get("status")
61
+ print(f"status: {TaskStatus(status).name}")
62
+
63
+ if status == TaskStatus.COMPLETED:
64
+ print(f"output: {update.get('output')}")
65
+ break
66
+ elif status == TaskStatus.FAILED:
67
+ print(f"error: {update.get('error')}")
68
+ break
69
+
70
+ # async support
71
+ async def run_async():
72
+ from infsh import AsyncInference
73
+
74
+ client = AsyncInference(api_key="your-api-key")
75
+
76
+ # simple usage
77
+ result = await client.run(params)
78
+
79
+ # stream updates
80
+ async for update in await client.run(params, stream=True):
81
+ print(f"status: {TaskStatus(update['status']).name}")
82
+ ```
83
+
84
+ ## file handling
85
+
86
+ the `File` class provides a standardized way to handle files in the inference.sh ecosystem:
41
87
 
42
88
  ```python
43
89
  from infsh import File
@@ -68,15 +114,15 @@ print(file.filename) # basename of the file
68
114
  file.refresh_metadata()
69
115
  ```
70
116
 
71
- The `File` class automatically handles:
72
- - MIME type detection
73
- - File size calculation
74
- - Filename extraction from path
75
- - File existence checking
117
+ the `File` class automatically handles:
118
+ - mime type detection
119
+ - file size calculation
120
+ - filename extraction from path
121
+ - file existence checking
76
122
 
77
- ## Creating an App
123
+ ## creating an app
78
124
 
79
- To create an inference app, inherit from `BaseApp` and define your input/output types:
125
+ to create an inference app, inherit from `BaseApp` and define your input/output types:
80
126
 
81
127
  ```python
82
128
  from infsh import BaseApp, BaseAppInput, BaseAppOutput, File
@@ -103,7 +149,7 @@ class MyApp(BaseApp):
103
149
  pass
104
150
  ```
105
151
 
106
- The app lifecycle has three main methods:
107
- - `setup()`: Called when the app starts, use it to initialize models
108
- - `run()`: Called for each inference request
109
- - `unload()`: Called when shutting down, use it to free resources
152
+ app lifecycle has three main methods:
153
+ - `setup()`: called when the app starts, use it to initialize models
154
+ - `run()`: called for each inference request
155
+ - `unload()`: called when shutting down, use it to free resources
@@ -0,0 +1,128 @@
1
+ # inference.sh sdk
2
+
3
+ helper package for inference.sh python applications.
4
+
5
+ ## installation
6
+
7
+ ```bash
8
+ pip install infsh
9
+ ```
10
+
11
+ ## client usage
12
+
13
+ ```python
14
+ from infsh import Inference, TaskStatus
15
+
16
+ # create client
17
+ client = Inference(api_key="your-api-key")
18
+
19
+ # simple usage - wait for result
20
+ result = client.run({
21
+ "app": "your-app",
22
+ "input": {"key": "value"},
23
+ "variant": "default"
24
+ })
25
+ print(f"output: {result['output']}")
26
+
27
+ # get task info without waiting
28
+ task = client.run(params, wait=False)
29
+ print(f"task id: {task['id']}")
30
+
31
+ # stream updates (recommended)
32
+ for update in client.run(params, stream=True):
33
+ status = update.get("status")
34
+ print(f"status: {TaskStatus(status).name}")
35
+
36
+ if status == TaskStatus.COMPLETED:
37
+ print(f"output: {update.get('output')}")
38
+ break
39
+ elif status == TaskStatus.FAILED:
40
+ print(f"error: {update.get('error')}")
41
+ break
42
+
43
+ # async support
44
+ async def run_async():
45
+ from infsh import AsyncInference
46
+
47
+ client = AsyncInference(api_key="your-api-key")
48
+
49
+ # simple usage
50
+ result = await client.run(params)
51
+
52
+ # stream updates
53
+ async for update in await client.run(params, stream=True):
54
+ print(f"status: {TaskStatus(update['status']).name}")
55
+ ```
56
+
57
+ ## file handling
58
+
59
+ the `File` class provides a standardized way to handle files in the inference.sh ecosystem:
60
+
61
+ ```python
62
+ from infsh import File
63
+
64
+ # Basic file creation
65
+ file = File(path="/path/to/file.png")
66
+
67
+ # File with explicit metadata
68
+ file = File(
69
+ path="/path/to/file.png",
70
+ content_type="image/png",
71
+ filename="custom_name.png",
72
+ size=1024 # in bytes
73
+ )
74
+
75
+ # Create from path (automatically populates metadata)
76
+ file = File.from_path("/path/to/file.png")
77
+
78
+ # Check if file exists
79
+ exists = file.exists()
80
+
81
+ # Access file metadata
82
+ print(file.content_type) # automatically detected if not specified
83
+ print(file.size) # file size in bytes
84
+ print(file.filename) # basename of the file
85
+
86
+ # Refresh metadata (useful if file has changed)
87
+ file.refresh_metadata()
88
+ ```
89
+
90
+ the `File` class automatically handles:
91
+ - mime type detection
92
+ - file size calculation
93
+ - filename extraction from path
94
+ - file existence checking
95
+
96
+ ## creating an app
97
+
98
+ to create an inference app, inherit from `BaseApp` and define your input/output types:
99
+
100
+ ```python
101
+ from infsh import BaseApp, BaseAppInput, BaseAppOutput, File
102
+
103
+ class AppInput(BaseAppInput):
104
+ image: str # URL or file path to image
105
+ mask: str # URL or file path to mask
106
+
107
+ class AppOutput(BaseAppOutput):
108
+ image: File
109
+
110
+ class MyApp(BaseApp):
111
+ async def setup(self):
112
+ # Initialize your model here
113
+ pass
114
+
115
+ async def run(self, app_input: AppInput) -> AppOutput:
116
+ # Process input and return output
117
+ result_path = "/tmp/result.png"
118
+ return AppOutput(image=File(path=result_path))
119
+
120
+ async def unload(self):
121
+ # Clean up resources
122
+ pass
123
+ ```
124
+
125
+ app lifecycle has three main methods:
126
+ - `setup()`: called when the app starts, use it to initialize models
127
+ - `run()`: called for each inference request
128
+ - `unload()`: called when shutting down, use it to free resources
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "inferencesh"
7
- version = "0.3.1"
7
+ version = "0.4.1"
8
8
  description = "inference.sh Python SDK"
9
9
  authors = [
10
10
  {name = "Inference Shell Inc.", email = "hello@inference.sh"},
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, Optional, Callable, Generator, Union
3
+ from typing import Any, Dict, Optional, Callable, Generator, Union, Iterator
4
4
  from dataclasses import dataclass
5
5
  from enum import IntEnum
6
6
  import json
@@ -8,6 +8,103 @@ import re
8
8
  import time
9
9
  import mimetypes
10
10
  import os
11
+ from contextlib import AbstractContextManager
12
+ from typing import Protocol, runtime_checkable
13
+
14
+
15
+ class TaskStream(AbstractContextManager['TaskStream']):
16
+ """A context manager for streaming task updates.
17
+
18
+ This class provides a Pythonic interface for handling streaming updates from a task.
19
+ It can be used either as a context manager or as an iterator.
20
+
21
+ Example:
22
+ ```python
23
+ # As a context manager
24
+ with client.stream_task(task_id) as stream:
25
+ for update in stream:
26
+ print(f"Update: {update}")
27
+
28
+ # As an iterator
29
+ for update in client.stream_task(task_id):
30
+ print(f"Update: {update}")
31
+ ```
32
+ """
33
+ def __init__(
34
+ self,
35
+ task: Dict[str, Any],
36
+ client: Any,
37
+ auto_reconnect: bool = True,
38
+ max_reconnects: int = 5,
39
+ reconnect_delay_ms: int = 1000,
40
+ ):
41
+ self.task = task
42
+ self.client = client
43
+ self.task_id = task["id"]
44
+ self.auto_reconnect = auto_reconnect
45
+ self.max_reconnects = max_reconnects
46
+ self.reconnect_delay_ms = reconnect_delay_ms
47
+ self._final_task: Optional[Dict[str, Any]] = None
48
+ self._error: Optional[Exception] = None
49
+
50
+ def __enter__(self) -> 'TaskStream':
51
+ return self
52
+
53
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
54
+ pass
55
+
56
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
57
+ return self.stream()
58
+
59
+ @property
60
+ def result(self) -> Optional[Dict[str, Any]]:
61
+ """The final task result if completed, None otherwise."""
62
+ return self._final_task
63
+
64
+ @property
65
+ def error(self) -> Optional[Exception]:
66
+ """The error that occurred during streaming, if any."""
67
+ return self._error
68
+
69
+ def stream(self) -> Iterator[Dict[str, Any]]:
70
+ """Stream updates for this task.
71
+
72
+ Yields:
73
+ Dict[str, Any]: Task update events
74
+
75
+ Raises:
76
+ RuntimeError: If the task fails or is cancelled
77
+ """
78
+ try:
79
+ for update in self.client._stream_updates(
80
+ self.task_id,
81
+ self.task,
82
+ ):
83
+ if isinstance(update, Exception):
84
+ self._error = update
85
+ raise update
86
+ if update.get("status") == TaskStatus.COMPLETED:
87
+ self._final_task = update
88
+ yield update
89
+ except Exception as exc:
90
+ self._error = exc
91
+ raise
92
+
93
+
94
+ @runtime_checkable
95
+ class TaskCallback(Protocol):
96
+ """Protocol for task streaming callbacks."""
97
+ def on_update(self, data: Dict[str, Any]) -> None:
98
+ """Called when a task update is received."""
99
+ ...
100
+
101
+ def on_error(self, error: Exception) -> None:
102
+ """Called when an error occurs during task execution."""
103
+ ...
104
+
105
+ def on_complete(self, task: Dict[str, Any]) -> None:
106
+ """Called when a task completes successfully."""
107
+ ...
11
108
 
12
109
 
13
110
  # Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
@@ -228,122 +325,116 @@ class Inference:
228
325
  return payload.get("data")
229
326
 
230
327
  # --------------- Public API ---------------
231
- def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
232
- processed_input = self._process_input_data(params.get("input"))
233
- task = self._request("post", "/run", data={**params, "input": processed_input})
234
- return task
235
-
236
- def run_sync(
328
+ def run(
237
329
  self,
238
330
  params: Dict[str, Any],
239
331
  *,
332
+ wait: bool = True,
333
+ stream: bool = False,
240
334
  auto_reconnect: bool = True,
241
335
  max_reconnects: int = 5,
242
336
  reconnect_delay_ms: int = 1000,
243
- ) -> Dict[str, Any]:
337
+ ) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
338
+ """Run a task with optional streaming updates.
339
+
340
+ By default, this method waits for the task to complete and returns the final result.
341
+ You can set wait=False to get just the task info, or stream=True to get an iterator
342
+ of status updates.
343
+
344
+ Args:
345
+ params: Task parameters to pass to the API
346
+ wait: Whether to wait for task completion (default: True)
347
+ stream: Whether to return an iterator of updates (default: False)
348
+ auto_reconnect: Whether to automatically reconnect on connection loss
349
+ max_reconnects: Maximum number of reconnection attempts
350
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
351
+
352
+ Returns:
353
+ Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
354
+ - If wait=True and stream=False: The completed task data
355
+ - If wait=False: The created task info
356
+ - If stream=True: An iterator of task updates
357
+
358
+ Example:
359
+ ```python
360
+ # Simple usage - wait for result (default)
361
+ result = client.run(params)
362
+ print(f"Output: {result['output']}")
363
+
364
+ # Get task info without waiting
365
+ task = client.run(params, wait=False)
366
+ task_id = task["id"]
367
+
368
+ # Stream updates
369
+ for update in client.run(params, stream=True):
370
+ print(f"Status: {update.get('status')}")
371
+ if update.get('status') == TaskStatus.COMPLETED:
372
+ print(f"Result: {update.get('output')}")
373
+ ```
374
+ """
375
+ # Create the task
244
376
  processed_input = self._process_input_data(params.get("input"))
245
377
  task = self._request("post", "/run", data={**params, "input": processed_input})
246
- task_id = task["id"]
247
-
248
- final_task: Optional[Dict[str, Any]] = None
249
-
250
- def on_data(data: Dict[str, Any]) -> None:
251
- nonlocal final_task
252
- try:
253
- result = _process_stream_event(
254
- data,
255
- task=task,
256
- stopper=lambda: manager.stop(),
257
- )
258
- if result is not None:
259
- final_task = result
260
- except Exception as exc:
261
- raise
262
-
263
- def on_error(exc: Exception) -> None:
264
- raise exc
265
-
266
- def on_start() -> None:
267
- pass
268
-
269
- def on_stop() -> None:
270
- pass
271
-
272
- manager = StreamManager(
273
- create_event_source=None, # We'll set this after defining it
274
- auto_reconnect=auto_reconnect,
275
- max_reconnects=max_reconnects,
276
- reconnect_delay_ms=reconnect_delay_ms,
277
- on_data=on_data,
278
- on_error=on_error,
279
- on_start=on_start,
280
- on_stop=on_stop,
281
- )
282
-
283
- def create_event_source() -> Generator[Dict[str, Any], None, None]:
284
- url = f"/tasks/{task_id}/stream"
285
- resp = self._request(
286
- "get",
287
- url,
288
- headers={
289
- "Accept": "text/event-stream",
290
- "Cache-Control": "no-cache",
291
- "Accept-Encoding": "identity",
292
- "Connection": "keep-alive",
293
- },
294
- stream=True,
295
- timeout=60,
378
+
379
+ # Return immediately if not waiting
380
+ if not wait and not stream:
381
+ return _strip_task(task)
382
+
383
+ # Return stream if requested
384
+ if stream:
385
+ task_stream = TaskStream(
386
+ task=task,
387
+ client=self,
388
+ auto_reconnect=auto_reconnect,
389
+ max_reconnects=max_reconnects,
390
+ reconnect_delay_ms=reconnect_delay_ms,
296
391
  )
392
+ return task_stream
297
393
 
298
- try:
299
- last_event_at = time.perf_counter()
300
- for evt in self._iter_sse(resp, stream_manager=manager):
301
- yield evt
302
- finally:
303
- try:
304
- # Force close the underlying socket if possible
305
- try:
306
- raw = getattr(resp, 'raw', None)
307
- if raw is not None:
308
- raw.close()
309
- except Exception:
310
- raise
311
- # Close the response
312
- resp.close()
313
- except Exception:
314
- raise
315
-
316
- # Update the create_event_source function in the manager
317
- manager._create_event_source = create_event_source
394
+ # Otherwise wait for completion
395
+ return self.wait_for_completion(task["id"])
318
396
 
319
- # Connect and wait for completion
320
- manager.connect()
321
397
 
322
- # At this point, we should have a final task state
323
- if final_task is not None:
324
- return final_task
325
-
326
- # Try to fetch the latest state as a fallback
327
- try:
328
- latest = self.get_task(task_id)
329
- status = latest.get("status")
330
- if status == TaskStatus.COMPLETED:
331
- return latest
332
- if status == TaskStatus.FAILED:
333
- raise RuntimeError(latest.get("error") or "task failed")
334
- if status == TaskStatus.CANCELLED:
335
- raise RuntimeError("task cancelled")
336
- except Exception as exc:
337
- raise
338
-
339
- raise RuntimeError("Stream ended without completion")
340
398
 
341
399
  def cancel(self, task_id: str) -> None:
342
400
  self._request("post", f"/tasks/{task_id}/cancel")
343
401
 
344
402
  def get_task(self, task_id: str) -> Dict[str, Any]:
403
+ """Get the current state of a task.
404
+
405
+ Args:
406
+ task_id: The ID of the task to get
407
+
408
+ Returns:
409
+ Dict[str, Any]: The current task state
410
+ """
345
411
  return self._request("get", f"/tasks/{task_id}")
346
412
 
413
+ def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
414
+ """Wait for a task to complete and return its final state.
415
+
416
+ This method polls the task status until it reaches a terminal state
417
+ (completed, failed, or cancelled).
418
+
419
+ Args:
420
+ task_id: The ID of the task to wait for
421
+
422
+ Returns:
423
+ Dict[str, Any]: The final task state
424
+
425
+ Raises:
426
+ RuntimeError: If the task fails or is cancelled
427
+ """
428
+ with self.stream_task(task_id) as stream:
429
+ for update in stream:
430
+ if update.get("status") == TaskStatus.COMPLETED:
431
+ return update
432
+ elif update.get("status") == TaskStatus.FAILED:
433
+ raise RuntimeError(update.get("error") or "Task failed")
434
+ elif update.get("status") == TaskStatus.CANCELLED:
435
+ raise RuntimeError("Task cancelled")
436
+ raise RuntimeError("Stream ended without completion")
437
+
347
438
  # --------------- File upload ---------------
348
439
  def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
349
440
  options = options or UploadFileOptions()
@@ -403,6 +494,103 @@ class Inference:
403
494
  return file_obj
404
495
 
405
496
  # --------------- Helpers ---------------
497
+ def stream_task(
498
+ self,
499
+ task_id: str,
500
+ *,
501
+ auto_reconnect: bool = True,
502
+ max_reconnects: int = 5,
503
+ reconnect_delay_ms: int = 1000,
504
+ ) -> TaskStream:
505
+ """Create a TaskStream for getting streaming updates from a task.
506
+
507
+ This provides a more Pythonic interface for handling task updates compared to callbacks.
508
+ The returned TaskStream can be used either as a context manager or as an iterator.
509
+
510
+ Args:
511
+ task_id: The ID of the task to stream
512
+ auto_reconnect: Whether to automatically reconnect on connection loss
513
+ max_reconnects: Maximum number of reconnection attempts
514
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
515
+
516
+ Returns:
517
+ TaskStream: A stream interface for the task
518
+
519
+ Example:
520
+ ```python
521
+ # Run a task
522
+ task = client.run(params)
523
+
524
+ # Stream updates using context manager
525
+ with client.stream_task(task["id"]) as stream:
526
+ for update in stream:
527
+ print(f"Status: {update.get('status')}")
528
+ if update.get("status") == TaskStatus.COMPLETED:
529
+ print(f"Result: {update.get('output')}")
530
+
531
+ # Or use as a simple iterator
532
+ for update in client.stream_task(task["id"]):
533
+ print(f"Update: {update}")
534
+ ```
535
+ """
536
+ task = self.get_task(task_id)
537
+ return TaskStream(
538
+ task=task,
539
+ client=self,
540
+ auto_reconnect=auto_reconnect,
541
+ max_reconnects=max_reconnects,
542
+ reconnect_delay_ms=reconnect_delay_ms,
543
+ )
544
+
545
+ def _stream_updates(
546
+ self,
547
+ task_id: str,
548
+ task: Dict[str, Any],
549
+ ) -> Generator[Union[Dict[str, Any], Exception], None, None]:
550
+ """Internal method to stream task updates."""
551
+ url = f"/tasks/{task_id}/stream"
552
+ resp = self._request(
553
+ "get",
554
+ url,
555
+ headers={
556
+ "Accept": "text/event-stream",
557
+ "Cache-Control": "no-cache",
558
+ "Accept-Encoding": "identity",
559
+ "Connection": "keep-alive",
560
+ },
561
+ stream=True,
562
+ timeout=60,
563
+ )
564
+ try:
565
+ for evt in self._iter_sse(resp):
566
+ try:
567
+ # Process the event to check for completion/errors
568
+ result = _process_stream_event(
569
+ evt,
570
+ task=task,
571
+ stopper=None, # We'll handle stopping via the iterator
572
+ )
573
+ if result is not None:
574
+ yield result
575
+ break
576
+ yield _strip_task(evt)
577
+ except Exception as exc:
578
+ yield exc
579
+ raise
580
+ finally:
581
+ try:
582
+ # Force close the underlying socket if possible
583
+ try:
584
+ raw = getattr(resp, 'raw', None)
585
+ if raw is not None:
586
+ raw.close()
587
+ except Exception:
588
+ raise
589
+ # Close the response
590
+ resp.close()
591
+ except Exception:
592
+ raise
593
+
406
594
  def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
407
595
  """Iterate JSON events from an SSE response."""
408
596
  # Mode 1: raw socket readline (can reduce buffering in some environments)
@@ -565,88 +753,114 @@ class AsyncInference:
565
753
  return payload.get("data")
566
754
 
567
755
  # --------------- Public API ---------------
568
- async def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
569
- processed_input = await self._process_input_data(params.get("input"))
570
- task = await self._request("post", "/run", data={**params, "input": processed_input})
571
- return task
572
-
573
- async def run_sync(
756
+ async def run(
574
757
  self,
575
758
  params: Dict[str, Any],
576
759
  *,
760
+ wait: bool = True,
761
+ stream: bool = False,
577
762
  auto_reconnect: bool = True,
578
763
  max_reconnects: int = 5,
579
764
  reconnect_delay_ms: int = 1000,
580
- ) -> Dict[str, Any]:
765
+ ) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
766
+ """Run a task with optional streaming updates.
767
+
768
+ By default, this method waits for the task to complete and returns the final result.
769
+ You can set wait=False to get just the task info, or stream=True to get an iterator
770
+ of status updates.
771
+
772
+ Args:
773
+ params: Task parameters to pass to the API
774
+ wait: Whether to wait for task completion (default: True)
775
+ stream: Whether to return an iterator of updates (default: False)
776
+ auto_reconnect: Whether to automatically reconnect on connection loss
777
+ max_reconnects: Maximum number of reconnection attempts
778
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
779
+
780
+ Returns:
781
+ Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
782
+ - If wait=True and stream=False: The completed task data
783
+ - If wait=False: The created task info
784
+ - If stream=True: An iterator of task updates
785
+
786
+ Example:
787
+ ```python
788
+ # Simple usage - wait for result (default)
789
+ result = await client.run(params)
790
+ print(f"Output: {result['output']}")
791
+
792
+ # Get task info without waiting
793
+ task = await client.run(params, wait=False)
794
+ task_id = task["id"]
795
+
796
+ # Stream updates
797
+ async for update in await client.run(params, stream=True):
798
+ print(f"Status: {update.get('status')}")
799
+ if update.get('status') == TaskStatus.COMPLETED:
800
+ print(f"Result: {update.get('output')}")
801
+ ```
802
+ """
803
+ # Create the task
581
804
  processed_input = await self._process_input_data(params.get("input"))
582
805
  task = await self._request("post", "/run", data={**params, "input": processed_input})
583
- task_id = task["id"]
584
-
585
- final_task: Optional[Dict[str, Any]] = None
586
- reconnect_attempts = 0
587
- had_success = False
588
-
589
- while True:
590
- try:
591
- resp = await self._request(
592
- "get",
593
- f"/tasks/{task_id}/stream",
594
- headers={
595
- "Accept": "text/event-stream",
596
- "Cache-Control": "no-cache",
597
- "Accept-Encoding": "identity",
598
- "Connection": "keep-alive",
599
- },
600
- timeout=60,
601
- expect_stream=True,
602
- )
603
- had_success = True
604
- async for data in self._aiter_sse(resp):
605
- result = _process_stream_event(
606
- data,
607
- task=task,
608
- stopper=None,
609
- )
610
- if result is not None:
611
- final_task = result
612
- break
613
- if final_task is not None:
614
- break
615
- except Exception as exc: # noqa: BLE001
616
- if not auto_reconnect:
617
- raise
618
- if not had_success:
619
- reconnect_attempts += 1
620
- if reconnect_attempts > max_reconnects:
621
- raise
622
- await _async_sleep(reconnect_delay_ms / 1000.0)
623
- else:
624
- if not auto_reconnect:
625
- break
626
- await _async_sleep(reconnect_delay_ms / 1000.0)
627
-
628
- if final_task is None:
629
- # Fallback: fetch latest task state in case stream ended without a terminal event
630
- try:
631
- latest = await self.get_task(task_id)
632
- status = latest.get("status")
633
- if status == TaskStatus.COMPLETED:
634
- return latest
635
- if status == TaskStatus.FAILED:
636
- raise RuntimeError(latest.get("error") or "task failed")
637
- if status == TaskStatus.CANCELLED:
638
- raise RuntimeError("task cancelled")
639
- except Exception:
640
- raise
641
- raise RuntimeError("Stream ended without completion")
642
- return final_task
806
+
807
+ # Return immediately if not waiting
808
+ if not wait and not stream:
809
+ return task
810
+
811
+ # Return stream if requested
812
+ if stream:
813
+ task_stream = TaskStream(
814
+ task=task,
815
+ client=self,
816
+ auto_reconnect=auto_reconnect,
817
+ max_reconnects=max_reconnects,
818
+ reconnect_delay_ms=reconnect_delay_ms,
819
+ )
820
+ return task_stream
821
+
822
+ # Otherwise wait for completion
823
+ return await self.wait_for_completion(task["id"])
643
824
 
644
825
  async def cancel(self, task_id: str) -> None:
645
826
  await self._request("post", f"/tasks/{task_id}/cancel")
646
827
 
647
828
  async def get_task(self, task_id: str) -> Dict[str, Any]:
829
+ """Get the current state of a task.
830
+
831
+ Args:
832
+ task_id: The ID of the task to get
833
+
834
+ Returns:
835
+ Dict[str, Any]: The current task state
836
+ """
648
837
  return await self._request("get", f"/tasks/{task_id}")
649
838
 
839
+ async def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
840
+ """Wait for a task to complete and return its final state.
841
+
842
+ This method polls the task status until it reaches a terminal state
843
+ (completed, failed, or cancelled).
844
+
845
+ Args:
846
+ task_id: The ID of the task to wait for
847
+
848
+ Returns:
849
+ Dict[str, Any]: The final task state
850
+
851
+ Raises:
852
+ RuntimeError: If the task fails or is cancelled
853
+ """
854
+ with self.stream_task(task_id) as stream:
855
+ async for update in stream:
856
+ if update.get("status") == TaskStatus.COMPLETED:
857
+ return update
858
+ elif update.get("status") == TaskStatus.FAILED:
859
+ raise RuntimeError(update.get("error") or "Task failed")
860
+ elif update.get("status") == TaskStatus.CANCELLED:
861
+ raise RuntimeError("Task cancelled")
862
+ raise RuntimeError("Stream ended without completion")
863
+
650
864
  # --------------- File upload ---------------
651
865
  async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
652
866
  options = options or UploadFileOptions()
@@ -797,6 +1011,18 @@ def _looks_like_base64(value: str) -> bool:
797
1011
  return False
798
1012
 
799
1013
 
1014
+ def _strip_task(task: Dict[str, Any]) -> Dict[str, Any]:
1015
+ """Strip task to essential fields."""
1016
+ return {
1017
+ "id": task.get("id"),
1018
+ "created_at": task.get("created_at"),
1019
+ "updated_at": task.get("updated_at"),
1020
+ "input": task.get("input"),
1021
+ "output": task.get("output"),
1022
+ "logs": task.get("logs"),
1023
+ "status": task.get("status"),
1024
+ }
1025
+
800
1026
  def _process_stream_event(
801
1027
  data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
802
1028
  ) -> Optional[Dict[str, Any]]:
@@ -804,16 +1030,9 @@ def _process_stream_event(
804
1030
  If stopper is provided, it will be called on terminal events to end streaming.
805
1031
  """
806
1032
  status = data.get("status")
807
- output = data.get("output")
808
- logs = data.get("logs")
809
1033
 
810
1034
  if status == TaskStatus.COMPLETED:
811
- result = {
812
- **task,
813
- "status": data.get("status"),
814
- "output": data.get("output"),
815
- "logs": data.get("logs") or [],
816
- }
1035
+ result = _strip_task(data)
817
1036
  if stopper:
818
1037
  stopper()
819
1038
  return result
@@ -5,11 +5,45 @@ import os
5
5
  import urllib.request
6
6
  import urllib.parse
7
7
  import tempfile
8
+ import hashlib
9
+ from pathlib import Path
8
10
  from tqdm import tqdm
9
11
 
10
12
 
11
13
  class File(BaseModel):
12
14
  """A class representing a file in the inference.sh ecosystem."""
15
+
16
+ @classmethod
17
+ def get_cache_dir(cls) -> Path:
18
+ """Get the cache directory path based on environment variables or default location."""
19
+ if cache_dir := os.environ.get("FILE_CACHE_DIR"):
20
+ path = Path(cache_dir)
21
+ else:
22
+ path = Path.home() / ".cache" / "inferencesh" / "files"
23
+ path.mkdir(parents=True, exist_ok=True)
24
+ return path
25
+
26
+ def _get_cache_path(self, url: str) -> Path:
27
+ """Get the cache path for a URL using a hash-based directory structure."""
28
+ # Parse URL components
29
+ parsed_url = urllib.parse.urlparse(url)
30
+
31
+ # Create hash from URL path and query parameters for uniqueness
32
+ url_components = parsed_url.netloc + parsed_url.path
33
+ if parsed_url.query:
34
+ url_components += '?' + parsed_url.query
35
+ url_hash = hashlib.sha256(url_components.encode()).hexdigest()[:12]
36
+
37
+ # Get filename from URL or use default
38
+ filename = os.path.basename(parsed_url.path)
39
+ if not filename:
40
+ filename = 'download'
41
+
42
+ # Create hash directory in cache
43
+ cache_dir = self.get_cache_dir() / url_hash
44
+ cache_dir.mkdir(exist_ok=True)
45
+
46
+ return cache_dir / filename
13
47
  uri: Optional[str] = Field(default=None) # Original location (URL or file path)
14
48
  path: Optional[str] = None # Resolved local file path
15
49
  content_type: Optional[str] = None # MIME type of the file
@@ -74,11 +108,20 @@ class File(BaseModel):
74
108
  return parsed.scheme in ('http', 'https')
75
109
 
76
110
  def _download_url(self) -> None:
77
- """Download the URL to a temporary file and update the path."""
111
+ """Download the URL to the cache directory and update the path."""
78
112
  original_url = self.uri
113
+ cache_path = self._get_cache_path(original_url)
114
+
115
+ # If file exists in cache, use it
116
+ if cache_path.exists():
117
+ print(f"Using cached file: {cache_path}")
118
+ self.path = str(cache_path)
119
+ return
120
+
121
+ print(f"Downloading URL: {original_url} to {cache_path}")
79
122
  tmp_file = None
80
123
  try:
81
- # Create a temporary file with a suffix based on the URL path
124
+ # Download to temporary file first to avoid partial downloads in cache
82
125
  suffix = os.path.splitext(urllib.parse.urlparse(original_url).path)[1]
83
126
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
84
127
  self._tmp_path = tmp_file.name
@@ -133,7 +176,10 @@ class File(BaseModel):
133
176
  # If we read the whole body at once, exit loop
134
177
  break
135
178
 
136
- self.path = self._tmp_path
179
+ # Move the temporary file to the cache location
180
+ os.replace(self._tmp_path, cache_path)
181
+ self._tmp_path = None # Prevent deletion in __del__
182
+ self.path = str(cache_path)
137
183
  except (urllib.error.URLError, urllib.error.HTTPError) as e:
138
184
  raise RuntimeError(f"Failed to download URL {original_url}: {str(e)}")
139
185
  except IOError as e:
@@ -24,16 +24,24 @@ def download(url: str, directory: Union[str, Path, StorageDir]) -> str:
24
24
  dir_path = Path(directory)
25
25
  dir_path.mkdir(exist_ok=True)
26
26
 
27
- # Create hash directory from URL
28
- url_hash = hashlib.sha256(url.encode()).hexdigest()[:12]
29
- hash_dir = dir_path / url_hash
30
- hash_dir.mkdir(exist_ok=True)
27
+ # Parse URL components
28
+ parsed_url = urllib.parse.urlparse(url)
31
29
 
32
- # Keep original filename
33
- filename = os.path.basename(urllib.parse.urlparse(url).path)
30
+ # Create hash from URL path and query parameters for uniqueness
31
+ url_components = parsed_url.netloc + parsed_url.path
32
+ if parsed_url.query:
33
+ url_components += '?' + parsed_url.query
34
+ url_hash = hashlib.sha256(url_components.encode()).hexdigest()[:12]
35
+
36
+ # Keep original filename or use a default
37
+ filename = os.path.basename(parsed_url.path)
34
38
  if not filename:
35
39
  filename = 'download'
36
-
40
+
41
+ # Create hash directory and store file
42
+ hash_dir = dir_path / url_hash
43
+ hash_dir.mkdir(exist_ok=True)
44
+
37
45
  output_path = hash_dir / filename
38
46
 
39
47
  # If file exists in directory and it's not a temp directory, return it
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: inferencesh
3
- Version: 0.3.1
3
+ Version: 0.4.1
4
4
  Summary: inference.sh Python SDK
5
5
  Author: Inference Shell Inc.
6
6
  Author-email: "Inference Shell Inc." <hello@inference.sh>
@@ -25,19 +25,65 @@ Dynamic: author
25
25
  Dynamic: license-file
26
26
  Dynamic: requires-python
27
27
 
28
- # inference.sh CLI
28
+ # inference.sh sdk
29
29
 
30
- Helper package for inference.sh Python applications.
30
+ helper package for inference.sh python applications.
31
31
 
32
- ## Installation
32
+ ## installation
33
33
 
34
34
  ```bash
35
35
  pip install infsh
36
36
  ```
37
37
 
38
- ## File Handling
38
+ ## client usage
39
39
 
40
- The `File` class provides a standardized way to handle files in the inference.sh ecosystem:
40
+ ```python
41
+ from infsh import Inference, TaskStatus
42
+
43
+ # create client
44
+ client = Inference(api_key="your-api-key")
45
+
46
+ # simple usage - wait for result
47
+ result = client.run({
48
+ "app": "your-app",
49
+ "input": {"key": "value"},
50
+ "variant": "default"
51
+ })
52
+ print(f"output: {result['output']}")
53
+
54
+ # get task info without waiting
55
+ task = client.run(params, wait=False)
56
+ print(f"task id: {task['id']}")
57
+
58
+ # stream updates (recommended)
59
+ for update in client.run(params, stream=True):
60
+ status = update.get("status")
61
+ print(f"status: {TaskStatus(status).name}")
62
+
63
+ if status == TaskStatus.COMPLETED:
64
+ print(f"output: {update.get('output')}")
65
+ break
66
+ elif status == TaskStatus.FAILED:
67
+ print(f"error: {update.get('error')}")
68
+ break
69
+
70
+ # async support
71
+ async def run_async():
72
+ from infsh import AsyncInference
73
+
74
+ client = AsyncInference(api_key="your-api-key")
75
+
76
+ # simple usage
77
+ result = await client.run(params)
78
+
79
+ # stream updates
80
+ async for update in await client.run(params, stream=True):
81
+ print(f"status: {TaskStatus(update['status']).name}")
82
+ ```
83
+
84
+ ## file handling
85
+
86
+ the `File` class provides a standardized way to handle files in the inference.sh ecosystem:
41
87
 
42
88
  ```python
43
89
  from infsh import File
@@ -68,15 +114,15 @@ print(file.filename) # basename of the file
68
114
  file.refresh_metadata()
69
115
  ```
70
116
 
71
- The `File` class automatically handles:
72
- - MIME type detection
73
- - File size calculation
74
- - Filename extraction from path
75
- - File existence checking
117
+ the `File` class automatically handles:
118
+ - mime type detection
119
+ - file size calculation
120
+ - filename extraction from path
121
+ - file existence checking
76
122
 
77
- ## Creating an App
123
+ ## creating an app
78
124
 
79
- To create an inference app, inherit from `BaseApp` and define your input/output types:
125
+ to create an inference app, inherit from `BaseApp` and define your input/output types:
80
126
 
81
127
  ```python
82
128
  from infsh import BaseApp, BaseAppInput, BaseAppOutput, File
@@ -103,7 +149,7 @@ class MyApp(BaseApp):
103
149
  pass
104
150
  ```
105
151
 
106
- The app lifecycle has three main methods:
107
- - `setup()`: Called when the app starts, use it to initialize models
108
- - `run()`: Called for each inference request
109
- - `unload()`: Called when shutting down, use it to free resources
152
+ app lifecycle has three main methods:
153
+ - `setup()`: called when the app starts, use it to initialize models
154
+ - `run()`: called for each inference request
155
+ - `unload()`: called when shutting down, use it to free resources
@@ -1,82 +0,0 @@
1
- # inference.sh CLI
2
-
3
- Helper package for inference.sh Python applications.
4
-
5
- ## Installation
6
-
7
- ```bash
8
- pip install infsh
9
- ```
10
-
11
- ## File Handling
12
-
13
- The `File` class provides a standardized way to handle files in the inference.sh ecosystem:
14
-
15
- ```python
16
- from infsh import File
17
-
18
- # Basic file creation
19
- file = File(path="/path/to/file.png")
20
-
21
- # File with explicit metadata
22
- file = File(
23
- path="/path/to/file.png",
24
- content_type="image/png",
25
- filename="custom_name.png",
26
- size=1024 # in bytes
27
- )
28
-
29
- # Create from path (automatically populates metadata)
30
- file = File.from_path("/path/to/file.png")
31
-
32
- # Check if file exists
33
- exists = file.exists()
34
-
35
- # Access file metadata
36
- print(file.content_type) # automatically detected if not specified
37
- print(file.size) # file size in bytes
38
- print(file.filename) # basename of the file
39
-
40
- # Refresh metadata (useful if file has changed)
41
- file.refresh_metadata()
42
- ```
43
-
44
- The `File` class automatically handles:
45
- - MIME type detection
46
- - File size calculation
47
- - Filename extraction from path
48
- - File existence checking
49
-
50
- ## Creating an App
51
-
52
- To create an inference app, inherit from `BaseApp` and define your input/output types:
53
-
54
- ```python
55
- from infsh import BaseApp, BaseAppInput, BaseAppOutput, File
56
-
57
- class AppInput(BaseAppInput):
58
- image: str # URL or file path to image
59
- mask: str # URL or file path to mask
60
-
61
- class AppOutput(BaseAppOutput):
62
- image: File
63
-
64
- class MyApp(BaseApp):
65
- async def setup(self):
66
- # Initialize your model here
67
- pass
68
-
69
- async def run(self, app_input: AppInput) -> AppOutput:
70
- # Process input and return output
71
- result_path = "/tmp/result.png"
72
- return AppOutput(image=File(path=result_path))
73
-
74
- async def unload(self):
75
- # Clean up resources
76
- pass
77
- ```
78
-
79
- The app lifecycle has three main methods:
80
- - `setup()`: Called when the app starts, use it to initialize models
81
- - `run()`: Called for each inference request
82
- - `unload()`: Called when shutting down, use it to free resources
File without changes
File without changes
File without changes