inferencesh 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of inferencesh might be problematic. Click here for more details.

inferencesh/client.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, Optional, Callable, Generator, Union
3
+ from typing import Any, Dict, Optional, Callable, Generator, Union, Iterator
4
4
  from dataclasses import dataclass
5
5
  from enum import IntEnum
6
6
  import json
@@ -8,6 +8,103 @@ import re
8
8
  import time
9
9
  import mimetypes
10
10
  import os
11
+ from contextlib import AbstractContextManager
12
+ from typing import Protocol, runtime_checkable
13
+
14
+
15
+ class TaskStream(AbstractContextManager['TaskStream']):
16
+ """A context manager for streaming task updates.
17
+
18
+ This class provides a Pythonic interface for handling streaming updates from a task.
19
+ It can be used either as a context manager or as an iterator.
20
+
21
+ Example:
22
+ ```python
23
+ # As a context manager
24
+ with client.stream_task(task_id) as stream:
25
+ for update in stream:
26
+ print(f"Update: {update}")
27
+
28
+ # As an iterator
29
+ for update in client.stream_task(task_id):
30
+ print(f"Update: {update}")
31
+ ```
32
+ """
33
+ def __init__(
34
+ self,
35
+ task: Dict[str, Any],
36
+ client: Any,
37
+ auto_reconnect: bool = True,
38
+ max_reconnects: int = 5,
39
+ reconnect_delay_ms: int = 1000,
40
+ ):
41
+ self.task = task
42
+ self.client = client
43
+ self.task_id = task["id"]
44
+ self.auto_reconnect = auto_reconnect
45
+ self.max_reconnects = max_reconnects
46
+ self.reconnect_delay_ms = reconnect_delay_ms
47
+ self._final_task: Optional[Dict[str, Any]] = None
48
+ self._error: Optional[Exception] = None
49
+
50
+ def __enter__(self) -> 'TaskStream':
51
+ return self
52
+
53
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
54
+ pass
55
+
56
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
57
+ return self.stream()
58
+
59
+ @property
60
+ def result(self) -> Optional[Dict[str, Any]]:
61
+ """The final task result if completed, None otherwise."""
62
+ return self._final_task
63
+
64
+ @property
65
+ def error(self) -> Optional[Exception]:
66
+ """The error that occurred during streaming, if any."""
67
+ return self._error
68
+
69
+ def stream(self) -> Iterator[Dict[str, Any]]:
70
+ """Stream updates for this task.
71
+
72
+ Yields:
73
+ Dict[str, Any]: Task update events
74
+
75
+ Raises:
76
+ RuntimeError: If the task fails or is cancelled
77
+ """
78
+ try:
79
+ for update in self.client._stream_updates(
80
+ self.task_id,
81
+ self.task,
82
+ ):
83
+ if isinstance(update, Exception):
84
+ self._error = update
85
+ raise update
86
+ if update.get("status") == TaskStatus.COMPLETED:
87
+ self._final_task = update
88
+ yield update
89
+ except Exception as exc:
90
+ self._error = exc
91
+ raise
92
+
93
+
94
+ @runtime_checkable
95
+ class TaskCallback(Protocol):
96
+ """Protocol for task streaming callbacks."""
97
+ def on_update(self, data: Dict[str, Any]) -> None:
98
+ """Called when a task update is received."""
99
+ ...
100
+
101
+ def on_error(self, error: Exception) -> None:
102
+ """Called when an error occurs during task execution."""
103
+ ...
104
+
105
+ def on_complete(self, task: Dict[str, Any]) -> None:
106
+ """Called when a task completes successfully."""
107
+ ...
11
108
 
12
109
 
13
110
  # Deliberately do lazy imports for requests/aiohttp to avoid hard dependency at import time
@@ -228,122 +325,116 @@ class Inference:
228
325
  return payload.get("data")
229
326
 
230
327
  # --------------- Public API ---------------
