shared-tensor 0.1.2__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. shared_tensor-0.2.1/PKG-INFO +177 -0
  2. shared_tensor-0.2.1/README.md +126 -0
  3. {shared_tensor-0.1.2 → shared_tensor-0.2.1}/pyproject.toml +21 -51
  4. {shared_tensor-0.1.2 → shared_tensor-0.2.1}/shared_tensor/__init__.py +12 -19
  5. shared_tensor-0.2.1/shared_tensor/async_client.py +166 -0
  6. shared_tensor-0.2.1/shared_tensor/async_provider.py +138 -0
  7. shared_tensor-0.2.1/shared_tensor/async_task.py +254 -0
  8. shared_tensor-0.2.1/shared_tensor/client.py +154 -0
  9. shared_tensor-0.2.1/shared_tensor/errors.py +56 -0
  10. shared_tensor-0.2.1/shared_tensor/jsonrpc.py +116 -0
  11. shared_tensor-0.2.1/shared_tensor/provider.py +139 -0
  12. shared_tensor-0.2.1/shared_tensor/server.py +381 -0
  13. shared_tensor-0.2.1/shared_tensor/utils.py +272 -0
  14. shared_tensor-0.1.2/PKG-INFO +0 -432
  15. shared_tensor-0.1.2/README.md +0 -379
  16. shared_tensor-0.1.2/shared_tensor/async_client.py +0 -302
  17. shared_tensor-0.1.2/shared_tensor/async_provider.py +0 -177
  18. shared_tensor-0.1.2/shared_tensor/async_task.py +0 -361
  19. shared_tensor-0.1.2/shared_tensor/client.py +0 -265
  20. shared_tensor-0.1.2/shared_tensor/errors.py +0 -16
  21. shared_tensor-0.1.2/shared_tensor/jsonrpc.py +0 -163
  22. shared_tensor-0.1.2/shared_tensor/provider.py +0 -160
  23. shared_tensor-0.1.2/shared_tensor/server.py +0 -458
  24. shared_tensor-0.1.2/shared_tensor/utils.py +0 -122
  25. {shared_tensor-0.1.2 → shared_tensor-0.2.1}/LICENSE +0 -0
  26. {shared_tensor-0.1.2 → shared_tensor-0.2.1}/MANIFEST.in +0 -0
  27. {shared_tensor-0.1.2 → shared_tensor-0.2.1}/setup.cfg +0 -0
  28. {shared_tensor-0.1.2 → shared_tensor-0.2.1}/shared_tensor.egg-info/SOURCES.txt +0 -0
@@ -0,0 +1,177 @@
1
+ Metadata-Version: 2.4
2
+ Name: shared-tensor
3
+ Version: 0.2.1
4
+ Summary: Local endpoint-oriented RPC for same-host same-GPU PyTorch IPC
5
+ Author-email: Athena Team <contact@world-sim-dev.org>
6
+ Maintainer-email: Athena Team <contact@world-sim-dev.org>
7
+ License-Expression: Apache-2.0
8
+ Project-URL: Homepage, https://github.com/world-sim-dev/shared-tensor
9
+ Project-URL: Repository, https://github.com/world-sim-dev/shared-tensor
10
+ Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/wiki
11
+ Project-URL: Bug Reports, https://github.com/world-sim-dev/shared-tensor/issues
12
+ Project-URL: Changelog, https://github.com/world-sim-dev/shared-tensor/releases
13
+ Keywords: gpu,memory,sharing,ipc,inter-process-communication,pytorch,cuda,model-serving,inference,torch,torch-ipc
14
+ Classifier: Development Status :: 3 - Alpha
15
+ Classifier: Intended Audience :: Developers
16
+ Classifier: Intended Audience :: Science/Research
17
+ Classifier: Operating System :: POSIX :: Linux
18
+ Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
+ Classifier: Topic :: System :: Distributed Computing
25
+ Requires-Python: >=3.10
26
+ Description-Content-Type: text/markdown
27
+ License-File: LICENSE
28
+ Requires-Dist: numpy<2
29
+ Requires-Dist: requests>=2.25.0
30
+ Requires-Dist: torch>=2.2.0
31
+ Provides-Extra: dev
32
+ Requires-Dist: pytest>=6.0; extra == "dev"
33
+ Requires-Dist: pytest-cov>=2.0; extra == "dev"
34
+ Requires-Dist: types-requests>=2.32.0; extra == "dev"
35
+ Requires-Dist: black>=22.0; extra == "dev"
36
+ Requires-Dist: isort>=5.0; extra == "dev"
37
+ Requires-Dist: mypy>=0.950; extra == "dev"
38
+ Requires-Dist: pre-commit>=2.0.0; extra == "dev"
39
+ Requires-Dist: build>=0.8.0; extra == "dev"
40
+ Requires-Dist: twine>=4.0.0; extra == "dev"
41
+ Requires-Dist: ruff>=0.6.0; extra == "dev"
42
+ Provides-Extra: test
43
+ Requires-Dist: pytest>=6.0; extra == "test"
44
+ Requires-Dist: pytest-cov>=2.0; extra == "test"
45
+ Requires-Dist: pytest-asyncio>=0.20.0; extra == "test"
46
+ Provides-Extra: docs
47
+ Requires-Dist: sphinx>=4.0.0; extra == "docs"
48
+ Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == "docs"
49
+ Requires-Dist: myst-parser>=0.18.0; extra == "docs"
50
+ Dynamic: license-file
51
+
52
+ # Shared Tensor
53
+
54
+ `shared_tensor` is a localhost-only RPC layer for one thing: passing CUDA `torch.Tensor` and CUDA `torch.nn.Module` objects between processes on the same machine and the same GPU with native PyTorch IPC semantics.
55
+
56
+ ## What It Supports
57
+
58
+ - same-host, trusted-process deployment
59
+ - same-GPU CUDA object handoff
60
+ - native `torch` tensors and modules
61
+ - explicit endpoint registration
62
+ - sync calls and async task polling
63
+
64
+ ## What It Does Not Support
65
+
66
+ - CPU tensor transport
67
+ - generic Python object RPC
68
+ - cross-machine transport
69
+ - macOS `mps`
70
+ - silent CPU fallback or implicit device migration
71
+
72
+ ## Payload Contract
73
+
74
+ Allowed payloads:
75
+
76
+ - CUDA `torch.Tensor`
77
+ - CUDA `torch.nn.Module`
78
+ - `tuple`, `list`, and `dict[str, ...]` containers built from those values for `args` and `kwargs`
79
+ - empty `args` / `kwargs` through the control path for no-argument calls only
80
+
81
+ Rejected payloads:
82
+
83
+ - CPU tensors and CPU modules
84
+ - plain Python values such as `int`, `str`, `dict`, and `list`
85
+ - `mps` tensors and modules
86
+
87
+ ## Install
88
+
89
+ Use Python `3.10+` and a CUDA-enabled PyTorch build.
90
+
91
+ ```bash
92
+ pip install shared-tensor
93
+ ```
94
+
95
+ For development:
96
+
97
+ ```bash
98
+ conda create -y -n shared-tensor-dev python=3.11
99
+ conda activate shared-tensor-dev
100
+ pip install -e ".[dev,test]"
101
+ ```
102
+
103
+ ## Typical Example
104
+
105
+ Provider process:
106
+
107
+ ```python
108
+ import torch
109
+
110
+ from shared_tensor import SharedTensorProvider
111
+
112
+ provider = SharedTensorProvider(execution_mode="server")
113
+
114
+
115
+ @provider.share(name="load_model")
116
+ def load_model() -> torch.nn.Module:
117
+ return torch.nn.Linear(4, 2, device="cuda")
118
+
119
+
120
+ @provider.share(name="identity")
121
+ def identity(tensor: torch.Tensor) -> torch.Tensor:
122
+ return tensor
123
+ ```
124
+
125
+ Run the server:
126
+
127
+ ```bash
128
+ shared-tensor-server --provider my_service:provider --host 127.0.0.1 --port 2537
129
+ ```
130
+
131
+ Consumer process:
132
+
133
+ ```python
134
+ import torch
135
+
136
+ from shared_tensor import SharedTensorClient
137
+
138
+ with SharedTensorClient(port=2537) as client:
139
+ model = client.call("load_model")
140
+ x = torch.ones(1, 4, device="cuda")
141
+ y = model(x)
142
+
143
+ shared = client.call("identity", x)
144
+ ```
145
+
146
+ ## Test Matrix
147
+
148
+ Default local run:
149
+
150
+ ```bash
151
+ python -m pytest -m "not gpu"
152
+ ```
153
+
154
+ CUDA run:
155
+
156
+ ```bash
157
+ python -m pytest -m gpu
158
+ ```
159
+
160
+ `skipped` means the test was intentionally not run because its precondition was missing. In this repo that usually means a `gpu` test was executed on a machine where `torch.cuda.is_available()` was false. It is not a failure.
161
+
162
+ Current validation target:
163
+
164
+ - local non-GPU suite passes
165
+ - H100 CUDA suite passes
166
+
167
+ ## Operational Notes
168
+
169
+ - This library assumes a trusted same-host environment.
170
+ - The server process must be a separate process from the client when using CUDA IPC.
171
+ - If you need cross-machine transport or CPU object RPC, use a different tool.
172
+
173
+ ## Repo Notes
174
+
175
+ - `CLAUDE.md` captures repo maintenance rules.
176
+ - `examples/basic_service.py` shows the minimal sync flow.
177
+ - `examples/model_service.py` shows model handoff.
@@ -0,0 +1,126 @@
1
+ # Shared Tensor
2
+
3
+ `shared_tensor` is a localhost-only RPC layer for one thing: passing CUDA `torch.Tensor` and CUDA `torch.nn.Module` objects between processes on the same machine and the same GPU with native PyTorch IPC semantics.
4
+
5
+ ## What It Supports
6
+
7
+ - same-host, trusted-process deployment
8
+ - same-GPU CUDA object handoff
9
+ - native `torch` tensors and modules
10
+ - explicit endpoint registration
11
+ - sync calls and async task polling
12
+
13
+ ## What It Does Not Support
14
+
15
+ - CPU tensor transport
16
+ - generic Python object RPC
17
+ - cross-machine transport
18
+ - macOS `mps`
19
+ - silent CPU fallback or implicit device migration
20
+
21
+ ## Payload Contract
22
+
23
+ Allowed payloads:
24
+
25
+ - CUDA `torch.Tensor`
26
+ - CUDA `torch.nn.Module`
27
+ - `tuple`, `list`, and `dict[str, ...]` containers built from those values for `args` and `kwargs`
28
+ - empty `args` / `kwargs` through the control path for no-argument calls only
29
+
30
+ Rejected payloads:
31
+
32
+ - CPU tensors and CPU modules
33
+ - plain Python values such as `int`, `str`, `dict`, and `list`
34
+ - `mps` tensors and modules
35
+
36
+ ## Install
37
+
38
+ Use Python `3.10+` and a CUDA-enabled PyTorch build.
39
+
40
+ ```bash
41
+ pip install shared-tensor
42
+ ```
43
+
44
+ For development:
45
+
46
+ ```bash
47
+ conda create -y -n shared-tensor-dev python=3.11
48
+ conda activate shared-tensor-dev
49
+ pip install -e ".[dev,test]"
50
+ ```
51
+
52
+ ## Typical Example
53
+
54
+ Provider process:
55
+
56
+ ```python
57
+ import torch
58
+
59
+ from shared_tensor import SharedTensorProvider
60
+
61
+ provider = SharedTensorProvider(execution_mode="server")
62
+
63
+
64
+ @provider.share(name="load_model")
65
+ def load_model() -> torch.nn.Module:
66
+ return torch.nn.Linear(4, 2, device="cuda")
67
+
68
+
69
+ @provider.share(name="identity")
70
+ def identity(tensor: torch.Tensor) -> torch.Tensor:
71
+ return tensor
72
+ ```
73
+
74
+ Run the server:
75
+
76
+ ```bash
77
+ shared-tensor-server --provider my_service:provider --host 127.0.0.1 --port 2537
78
+ ```
79
+
80
+ Consumer process:
81
+
82
+ ```python
83
+ import torch
84
+
85
+ from shared_tensor import SharedTensorClient
86
+
87
+ with SharedTensorClient(port=2537) as client:
88
+ model = client.call("load_model")
89
+ x = torch.ones(1, 4, device="cuda")
90
+ y = model(x)
91
+
92
+ shared = client.call("identity", x)
93
+ ```
94
+
95
+ ## Test Matrix
96
+
97
+ Default local run:
98
+
99
+ ```bash
100
+ python -m pytest -m "not gpu"
101
+ ```
102
+
103
+ CUDA run:
104
+
105
+ ```bash
106
+ python -m pytest -m gpu
107
+ ```
108
+
109
+ `skipped` means the test was intentionally not run because its precondition was missing. In this repo that usually means a `gpu` test was executed on a machine where `torch.cuda.is_available()` was false. It is not a failure.
110
+
111
+ Current validation target:
112
+
113
+ - local non-GPU suite passes
114
+ - H100 CUDA suite passes
115
+
116
+ ## Operational Notes
117
+
118
+ - This library assumes a trusted same-host environment.
119
+ - The server process must be a separate process from the client when using CUDA IPC.
120
+ - If you need cross-machine transport or CPU object RPC, use a different tool.
121
+
122
+ ## Repo Notes
123
+
124
+ - `CLAUDE.md` captures repo maintenance rules.
125
+ - `examples/basic_service.py` shows the minimal sync flow.
126
+ - `examples/model_service.py` shows model handoff.
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.1.2"
8
- description = "A library for sharing GPU memory objects across processes using IPC mechanisms"
7
+ version = "0.2.1"
8
+ description = "Local endpoint-oriented RPC for same-host same-GPU PyTorch IPC"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
11
11
  authors = [
@@ -16,16 +16,16 @@ maintainers = [
16
16
  ]
17
17
  keywords = [
18
18
  "gpu",
19
- "memory",
19
+ "memory",
20
20
  "sharing",
21
21
  "ipc",
22
22
  "inter-process-communication",
23
23
  "pytorch",
24
- "tensorflow",
25
24
  "cuda",
26
25
  "model-serving",
27
26
  "inference",
28
- "distributed-computing"
27
+ "torch",
28
+ "torch-ipc"
29
29
  ]
30
30
  classifiers = [
31
31
  "Development Status :: 3 - Alpha",
@@ -33,38 +33,36 @@ classifiers = [
33
33
  "Intended Audience :: Science/Research",
34
34
  "Operating System :: POSIX :: Linux",
35
35
  "Programming Language :: Python :: 3",
36
- "Programming Language :: Python :: 3.8",
37
- "Programming Language :: Python :: 3.9",
38
36
  "Programming Language :: Python :: 3.10",
39
37
  "Programming Language :: Python :: 3.11",
40
38
  "Programming Language :: Python :: 3.12",
41
39
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
42
40
  "Topic :: Software Development :: Libraries :: Python Modules",
43
- "Topic :: System :: Hardware :: Symmetric Multi-processing",
41
+ "Topic :: System :: Distributed Computing",
44
42
  ]
45
- requires-python = ">=3.8"
43
+ requires-python = ">=3.10"
46
44
  dependencies = [
47
- "torch>=1.12.0",
48
- "numpy>=1.20.0",
45
+ "numpy<2",
49
46
  "requests>=2.25.0",
47
+ "torch>=2.2.0",
50
48
  ]
51
49
 
52
50
  [project.optional-dependencies]
53
51
  dev = [
54
52
  "pytest>=6.0",
55
53
  "pytest-cov>=2.0",
54
+ "types-requests>=2.32.0",
56
55
  "black>=22.0",
57
- "flake8>=4.0",
58
56
  "isort>=5.0",
59
57
  "mypy>=0.950",
60
58
  "pre-commit>=2.0.0",
61
59
  "build>=0.8.0",
62
60
  "twine>=4.0.0",
61
+ "ruff>=0.6.0",
63
62
  ]
64
63
  test = [
65
64
  "pytest>=6.0",
66
65
  "pytest-cov>=2.0",
67
- "pytest-benchmark>=3.0",
68
66
  "pytest-asyncio>=0.20.0",
69
67
  ]
70
68
  docs = [
@@ -93,10 +91,9 @@ exclude = ["tests*", "examples*", "docs*"]
93
91
  [tool.setuptools.package-data]
94
92
  shared_tensor = ["*.so", "*.dll", "*.dylib"]
95
93
 
96
- # Black configuration
97
94
  [tool.black]
98
95
  line-length = 88
99
- target-version = ['py38', 'py39', 'py310', 'py311']
96
+ target-version = ['py310', 'py311', 'py312']
100
97
  include = '\.pyi?$'
101
98
  extend-exclude = '''
102
99
  /(
@@ -112,7 +109,6 @@ extend-exclude = '''
112
109
  )/
113
110
  '''
114
111
 
115
- # isort configuration
116
112
  [tool.isort]
117
113
  profile = "black"
118
114
  multi_line_output = 3
@@ -122,9 +118,8 @@ force_grid_wrap = 0
122
118
  use_parentheses = true
123
119
  ensure_newline_before_comments = true
124
120
 
125
- # mypy configuration
126
121
  [tool.mypy]
127
- python_version = "3.8"
122
+ python_version = "3.10"
128
123
  warn_return_any = true
129
124
  warn_unused_configs = true
130
125
  disallow_untyped_defs = true
@@ -141,11 +136,10 @@ strict_equality = true
141
136
  [[tool.mypy.overrides]]
142
137
  module = [
143
138
  "torch.*",
144
- "numpy.*",
139
+ "requests.*",
145
140
  ]
146
141
  ignore_missing_imports = true
147
142
 
148
- # pytest configuration
149
143
  [tool.pytest.ini_options]
150
144
  minversion = "6.0"
151
145
  addopts = [
@@ -168,7 +162,6 @@ markers = [
168
162
  "unit: marks tests as unit tests",
169
163
  ]
170
164
 
171
- # Coverage configuration
172
165
  [tool.coverage.run]
173
166
  source = ["shared_tensor"]
174
167
  omit = [
@@ -191,40 +184,17 @@ exclude_lines = [
191
184
  "@(abc\\.)?abstractmethod",
192
185
  ]
193
186
 
194
- # flake8 configuration (in setup.cfg format within pyproject.toml comments)
195
- # [flake8]
196
- # max-line-length = 88
197
- # extend-ignore = E203, E266, E501, W503
198
- # max-complexity = 10
199
- # select = B,C,E,F,W,T4,B9
200
-
201
- # Ruff configuration (modern alternative to flake8)
202
187
  [tool.ruff]
203
- target-version = "py38"
188
+ target-version = "py310"
204
189
  line-length = 88
205
- select = [
206
- "E", # pycodestyle errors
207
- "W", # pycodestyle warnings
208
- "F", # pyflakes
209
- "I", # isort
210
- "B", # flake8-bugbear
211
- "C4", # flake8-comprehensions
212
- "UP", # pyupgrade
213
- ]
214
- ignore = [
215
- "E501", # line too long, handled by black
216
- "B008", # do not perform function calls in argument defaults
217
- "C901", # too complex
218
- ]
219
190
 
220
- [tool.ruff.per-file-ignores]
191
+ [tool.ruff.lint]
192
+ select = ["E", "W", "F", "I", "B", "C4", "UP"]
193
+ ignore = ["E501", "B008", "C901"]
194
+
195
+ [tool.ruff.lint.per-file-ignores]
221
196
  "__init__.py" = ["F401"]
222
197
  "tests/**/*" = ["B011", "B018"]
223
198
 
224
- [tool.ruff.isort]
199
+ [tool.ruff.lint.isort]
225
200
  known-first-party = ["shared_tensor"]
226
-
227
- # Bandit security linter configuration
228
- [tool.bandit]
229
- exclude_dirs = ["tests", "examples"]
230
- skips = ["B101", "B601"]
@@ -1,27 +1,20 @@
1
- """
2
- Shared Tensor Library
1
+ """shared_tensor: local endpoint-oriented RPC for Python and PyTorch."""
3
2
 
4
- A library for sharing GPU memory objects across processes using IPC mechanisms.
5
- Enables model and inference engine separation architecture using JSON-RPC 2.0 protocol.
6
- """
7
-
8
- from shared_tensor.provider import SharedTensorProvider
3
+ from shared_tensor.async_client import AsyncSharedTensorClient
4
+ from shared_tensor.async_provider import AsyncSharedTensorProvider
5
+ from shared_tensor.async_task import TaskInfo, TaskStatus
9
6
  from shared_tensor.client import SharedTensorClient
7
+ from shared_tensor.provider import SharedTensorProvider
10
8
  from shared_tensor.server import SharedTensorServer
11
- from shared_tensor.async_provider import AsyncSharedTensorProvider
12
- from shared_tensor.async_client import AsyncSharedTensorClient
13
- from shared_tensor.async_task import TaskStatus, TaskInfo
14
9
 
15
- __version__ = "0.1.0"
16
- __author__ = "Athena Team"
17
-
18
- # Export main functionality
19
10
  __all__ = [
20
- "SharedTensorProvider",
11
+ "AsyncSharedTensorClient",
12
+ "AsyncSharedTensorProvider",
21
13
  "SharedTensorClient",
14
+ "SharedTensorProvider",
22
15
  "SharedTensorServer",
23
- "AsyncSharedTensorProvider",
24
- "AsyncSharedTensorClient",
25
- "TaskStatus",
26
16
  "TaskInfo",
27
- ]
17
+ "TaskStatus",
18
+ ]
19
+
20
+ __version__ = "0.2.1"
@@ -0,0 +1,166 @@
1
+ """Async task-oriented client facade built on top of :mod:`shared_tensor.client`."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from collections.abc import Callable
7
+ from typing import Any, cast
8
+
9
+ from shared_tensor.async_task import TaskInfo, TaskStatus
10
+ from shared_tensor.client import SharedTensorClient
11
+ from shared_tensor.errors import SharedTensorTaskError
12
+ from shared_tensor.utils import resolve_legacy_endpoint_name, serialize_call_payloads
13
+
14
+
15
+ class AsyncSharedTensorClient:
16
+ def __init__(
17
+ self,
18
+ port: int = 2537,
19
+ verbose_debug: bool = False,
20
+ poll_interval: float = 1.0,
21
+ *,
22
+ host: str = "127.0.0.1",
23
+ timeout: float = 30.0,
24
+ ) -> None:
25
+ self.poll_interval = poll_interval
26
+ self._client = SharedTensorClient(
27
+ port=port,
28
+ host=host,
29
+ timeout=timeout,
30
+ verbose_debug=verbose_debug,
31
+ )
32
+
33
+ def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
34
+ encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
35
+ result = self._client._request(
36
+ "submit",
37
+ {
38
+ "endpoint": endpoint,
39
+ "args_hex": args_payload.hex(),
40
+ "kwargs_hex": kwargs_payload.hex(),
41
+ "encoding": encoding,
42
+ },
43
+ )
44
+ return cast(str, result["task_id"])
45
+
46
+ def submit_task(
47
+ self,
48
+ function_path: str,
49
+ args: tuple[Any, ...] = (),
50
+ kwargs: dict[str, Any] | None = None,
51
+ options: dict[str, Any] | None = None,
52
+ ) -> str:
53
+ del options
54
+ endpoint = resolve_legacy_endpoint_name(function_path)
55
+ return self.submit(endpoint, *(args or ()), **(kwargs or {}))
56
+
57
+ def status(self, task_id: str) -> TaskInfo:
58
+ return TaskInfo.from_dict(self._client._request("get_task", {"task_id": task_id}))
59
+
60
+ def get_task_status(self, task_id: str) -> TaskInfo:
61
+ return self.status(task_id)
62
+
63
+ def result(self, task_id: str) -> Any:
64
+ result = self._client._request("get_task_result", {"task_id": task_id})
65
+ return SharedTensorClient._decode_rpc_payload(result)
66
+
67
+ def get_task_result(self, task_id: str) -> Any:
68
+ return self.result(task_id)
69
+
70
+ def wait(
71
+ self,
72
+ task_id: str,
73
+ timeout: float | None = None,
74
+ callback: Callable[[TaskInfo], None] | None = None,
75
+ ) -> Any:
76
+ started = time.time()
77
+ while True:
78
+ info = self.status(task_id)
79
+ if callback is not None:
80
+ callback(info)
81
+ if info.status == TaskStatus.COMPLETED:
82
+ return self.result(task_id)
83
+ if info.status == TaskStatus.FAILED:
84
+ raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
85
+ if info.status == TaskStatus.CANCELLED:
86
+ raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
87
+ if timeout is not None and time.time() - started > timeout:
88
+ raise SharedTensorTaskError(
89
+ f"Task '{task_id}' did not complete within {timeout} seconds"
90
+ )
91
+ time.sleep(self.poll_interval)
92
+
93
+ def wait_for_task(
94
+ self,
95
+ task_id: str,
96
+ timeout: float | None = None,
97
+ callback: Callable[[TaskInfo], None] | None = None,
98
+ ) -> Any:
99
+ return self.wait(task_id, timeout=timeout, callback=callback)
100
+
101
+ def execute_function_async(
102
+ self,
103
+ function_path: str,
104
+ args: tuple[Any, ...] = (),
105
+ kwargs: dict[str, Any] | None = None,
106
+ options: dict[str, Any] | None = None,
107
+ wait: bool = True,
108
+ timeout: float | None = None,
109
+ callback: Callable[[TaskInfo], None] | None = None,
110
+ ) -> Any:
111
+ del options
112
+ task_id = self.submit_task(function_path, args=args, kwargs=kwargs)
113
+ if not wait:
114
+ return task_id
115
+ return self.wait(task_id, timeout=timeout, callback=callback)
116
+
117
+ def cancel(self, task_id: str) -> bool:
118
+ return bool(self._client._request("cancel_task", {"task_id": task_id})["cancelled"])
119
+
120
+ def cancel_task(self, task_id: str) -> bool:
121
+ return self.cancel(task_id)
122
+
123
+ def list_tasks(self, status: str | None = None) -> dict[str, TaskInfo]:
124
+ params = {"status": status} if status else None
125
+ result = self._client._request("list_tasks", params)
126
+ return {task_id: TaskInfo.from_dict(data) for task_id, data in result.items()}
127
+
128
+ def close(self) -> None:
129
+ self._client.close()
130
+
131
+ def __enter__(self) -> AsyncSharedTensorClient:
132
+ return self
133
+
134
+ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
135
+ self.close()
136
+
137
+
138
+ def execute_remote_function_async(
139
+ function_path: str,
140
+ args: tuple[Any, ...] = (),
141
+ kwargs: dict[str, Any] | None = None,
142
+ options: dict[str, Any] | None = None,
143
+ *,
144
+ server_port: int = 2537,
145
+ host: str = "127.0.0.1",
146
+ verbose_debug: bool = False,
147
+ poll_interval: float = 1.0,
148
+ wait: bool = True,
149
+ timeout: float | None = None,
150
+ callback: Callable[[TaskInfo], None] | None = None,
151
+ ) -> Any:
152
+ with AsyncSharedTensorClient(
153
+ port=server_port,
154
+ host=host,
155
+ verbose_debug=verbose_debug,
156
+ poll_interval=poll_interval,
157
+ ) as client:
158
+ return client.execute_function_async(
159
+ function_path,
160
+ args=args,
161
+ kwargs=kwargs,
162
+ options=options,
163
+ wait=wait,
164
+ timeout=timeout,
165
+ callback=callback,
166
+ )