231
- def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
232
- processed_input = self._process_input_data(params.get("input"))
233
- task = self._request("post", "/run", data={**params, "input": processed_input})
234
- return task
235
-
236
- def run_sync(
328
+ def run(
237
329
  self,
238
330
  params: Dict[str, Any],
239
331
  *,
332
+ wait: bool = True,
333
+ stream: bool = False,
240
334
  auto_reconnect: bool = True,
241
335
  max_reconnects: int = 5,
242
336
  reconnect_delay_ms: int = 1000,
243
- ) -> Dict[str, Any]:
337
+ ) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
338
+ """Run a task with optional streaming updates.
339
+
340
+ By default, this method waits for the task to complete and returns the final result.
341
+ You can set wait=False to get just the task info, or stream=True to get an iterator
342
+ of status updates.
343
+
344
+ Args:
345
+ params: Task parameters to pass to the API
346
+ wait: Whether to wait for task completion (default: True)
347
+ stream: Whether to return an iterator of updates (default: False)
348
+ auto_reconnect: Whether to automatically reconnect on connection loss
349
+ max_reconnects: Maximum number of reconnection attempts
350
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
351
+
352
+ Returns:
353
+ Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
354
+ - If wait=True and stream=False: The completed task data
355
+ - If wait=False: The created task info
356
+ - If stream=True: An iterator of task updates
357
+
358
+ Example:
359
+ ```python
360
+ # Simple usage - wait for result (default)
361
+ result = client.run(params)
362
+ print(f"Output: {result['output']}")
363
+
364
+ # Get task info without waiting
365
+ task = client.run(params, wait=False)
366
+ task_id = task["id"]
367
+
368
+ # Stream updates
369
+ for update in client.run(params, stream=True):
370
+ print(f"Status: {update.get('status')}")
371
+ if update.get('status') == TaskStatus.COMPLETED:
372
+ print(f"Result: {update.get('output')}")
373
+ ```
374
+ """
375
+ # Create the task
244
376
  processed_input = self._process_input_data(params.get("input"))
245
377
  task = self._request("post", "/run", data={**params, "input": processed_input})
246
- task_id = task["id"]
247
-
248
- final_task: Optional[Dict[str, Any]] = None
249
-
250
- def on_data(data: Dict[str, Any]) -> None:
251
- nonlocal final_task
252
- try:
253
- result = _process_stream_event(
254
- data,
255
- task=task,
256
- stopper=lambda: manager.stop(),
257
- )
258
- if result is not None:
259
- final_task = result
260
- except Exception as exc:
261
- raise
262
-
263
- def on_error(exc: Exception) -> None:
264
- raise exc
265
-
266
- def on_start() -> None:
267
- pass
268
-
269
- def on_stop() -> None:
270
- pass
271
-
272
- manager = StreamManager(
273
- create_event_source=None, # We'll set this after defining it
274
- auto_reconnect=auto_reconnect,
275
- max_reconnects=max_reconnects,
276
- reconnect_delay_ms=reconnect_delay_ms,
277
- on_data=on_data,
278
- on_error=on_error,
279
- on_start=on_start,
280
- on_stop=on_stop,
281
- )
282
-
283
- def create_event_source() -> Generator[Dict[str, Any], None, None]:
284
- url = f"/tasks/{task_id}/stream"
285
- resp = self._request(
286
- "get",
287
- url,
288
- headers={
289
- "Accept": "text/event-stream",
290
- "Cache-Control": "no-cache",
291
- "Accept-Encoding": "identity",
292
- "Connection": "keep-alive",
293
- },
294
- stream=True,
295
- timeout=60,
378
+
379
+ # Return immediately if not waiting
380
+ if not wait and not stream:
381
+ return _strip_task(task)
382
+
383
+ # Return stream if requested
384
+ if stream:
385
+ task_stream = TaskStream(
386
+ task=task,
387
+ client=self,
388
+ auto_reconnect=auto_reconnect,
389
+ max_reconnects=max_reconnects,
390
+ reconnect_delay_ms=reconnect_delay_ms,
296
391
  )
392
+ return task_stream
297
393
 
298
- try:
299
- last_event_at = time.perf_counter()
300
- for evt in self._iter_sse(resp, stream_manager=manager):
301
- yield evt
302
- finally:
303
- try:
304
- # Force close the underlying socket if possible
305
- try:
306
- raw = getattr(resp, 'raw', None)
307
- if raw is not None:
308
- raw.close()
309
- except Exception:
310
- raise
311
- # Close the response
312
- resp.close()
313
- except Exception:
314
- raise
315
-
316
- # Update the create_event_source function in the manager
317
- manager._create_event_source = create_event_source
394
+ # Otherwise wait for completion
395
+ return self.wait_for_completion(task["id"])
318
396
 
319
- # Connect and wait for completion
320
- manager.connect()
321
397
 
322
- # At this point, we should have a final task state
323
- if final_task is not None:
324
- return final_task
325
-
326
- # Try to fetch the latest state as a fallback
327
- try:
328
- latest = self.get_task(task_id)
329
- status = latest.get("status")
330
- if status == TaskStatus.COMPLETED:
331
- return latest
332
- if status == TaskStatus.FAILED:
333
- raise RuntimeError(latest.get("error") or "task failed")
334
- if status == TaskStatus.CANCELLED:
335
- raise RuntimeError("task cancelled")
336
- except Exception as exc:
337
- raise
338
-
339
- raise RuntimeError("Stream ended without completion")
340
398
 
341
399
  def cancel(self, task_id: str) -> None:
342
400
  self._request("post", f"/tasks/{task_id}/cancel")
343
401
 
344
402
  def get_task(self, task_id: str) -> Dict[str, Any]:
403
+ """Get the current state of a task.
404
+
405
+ Args:
406
+ task_id: The ID of the task to get
407
+
408
+ Returns:
409
+ Dict[str, Any]: The current task state
410
+ """
345
411
  return self._request("get", f"/tasks/{task_id}")
346
412
 
413
+ def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
414
+ """Wait for a task to complete and return its final state.
415
+
416
+ This method polls the task status until it reaches a terminal state
417
+ (completed, failed, or cancelled).
418
+
419
+ Args:
420
+ task_id: The ID of the task to wait for
421
+
422
+ Returns:
423
+ Dict[str, Any]: The final task state
424
+
425
+ Raises:
426
+ RuntimeError: If the task fails or is cancelled
427
+ """
428
+ with self.stream_task(task_id) as stream:
429
+ for update in stream:
430
+ if update.get("status") == TaskStatus.COMPLETED:
431
+ return update
432
+ elif update.get("status") == TaskStatus.FAILED:
433
+ raise RuntimeError(update.get("error") or "Task failed")
434
+ elif update.get("status") == TaskStatus.CANCELLED:
435
+ raise RuntimeError("Task cancelled")
436
+ raise RuntimeError("Stream ended without completion")
437
+
347
438
  # --------------- File upload ---------------
348
439
  def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
349
440
  options = options or UploadFileOptions()
@@ -403,6 +494,103 @@ class Inference:
403
494
  return file_obj
404
495
 
405
496
  # --------------- Helpers ---------------
497
+ def stream_task(
498
+ self,
499
+ task_id: str,
500
+ *,
501
+ auto_reconnect: bool = True,
502
+ max_reconnects: int = 5,
503
+ reconnect_delay_ms: int = 1000,
504
+ ) -> TaskStream:
505
+ """Create a TaskStream for getting streaming updates from a task.
506
+
507
+ This provides a more Pythonic interface for handling task updates compared to callbacks.
508
+ The returned TaskStream can be used either as a context manager or as an iterator.
509
+
510
+ Args:
511
+ task_id: The ID of the task to stream
512
+ auto_reconnect: Whether to automatically reconnect on connection loss
513
+ max_reconnects: Maximum number of reconnection attempts
514
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
515
+
516
+ Returns:
517
+ TaskStream: A stream interface for the task
518
+
519
+ Example:
520
+ ```python
521
+ # Run a task
522
+ task = client.run(params)
523
+
524
+ # Stream updates using context manager
525
+ with client.stream_task(task["id"]) as stream:
526
+ for update in stream:
527
+ print(f"Status: {update.get('status')}")
528
+ if update.get("status") == TaskStatus.COMPLETED:
529
+ print(f"Result: {update.get('output')}")
530
+
531
+ # Or use as a simple iterator
532
+ for update in client.stream_task(task["id"]):
533
+ print(f"Update: {update}")
534
+ ```
535
+ """
536
+ task = self.get_task(task_id)
537
+ return TaskStream(
538
+ task=task,
539
+ client=self,
540
+ auto_reconnect=auto_reconnect,
541
+ max_reconnects=max_reconnects,
542
+ reconnect_delay_ms=reconnect_delay_ms,
543
+ )
544
+
545
+ def _stream_updates(
546
+ self,
547
+ task_id: str,
548
+ task: Dict[str, Any],
549
+ ) -> Generator[Union[Dict[str, Any], Exception], None, None]:
550
+ """Internal method to stream task updates."""
551
+ url = f"/tasks/{task_id}/stream"
552
+ resp = self._request(
553
+ "get",
554
+ url,
555
+ headers={
556
+ "Accept": "text/event-stream",
557
+ "Cache-Control": "no-cache",
558
+ "Accept-Encoding": "identity",
559
+ "Connection": "keep-alive",
560
+ },
561
+ stream=True,
562
+ timeout=60,
563
+ )
564
+ try:
565
+ for evt in self._iter_sse(resp):
566
+ try:
567
+ # Process the event to check for completion/errors
568
+ result = _process_stream_event(
569
+ evt,
570
+ task=task,
571
+ stopper=None, # We'll handle stopping via the iterator
572
+ )
573
+ if result is not None:
574
+ yield result
575
+ break
576
+ yield _strip_task(evt)
577
+ except Exception as exc:
578
+ yield exc
579
+ raise
580
+ finally:
581
+ try:
582
+ # Force close the underlying socket if possible
583
+ try:
584
+ raw = getattr(resp, 'raw', None)
585
+ if raw is not None:
586
+ raw.close()
587
+ except Exception:
588
+ raise
589
+ # Close the response
590
+ resp.close()
591
+ except Exception:
592
+ raise
593
+
406
594
  def _iter_sse(self, resp: Any, stream_manager: Optional[Any] = None) -> Generator[Dict[str, Any], None, None]:
407
595
  """Iterate JSON events from an SSE response."""
408
596
  # Mode 1: raw socket readline (can reduce buffering in some environments)
@@ -565,88 +753,114 @@ class AsyncInference:
565
753
  return payload.get("data")
566
754
 
567
755
  # --------------- Public API ---------------
568
- async def run(self, params: Dict[str, Any]) -> Dict[str, Any]:
569
- processed_input = await self._process_input_data(params.get("input"))
570
- task = await self._request("post", "/run", data={**params, "input": processed_input})
571
- return task
572
-
573
- async def run_sync(
756
+ async def run(
574
757
  self,
575
758
  params: Dict[str, Any],
576
759
  *,
760
+ wait: bool = True,
761
+ stream: bool = False,
577
762
  auto_reconnect: bool = True,
578
763
  max_reconnects: int = 5,
579
764
  reconnect_delay_ms: int = 1000,
580
- ) -> Dict[str, Any]:
765
+ ) -> Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
766
+ """Run a task with optional streaming updates.
767
+
768
+ By default, this method waits for the task to complete and returns the final result.
769
+ You can set wait=False to get just the task info, or stream=True to get an iterator
770
+ of status updates.
771
+
772
+ Args:
773
+ params: Task parameters to pass to the API
774
+ wait: Whether to wait for task completion (default: True)
775
+ stream: Whether to return an iterator of updates (default: False)
776
+ auto_reconnect: Whether to automatically reconnect on connection loss
777
+ max_reconnects: Maximum number of reconnection attempts
778
+ reconnect_delay_ms: Delay between reconnection attempts in milliseconds
779
+
780
+ Returns:
781
+ Union[Dict[str, Any], TaskStream, Iterator[Dict[str, Any]]]:
782
+ - If wait=True and stream=False: The completed task data
783
+ - If wait=False: The created task info
784
+ - If stream=True: An iterator of task updates
785
+
786
+ Example:
787
+ ```python
788
+ # Simple usage - wait for result (default)
789
+ result = await client.run(params)
790
+ print(f"Output: {result['output']}")
791
+
792
+ # Get task info without waiting
793
+ task = await client.run(params, wait=False)
794
+ task_id = task["id"]
795
+
796
+ # Stream updates
797
+ async for update in await client.run(params, stream=True):
798
+ print(f"Status: {update.get('status')}")
799
+ if update.get('status') == TaskStatus.COMPLETED:
800
+ print(f"Result: {update.get('output')}")
801
+ ```
802
+ """
803
+ # Create the task
581
804
  processed_input = await self._process_input_data(params.get("input"))
582
805
  task = await self._request("post", "/run", data={**params, "input": processed_input})
583
- task_id = task["id"]
584
-
585
- final_task: Optional[Dict[str, Any]] = None
586
- reconnect_attempts = 0
587
- had_success = False
588
-
589
- while True:
590
- try:
591
- resp = await self._request(
592
- "get",
593
- f"/tasks/{task_id}/stream",
594
- headers={
595
- "Accept": "text/event-stream",
596
- "Cache-Control": "no-cache",
597
- "Accept-Encoding": "identity",
598
- "Connection": "keep-alive",
599
- },
600
- timeout=60,
601
- expect_stream=True,
602
- )
603
- had_success = True
604
- async for data in self._aiter_sse(resp):
605
- result = _process_stream_event(
606
- data,
607
- task=task,
608
- stopper=None,
609
- )
610
- if result is not None:
611
- final_task = result
612
- break
613
- if final_task is not None:
614
- break
615
- except Exception as exc: # noqa: BLE001
616
- if not auto_reconnect:
617
- raise
618
- if not had_success:
619
- reconnect_attempts += 1
620
- if reconnect_attempts > max_reconnects:
621
- raise
622
- await _async_sleep(reconnect_delay_ms / 1000.0)
623
- else:
624
- if not auto_reconnect:
625
- break
626
- await _async_sleep(reconnect_delay_ms / 1000.0)
627
-
628
- if final_task is None:
629
- # Fallback: fetch latest task state in case stream ended without a terminal event
630
- try:
631
- latest = await self.get_task(task_id)
632
- status = latest.get("status")
633
- if status == TaskStatus.COMPLETED:
634
- return latest
635
- if status == TaskStatus.FAILED:
636
- raise RuntimeError(latest.get("error") or "task failed")
637
- if status == TaskStatus.CANCELLED:
638
- raise RuntimeError("task cancelled")
639
- except Exception:
640
- raise
641
- raise RuntimeError("Stream ended without completion")
642
- return final_task
806
+
807
+ # Return immediately if not waiting
808
+ if not wait and not stream:
809
+ return task
810
+
811
+ # Return stream if requested
812
+ if stream:
813
+ task_stream = TaskStream(
814
+ task=task,
815
+ client=self,
816
+ auto_reconnect=auto_reconnect,
817
+ max_reconnects=max_reconnects,
818
+ reconnect_delay_ms=reconnect_delay_ms,
819
+ )
820
+ return task_stream
821
+
822
+ # Otherwise wait for completion
823
+ return await self.wait_for_completion(task["id"])
643
824
 
644
825
  async def cancel(self, task_id: str) -> None:
645
826
  await self._request("post", f"/tasks/{task_id}/cancel")
646
827
 
647
828
  async def get_task(self, task_id: str) -> Dict[str, Any]:
829
+ """Get the current state of a task.
830
+
831
+ Args:
832
+ task_id: The ID of the task to get
833
+
834
+ Returns:
835
+ Dict[str, Any]: The current task state
836
+ """
648
837
  return await self._request("get", f"/tasks/{task_id}")
649
838
 
839
+ async def wait_for_completion(self, task_id: str) -> Dict[str, Any]:
840
+ """Wait for a task to complete and return its final state.
841
+
842
+ This method polls the task status until it reaches a terminal state
843
+ (completed, failed, or cancelled).
844
+
845
+ Args:
846
+ task_id: The ID of the task to wait for
847
+
848
+ Returns:
849
+ Dict[str, Any]: The final task state
850
+
851
+ Raises:
852
+ RuntimeError: If the task fails or is cancelled
853
+ """
854
+ with self.stream_task(task_id) as stream:
855
+ async for update in stream:
856
+ if update.get("status") == TaskStatus.COMPLETED:
857
+ return update
858
+ elif update.get("status") == TaskStatus.FAILED:
859
+ raise RuntimeError(update.get("error") or "Task failed")
860
+ elif update.get("status") == TaskStatus.CANCELLED:
861
+ raise RuntimeError("Task cancelled")
862
+ raise RuntimeError("Stream ended without completion")
863
+
650
864
  # --------------- File upload ---------------
651
865
  async def upload_file(self, data: Union[str, bytes], options: Optional[UploadFileOptions] = None) -> Dict[str, Any]:
652
866
  options = options or UploadFileOptions()
@@ -797,6 +1011,18 @@ def _looks_like_base64(value: str) -> bool:
797
1011
  return False
798
1012
 
799
1013
 
1014
+ def _strip_task(task: Dict[str, Any]) -> Dict[str, Any]:
1015
+ """Strip task to essential fields."""
1016
+ return {
1017
+ "id": task.get("id"),
1018
+ "created_at": task.get("created_at"),
1019
+ "updated_at": task.get("updated_at"),
1020
+ "input": task.get("input"),
1021
+ "output": task.get("output"),
1022
+ "logs": task.get("logs"),
1023
+ "status": task.get("status"),
1024
+ }
1025
+
800
1026
  def _process_stream_event(
801
1027
  data: Dict[str, Any], *, task: Dict[str, Any], stopper: Optional[Callable[[], None]] = None
802
1028
  ) -> Optional[Dict[str, Any]]:
@@ -804,16 +1030,9 @@ def _process_stream_event(
804
1030
  If stopper is provided, it will be called on terminal events to end streaming.
805
1031
  """
806
1032
  status = data.get("status")
807
- output = data.get("output")
808
- logs = data.get("logs")
809
1033
 
810
1034
  if status == TaskStatus.COMPLETED:
811
- result = {
812
- **task,
813
- "status": data.get("status"),
814
- "output": data.get("output"),
815
- "logs": data.get("logs") or [],
816
- }
1035
+ result = _strip_task(data)
817
1036
  if stopper:
818
1037
  stopper()
819
1038
  return result
@@ -4,7 +4,6 @@ import mimetypes
4
4
  import os
5
5
  import urllib.request
6
6
  import urllib.parse
7
- import tempfile
8
7
  import hashlib
9
8
  from pathlib import Path
10
9
  from tqdm import tqdm
@@ -119,12 +118,10 @@ class File(BaseModel):
119
118
  return
120
119
 
121
120
  print(f"Downloading URL: {original_url} to {cache_path}")
122
- tmp_file = None
123
121
  try:
124
- # Download to temporary file first to avoid partial downloads in cache
125
- suffix = os.path.splitext(urllib.parse.urlparse(original_url).path)[1]
126
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
127
- self._tmp_path = tmp_file.name
122
+ # Download to a temporary filename in the final directory
123
+ tmp_path = str(cache_path) + '.tmp'
124
+ self._tmp_path = tmp_path
128
125
 
129
126
  # Set up request with user agent
130
127
  headers = {
@@ -176,8 +173,8 @@ class File(BaseModel):
176
173
  # If we read the whole body at once, exit loop
177
174
  break
178
175
 
179
- # Move the temporary file to the cache location
180
- os.replace(self._tmp_path, cache_path)
176
+ # Rename the temporary file to the final name
177
+ os.rename(self._tmp_path, cache_path)
181
178
  self._tmp_path = None # Prevent deletion in __del__
182
179
  self.path = str(cache_path)
183
180
  except (urllib.error.URLError, urllib.error.HTTPError) as e:
@@ -186,7 +183,7 @@ class File(BaseModel):
186
183
  raise RuntimeError(f"Failed to write downloaded file to {self._tmp_path}: {str(e)}")
187
184
  except Exception as e:
188
185
  # Clean up temp file if something went wrong
189
- if tmp_file is not None and hasattr(self, '_tmp_path'):
186
+ if hasattr(self, '_tmp_path') and self._tmp_path:
190
187
  try:
191
188
  os.unlink(self._tmp_path)
192
189
  except (OSError, IOError):
@@ -0,0 +1,196 @@
1
+ Metadata-Version: 2.4
2
+ Name: inferencesh
3
+ Version: 0.4.2
4
+ Summary: inference.sh Python SDK
5
+ Author-email: "Inference Shell Inc." <hello@inference.sh>
6
+ Project-URL: Homepage, https://github.com/inference-sh/sdk
7
+ Project-URL: Bug Tracker, https://github.com/inference-sh/sdk/issues
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.7
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: pydantic>=2.0.0
15
+ Requires-Dist: tqdm>=4.67.0
16
+ Requires-Dist: requests>=2.31.0
17
+ Provides-Extra: test
18
+ Requires-Dist: pytest>=7.0.0; extra == "test"
19
+ Requires-Dist: pytest-cov>=4.0.0; extra == "test"
20
+ Provides-Extra: async
21
+ Requires-Dist: aiohttp>=3.9.0; python_version >= "3.8" and extra == "async"
22
+ Requires-Dist: aiofiles>=23.2.1; python_version >= "3.8" and extra == "async"
23
+ Dynamic: license-file
24
+
25
+ # inference.sh sdk
26
+
27
+ helper package for inference.sh python applications.
28
+
29
+ ## installation
30
+
31
+ ```bash
32
+ pip install infsh
33
+ ```
34
+
35
+ ## client usage
36
+
37
+ ```python
38
+ from inferencesh import Inference, TaskStatus
39
+
40
+ # Create client
41
+ client = Inference(api_key="your-api-key")
42
+
43
+ # Simple synchronous usage
44
+ try:
45
+ task = client.run({
46
+ "app": "your-app",
47
+ "input": {"key": "value"},
48
+ "infra": "cloud",
49
+ "variant": "default"
50
+ })
51
+
52
+ print(f"Task ID: {task.get('id')}")
53
+
54
+ if task.get("status") == TaskStatus.COMPLETED:
55
+ print("✓ Task completed successfully!")
56
+ print(f"Output: {task.get('output')}")
57
+ else:
58
+ status = task.get("status")
59
+ status_name = TaskStatus(status).name if status is not None else "UNKNOWN"
60
+ print(f"✗ Task did not complete. Final status: {status_name}")
61
+
62
+ except Exception as exc:
63
+ print(f"Error: {type(exc).__name__}: {exc}")
64
+ raise # Re-raise to see full traceback
65
+
66
+ # Streaming updates (recommended)
67
+ try:
68
+ for update in client.run(
69
+ {
70
+ "app": "your-app",
71
+ "input": {"key": "value"},
72
+ "infra": "cloud",
73
+ "variant": "default"
74
+ },
75
+ stream=True # Enable streaming updates
76
+ ):
77
+ status = update.get("status")
78
+ status_name = TaskStatus(status).name if status is not None else "UNKNOWN"
79
+ print(f"Status: {status_name}")
80
+
81
+ if status == TaskStatus.COMPLETED:
82
+ print("✓ Task completed!")
83
+ print(f"Output: {update.get('output')}")
84
+ break
85
+ elif status == TaskStatus.FAILED:
86
+ print(f"✗ Task failed: {update.get('error')}")
87
+ break
88
+ elif status == TaskStatus.CANCELLED:
89
+ print("✗ Task was cancelled")
90
+ break
91
+
92
+ except Exception as exc:
93
+ print(f"Error: {type(exc).__name__}: {exc}")
94
+ raise # Re-raise to see full traceback
95
+
96
+ # Async support
97
+ async def run_async():
98
+ from inferencesh import AsyncInference
99
+
100
+ client = AsyncInference(api_key="your-api-key")
101
+
102
+ # Simple usage
103
+ result = await client.run({
104
+ "app": "your-app",
105
+ "input": {"key": "value"},
106
+ "infra": "cloud",
107
+ "variant": "default"
108
+ })
109
+
110
+ # Stream updates
111
+ async for update in await client.run(
112
+ {
113
+ "app": "your-app",
114
+ "input": {"key": "value"},
115
+ "infra": "cloud",
116
+ "variant": "default"
117
+ },
118
+ stream=True
119
+ ):
120
+ status = update.get("status")
121
+ status_name = TaskStatus(status).name if status is not None else "UNKNOWN"
122
+ print(f"Status: {status_name}")
123
+ ```
124
+
125
+ ## file handling
126
+
127
+ the `File` class provides a standardized way to handle files in the inference.sh ecosystem:
128
+
129
+ ```python
130
+ from infsh import File
131
+
132
+ # Basic file creation
133
+ file = File(path="/path/to/file.png")
134
+
135
+ # File with explicit metadata
136
+ file = File(
137
+ path="/path/to/file.png",
138
+ content_type="image/png",
139
+ filename="custom_name.png",
140
+ size=1024 # in bytes
141
+ )
142
+
143
+ # Create from path (automatically populates metadata)
144
+ file = File.from_path("/path/to/file.png")
145
+
146
+ # Check if file exists
147
+ exists = file.exists()
148
+
149
+ # Access file metadata
150
+ print(file.content_type) # automatically detected if not specified
151
+ print(file.size) # file size in bytes
152
+ print(file.filename) # basename of the file
153
+
154
+ # Refresh metadata (useful if file has changed)
155
+ file.refresh_metadata()
156
+ ```
157
+
158
+ the `File` class automatically handles:
159
+ - mime type detection
160
+ - file size calculation
161
+ - filename extraction from path
162
+ - file existence checking
163
+
164
+ ## creating an app
165
+
166
+ to create an inference app, inherit from `BaseApp` and define your input/output types:
167
+
168
+ ```python
169
+ from infsh import BaseApp, BaseAppInput, BaseAppOutput, File
170
+
171
+ class AppInput(BaseAppInput):
172
+ image: str # URL or file path to image
173
+ mask: str # URL or file path to mask
174
+
175
+ class AppOutput(BaseAppOutput):
176
+ image: File
177
+
178
+ class MyApp(BaseApp):
179
+ async def setup(self):
180
+ # Initialize your model here
181
+ pass
182
+
183
+ async def run(self, app_input: AppInput) -> AppOutput:
184
+ # Process input and return output
185
+ result_path = "/tmp/result.png"
186
+ return AppOutput(image=File(path=result_path))
187
+
188
+ async def unload(self):
189
+ # Clean up resources
190
+ pass
191
+ ```
192
+
193
+ app lifecycle has three main methods:
194
+ - `setup()`: called when the app starts, use it to initialize models
195
+ - `run()`: called for each inference request
196
+ - `unload()`: called when shutting down, use it to free resources
@@ -1,15 +1,15 @@
1
1
  inferencesh/__init__.py,sha256=dY3l3yCkWoMtGX0gNXgxFnrprFRl6PPWjH8V7Qedx5g,772
2
- inferencesh/client.py,sha256=cm7E-8LxP8jyb7JnANmcBtF1ya2i7sMBW2Pq-Oh-mcM,31318
2
+ inferencesh/client.py,sha256=6wTCLqLq-QapvjCjMg8ZE3BQyg8iTL8hv8UU7t-oxmE,39360
3
3
  inferencesh/models/__init__.py,sha256=FDwcdtT6c4hbRitymjmN-hZMlQa8RbKSftkZZyjtUXA,536
4
4
  inferencesh/models/base.py,sha256=4gZQRi8J7y9U6PrGD9pRIehd1MJVJAqGakPQDs2AKFM,3251
5
- inferencesh/models/file.py,sha256=uh1czgk0KFl_9RHTODX0PkdnI42MSU8QMJR_I4lVKI4,10556
5
+ inferencesh/models/file.py,sha256=V8p5JwCzrXQMSHPsXOf8eeGxTHLFQpQqpG7AL3v0wKo,10374
6
6
  inferencesh/models/llm.py,sha256=GLcEkDizBbgcfc-zC719wDe44th3EGf3FpKERjIAPE8,27755
7
7
  inferencesh/utils/__init__.py,sha256=-xiD6uo2XzcrPAWFb_fUbaimmnW4KFKc-8IvBzaxNd4,148
8
8
  inferencesh/utils/download.py,sha256=DRGBudiPVa5bDS35KfR-DYeGRk7gO03WOelnisecwMo,1815
9
9
  inferencesh/utils/storage.py,sha256=E4J8emd4eFKdmdDgAqzz3TpaaDd3n0l8gYlMHuY8yIU,519
10
- inferencesh-0.4.0.dist-info/licenses/LICENSE,sha256=OsgqEWIh2el_QMj0y8O1A5Q5Dl-dxqqYbFE6fszuR4s,1086
11
- inferencesh-0.4.0.dist-info/METADATA,sha256=pHnblJABrxy5Iy81hpP7nV-J72Tp2JIUJ6D2UzVbSqo,2964
12
- inferencesh-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- inferencesh-0.4.0.dist-info/entry_points.txt,sha256=6IC-fyozAqW3ljsMLGCXxJ0_ui2Jb-2fLHtoH1RTnEE,45
14
- inferencesh-0.4.0.dist-info/top_level.txt,sha256=TSMHg3T1ThMl1HGAWmzBClwOYH1ump5neof9BfHIwaA,12
15
- inferencesh-0.4.0.dist-info/RECORD,,
10
+ inferencesh-0.4.2.dist-info/licenses/LICENSE,sha256=OsgqEWIh2el_QMj0y8O1A5Q5Dl-dxqqYbFE6fszuR4s,1086
11
+ inferencesh-0.4.2.dist-info/METADATA,sha256=7pTRbdMbhqbSc9xGqrwasgYkuHuS1NlcbrN0izdRNvk,5405
12
+ inferencesh-0.4.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ inferencesh-0.4.2.dist-info/entry_points.txt,sha256=6IC-fyozAqW3ljsMLGCXxJ0_ui2Jb-2fLHtoH1RTnEE,45
14
+ inferencesh-0.4.2.dist-info/top_level.txt,sha256=TSMHg3T1ThMl1HGAWmzBClwOYH1ump5neof9BfHIwaA,12
15
+ inferencesh-0.4.2.dist-info/RECORD,,
@@ -1,109 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: inferencesh
3
- Version: 0.4.0
4
- Summary: inference.sh Python SDK
5
- Author: Inference Shell Inc.
6
- Author-email: "Inference Shell Inc." <hello@inference.sh>
7
- Project-URL: Homepage, https://github.com/inference-sh/sdk
8
- Project-URL: Bug Tracker, https://github.com/inference-sh/sdk/issues
9
- Classifier: Programming Language :: Python :: 3
10
- Classifier: License :: OSI Approved :: MIT License
11
- Classifier: Operating System :: OS Independent
12
- Requires-Python: >=3.7
13
- Description-Content-Type: text/markdown
14
- License-File: LICENSE
15
- Requires-Dist: pydantic>=2.0.0
16
- Requires-Dist: tqdm>=4.67.0
17
- Requires-Dist: requests>=2.31.0
18
- Provides-Extra: test
19
- Requires-Dist: pytest>=7.0.0; extra == "test"
20
- Requires-Dist: pytest-cov>=4.0.0; extra == "test"
21
- Provides-Extra: async
22
- Requires-Dist: aiohttp>=3.9.0; python_version >= "3.8" and extra == "async"
23
- Requires-Dist: aiofiles>=23.2.1; python_version >= "3.8" and extra == "async"
24
- Dynamic: author
25
- Dynamic: license-file
26
- Dynamic: requires-python
27
-
28
- # inference.sh CLI
29
-
30
- Helper package for inference.sh Python applications.
31
-
32
- ## Installation
33
-
34
- ```bash
35
- pip install infsh
36
- ```
37
-
38
- ## File Handling
39
-
40
- The `File` class provides a standardized way to handle files in the inference.sh ecosystem:
41
-
42
- ```python
43
- from infsh import File
44
-
45
- # Basic file creation
46
- file = File(path="/path/to/file.png")
47
-
48
- # File with explicit metadata
49
- file = File(
50
- path="/path/to/file.png",
51
- content_type="image/png",
52
- filename="custom_name.png",
53
- size=1024 # in bytes
54
- )
55
-
56
- # Create from path (automatically populates metadata)
57
- file = File.from_path("/path/to/file.png")
58
-
59
- # Check if file exists
60
- exists = file.exists()
61
-
62
- # Access file metadata
63
- print(file.content_type) # automatically detected if not specified
64
- print(file.size) # file size in bytes
65
- print(file.filename) # basename of the file
66
-
67
- # Refresh metadata (useful if file has changed)
68
- file.refresh_metadata()
69
- ```
70
-
71
- The `File` class automatically handles:
72
- - MIME type detection
73
- - File size calculation
74
- - Filename extraction from path
75
- - File existence checking
76
-
77
- ## Creating an App
78
-
79
- To create an inference app, inherit from `BaseApp` and define your input/output types:
80
-
81
- ```python
82
- from infsh import BaseApp, BaseAppInput, BaseAppOutput, File
83
-
84
- class AppInput(BaseAppInput):
85
- image: str # URL or file path to image
86
- mask: str # URL or file path to mask
87
-
88
- class AppOutput(BaseAppOutput):
89
- image: File
90
-
91
- class MyApp(BaseApp):
92
- async def setup(self):
93
- # Initialize your model here
94
- pass
95
-
96
- async def run(self, app_input: AppInput) -> AppOutput:
97
- # Process input and return output
98
- result_path = "/tmp/result.png"
99
- return AppOutput(image=File(path=result_path))
100
-
101
- async def unload(self):
102
- # Clean up resources
103
- pass
104
- ```
105
-
106
- The app lifecycle has three main methods:
107
- - `setup()`: Called when the app starts, use it to initialize models
108
- - `run()`: Called for each inference request
109
- - `unload()`: Called when shutting down, use it to free resources