shared-tensor 0.2.6__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/PKG-INFO +17 -5
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/README.md +11 -1
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/pyproject.toml +9 -7
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/__init__.py +1 -1
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/async_task.py +87 -14
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/client.py +139 -0
- shared_tensor-0.2.8/shared_tensor/managed_object.py +135 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/provider.py +41 -28
- shared_tensor-0.2.8/shared_tensor/runtime.py +37 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/server.py +140 -58
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor.egg-info/SOURCES.txt +1 -0
- shared_tensor-0.2.6/shared_tensor/managed_object.py +0 -126
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/LICENSE +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/MANIFEST.in +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/setup.cfg +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/async_client.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/async_provider.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/errors.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/transport.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.8}/shared_tensor/utils.py +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shared-tensor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8
|
|
4
4
|
Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
|
|
5
5
|
Author-email: Athena Team <contact@world-sim-dev.org>
|
|
6
6
|
Maintainer-email: Athena Team <contact@world-sim-dev.org>
|
|
7
7
|
License-Expression: Apache-2.0
|
|
8
8
|
Project-URL: Homepage, https://github.com/world-sim-dev/shared-tensor
|
|
9
9
|
Project-URL: Repository, https://github.com/world-sim-dev/shared-tensor
|
|
10
|
-
Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/
|
|
10
|
+
Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/tree/main/docs
|
|
11
11
|
Project-URL: Bug Reports, https://github.com/world-sim-dev/shared-tensor/issues
|
|
12
12
|
Project-URL: Changelog, https://github.com/world-sim-dev/shared-tensor/releases
|
|
13
13
|
Keywords: gpu,memory,sharing,ipc,inter-process-communication,pytorch,cuda,model-serving,inference,torch,torch-ipc
|
|
@@ -16,18 +16,20 @@ Classifier: Intended Audience :: Developers
|
|
|
16
16
|
Classifier: Intended Audience :: Science/Research
|
|
17
17
|
Classifier: Operating System :: POSIX :: Linux
|
|
18
18
|
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
19
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
20
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
21
22
|
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
24
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
24
26
|
Classifier: Topic :: System :: Distributed Computing
|
|
25
|
-
Requires-Python:
|
|
27
|
+
Requires-Python: <3.14,>=3.9
|
|
26
28
|
Description-Content-Type: text/markdown
|
|
27
29
|
License-File: LICENSE
|
|
28
30
|
Requires-Dist: cloudpickle>=3.0.0
|
|
29
31
|
Requires-Dist: numpy<2
|
|
30
|
-
Requires-Dist: torch
|
|
32
|
+
Requires-Dist: torch<2.8,>=2.1
|
|
31
33
|
Provides-Extra: dev
|
|
32
34
|
Requires-Dist: pytest>=6.0; extra == "dev"
|
|
33
35
|
Requires-Dist: pytest-cov>=2.0; extra == "dev"
|
|
@@ -75,7 +77,7 @@ Not supported:
|
|
|
75
77
|
|
|
76
78
|
## Install
|
|
77
79
|
|
|
78
|
-
Use Python `3.
|
|
80
|
+
Use Python `3.9+` and a CUDA-enabled PyTorch build.
|
|
79
81
|
|
|
80
82
|
```bash
|
|
81
83
|
pip install shared-tensor
|
|
@@ -89,6 +91,16 @@ conda activate shared-tensor-dev
|
|
|
89
91
|
pip install -e ".[dev,test]"
|
|
90
92
|
```
|
|
91
93
|
|
|
94
|
+
## Docs
|
|
95
|
+
|
|
96
|
+
Read the examples first, then the design notes:
|
|
97
|
+
|
|
98
|
+
- `docs/overview.md`
|
|
99
|
+
- `docs/patterns.md`
|
|
100
|
+
- `docs/architecture.md`
|
|
101
|
+
- `docs/lifecycle.md`
|
|
102
|
+
- `docs/diagrams.md`
|
|
103
|
+
|
|
92
104
|
## Example: Manual Two-Process Deployment
|
|
93
105
|
|
|
94
106
|
Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
|
|
@@ -25,7 +25,7 @@ Not supported:
|
|
|
25
25
|
|
|
26
26
|
## Install
|
|
27
27
|
|
|
28
|
-
Use Python `3.
|
|
28
|
+
Use Python `3.9+` and a CUDA-enabled PyTorch build.
|
|
29
29
|
|
|
30
30
|
```bash
|
|
31
31
|
pip install shared-tensor
|
|
@@ -39,6 +39,16 @@ conda activate shared-tensor-dev
|
|
|
39
39
|
pip install -e ".[dev,test]"
|
|
40
40
|
```
|
|
41
41
|
|
|
42
|
+
## Docs
|
|
43
|
+
|
|
44
|
+
Read the examples first, then the design notes:
|
|
45
|
+
|
|
46
|
+
- `docs/overview.md`
|
|
47
|
+
- `docs/patterns.md`
|
|
48
|
+
- `docs/architecture.md`
|
|
49
|
+
- `docs/lifecycle.md`
|
|
50
|
+
- `docs/diagrams.md`
|
|
51
|
+
|
|
42
52
|
## Example: Manual Two-Process Deployment
|
|
43
53
|
|
|
44
54
|
Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shared-tensor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.8"
|
|
8
8
|
description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -33,18 +33,20 @@ classifiers = [
|
|
|
33
33
|
"Intended Audience :: Science/Research",
|
|
34
34
|
"Operating System :: POSIX :: Linux",
|
|
35
35
|
"Programming Language :: Python :: 3",
|
|
36
|
+
"Programming Language :: Python :: 3.9",
|
|
36
37
|
"Programming Language :: Python :: 3.10",
|
|
37
38
|
"Programming Language :: Python :: 3.11",
|
|
38
39
|
"Programming Language :: Python :: 3.12",
|
|
40
|
+
"Programming Language :: Python :: 3.13",
|
|
39
41
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
40
42
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
41
43
|
"Topic :: System :: Distributed Computing",
|
|
42
44
|
]
|
|
43
|
-
requires-python = ">=3.
|
|
45
|
+
requires-python = ">=3.9,<3.14"
|
|
44
46
|
dependencies = [
|
|
45
47
|
"cloudpickle>=3.0.0",
|
|
46
48
|
"numpy<2",
|
|
47
|
-
"torch>=2.2.
|
|
49
|
+
"torch>=2.1,<2.8",
|
|
48
50
|
]
|
|
49
51
|
|
|
50
52
|
[project.optional-dependencies]
|
|
@@ -73,7 +75,7 @@ docs = [
|
|
|
73
75
|
[project.urls]
|
|
74
76
|
Homepage = "https://github.com/world-sim-dev/shared-tensor"
|
|
75
77
|
Repository = "https://github.com/world-sim-dev/shared-tensor"
|
|
76
|
-
Documentation = "https://github.com/world-sim-dev/shared-tensor/
|
|
78
|
+
Documentation = "https://github.com/world-sim-dev/shared-tensor/tree/main/docs"
|
|
77
79
|
"Bug Reports" = "https://github.com/world-sim-dev/shared-tensor/issues"
|
|
78
80
|
Changelog = "https://github.com/world-sim-dev/shared-tensor/releases"
|
|
79
81
|
|
|
@@ -89,7 +91,7 @@ shared_tensor = ["*.so", "*.dll", "*.dylib"]
|
|
|
89
91
|
|
|
90
92
|
[tool.black]
|
|
91
93
|
line-length = 88
|
|
92
|
-
target-version = ['py310', 'py311', 'py312']
|
|
94
|
+
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
|
|
93
95
|
include = '\.pyi?$'
|
|
94
96
|
extend-exclude = '''
|
|
95
97
|
/(
|
|
@@ -115,7 +117,7 @@ use_parentheses = true
|
|
|
115
117
|
ensure_newline_before_comments = true
|
|
116
118
|
|
|
117
119
|
[tool.mypy]
|
|
118
|
-
python_version = "3.
|
|
120
|
+
python_version = "3.9"
|
|
119
121
|
warn_return_any = true
|
|
120
122
|
warn_unused_configs = true
|
|
121
123
|
disallow_untyped_defs = true
|
|
@@ -180,7 +182,7 @@ exclude_lines = [
|
|
|
180
182
|
]
|
|
181
183
|
|
|
182
184
|
[tool.ruff]
|
|
183
|
-
target-version = "
|
|
185
|
+
target-version = "py39"
|
|
184
186
|
line-length = 88
|
|
185
187
|
|
|
186
188
|
[tool.ruff.lint]
|
|
@@ -33,8 +33,6 @@ class TaskInfo:
|
|
|
33
33
|
created_at: float
|
|
34
34
|
started_at: float | None = None
|
|
35
35
|
completed_at: float | None = None
|
|
36
|
-
result_encoding: str | None = None
|
|
37
|
-
result_payload: bytes | None = None
|
|
38
36
|
error_type: str | None = None
|
|
39
37
|
error_message: str | None = None
|
|
40
38
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
@@ -47,8 +45,6 @@ class TaskInfo:
|
|
|
47
45
|
"created_at": self.created_at,
|
|
48
46
|
"started_at": self.started_at,
|
|
49
47
|
"completed_at": self.completed_at,
|
|
50
|
-
"result_encoding": self.result_encoding,
|
|
51
|
-
"result_payload": self.result_payload,
|
|
52
48
|
"error_type": self.error_type,
|
|
53
49
|
"error_message": self.error_message,
|
|
54
50
|
"metadata": dict(self.metadata),
|
|
@@ -66,6 +62,9 @@ class TaskInfo:
|
|
|
66
62
|
class _TaskEntry:
|
|
67
63
|
info: TaskInfo
|
|
68
64
|
future: Future[Any]
|
|
65
|
+
result_encoding: str | None = None
|
|
66
|
+
result_payload: bytes | None = None
|
|
67
|
+
local_result: Any = None
|
|
69
68
|
|
|
70
69
|
|
|
71
70
|
class TaskManager:
|
|
@@ -86,6 +85,7 @@ class TaskManager:
|
|
|
86
85
|
self._last_cleanup = 0.0
|
|
87
86
|
self._lock = RLock()
|
|
88
87
|
self._tasks: dict[str, _TaskEntry] = {}
|
|
88
|
+
self._accepting_submissions = True
|
|
89
89
|
|
|
90
90
|
def submit(
|
|
91
91
|
self,
|
|
@@ -97,6 +97,8 @@ class TaskManager:
|
|
|
97
97
|
) -> TaskInfo:
|
|
98
98
|
self._maybe_cleanup()
|
|
99
99
|
with self._lock:
|
|
100
|
+
if not self._accepting_submissions:
|
|
101
|
+
raise SharedTensorTaskError("Task manager is shutting down and is not accepting new tasks")
|
|
100
102
|
self._drop_oldest_finished_tasks_if_needed()
|
|
101
103
|
if len(self._tasks) >= self._max_tasks:
|
|
102
104
|
raise SharedTensorTaskError("Task capacity exceeded")
|
|
@@ -139,13 +141,14 @@ class TaskManager:
|
|
|
139
141
|
)
|
|
140
142
|
return
|
|
141
143
|
|
|
144
|
+
self._store_local_result(task_id, result)
|
|
145
|
+
|
|
142
146
|
if result is None:
|
|
147
|
+
self._store_payload(task_id, encoding=None, payload=None, object_id=None)
|
|
143
148
|
self._transition(
|
|
144
149
|
task_id,
|
|
145
150
|
status=TaskStatus.COMPLETED,
|
|
146
151
|
completed_at=time.time(),
|
|
147
|
-
result_encoding=None,
|
|
148
|
-
result_payload=None,
|
|
149
152
|
)
|
|
150
153
|
return
|
|
151
154
|
|
|
@@ -165,13 +168,16 @@ class TaskManager:
|
|
|
165
168
|
)
|
|
166
169
|
return
|
|
167
170
|
|
|
171
|
+
self._store_payload(
|
|
172
|
+
task_id,
|
|
173
|
+
encoding=payload["encoding"],
|
|
174
|
+
payload=payload["payload_bytes"],
|
|
175
|
+
object_id=payload.get("object_id"),
|
|
176
|
+
)
|
|
168
177
|
self._transition(
|
|
169
178
|
task_id,
|
|
170
179
|
status=TaskStatus.COMPLETED,
|
|
171
180
|
completed_at=time.time(),
|
|
172
|
-
result_encoding=payload["encoding"],
|
|
173
|
-
result_payload=payload["payload_bytes"],
|
|
174
|
-
metadata={"object_id": payload.get("object_id")},
|
|
175
181
|
)
|
|
176
182
|
|
|
177
183
|
@staticmethod
|
|
@@ -191,6 +197,31 @@ class TaskManager:
|
|
|
191
197
|
for key, value in updates.items():
|
|
192
198
|
setattr(entry.info, key, value)
|
|
193
199
|
|
|
200
|
+
def _store_local_result(self, task_id: str, value: Any) -> None:
|
|
201
|
+
with self._lock:
|
|
202
|
+
entry = self._tasks.get(task_id)
|
|
203
|
+
if entry is None:
|
|
204
|
+
return
|
|
205
|
+
entry.local_result = value
|
|
206
|
+
|
|
207
|
+
def _store_payload(
|
|
208
|
+
self,
|
|
209
|
+
task_id: str,
|
|
210
|
+
*,
|
|
211
|
+
encoding: str | None,
|
|
212
|
+
payload: bytes | None,
|
|
213
|
+
object_id: str | None,
|
|
214
|
+
) -> None:
|
|
215
|
+
with self._lock:
|
|
216
|
+
entry = self._tasks.get(task_id)
|
|
217
|
+
if entry is None:
|
|
218
|
+
return
|
|
219
|
+
entry.result_encoding = encoding
|
|
220
|
+
entry.result_payload = payload
|
|
221
|
+
metadata = dict(entry.info.metadata)
|
|
222
|
+
metadata["object_id"] = object_id
|
|
223
|
+
entry.info.metadata = metadata
|
|
224
|
+
|
|
194
225
|
def get(self, task_id: str) -> TaskInfo:
|
|
195
226
|
self._maybe_cleanup()
|
|
196
227
|
with self._lock:
|
|
@@ -207,6 +238,24 @@ class TaskManager:
|
|
|
207
238
|
return None
|
|
208
239
|
return deserialize_payload(encoding, payload_bytes)
|
|
209
240
|
|
|
241
|
+
def result_local(self, task_id: str) -> Any:
|
|
242
|
+
self._maybe_cleanup()
|
|
243
|
+
with self._lock:
|
|
244
|
+
entry = self._tasks.get(task_id)
|
|
245
|
+
if entry is None:
|
|
246
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
247
|
+
info = copy.deepcopy(entry.info)
|
|
248
|
+
value = entry.local_result
|
|
249
|
+
if info.status == TaskStatus.CANCELLED:
|
|
250
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
|
|
251
|
+
if info.status == TaskStatus.FAILED:
|
|
252
|
+
raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
|
|
253
|
+
if info.status != TaskStatus.COMPLETED:
|
|
254
|
+
raise SharedTensorTaskError(
|
|
255
|
+
f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
|
|
256
|
+
)
|
|
257
|
+
return value
|
|
258
|
+
|
|
210
259
|
def wait_result_payload(
|
|
211
260
|
self,
|
|
212
261
|
task_id: str,
|
|
@@ -227,7 +276,14 @@ class TaskManager:
|
|
|
227
276
|
return self.result_payload(task_id)
|
|
228
277
|
|
|
229
278
|
def result_payload(self, task_id: str) -> dict[str, str | bytes | None]:
|
|
230
|
-
|
|
279
|
+
self._maybe_cleanup()
|
|
280
|
+
with self._lock:
|
|
281
|
+
entry = self._tasks.get(task_id)
|
|
282
|
+
if entry is None:
|
|
283
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
284
|
+
info = copy.deepcopy(entry.info)
|
|
285
|
+
encoding = entry.result_encoding
|
|
286
|
+
payload = entry.result_payload
|
|
231
287
|
if info.status == TaskStatus.CANCELLED:
|
|
232
288
|
raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
|
|
233
289
|
if info.status == TaskStatus.FAILED:
|
|
@@ -237,11 +293,26 @@ class TaskManager:
|
|
|
237
293
|
f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
|
|
238
294
|
)
|
|
239
295
|
return {
|
|
240
|
-
"encoding":
|
|
241
|
-
"payload_bytes":
|
|
296
|
+
"encoding": encoding,
|
|
297
|
+
"payload_bytes": payload,
|
|
242
298
|
"object_id": info.metadata.get("object_id"),
|
|
243
299
|
}
|
|
244
300
|
|
|
301
|
+
def wait_result_local(self, task_id: str, timeout: float | None = None) -> Any:
|
|
302
|
+
self._maybe_cleanup()
|
|
303
|
+
with self._lock:
|
|
304
|
+
entry = self._tasks.get(task_id)
|
|
305
|
+
if entry is None:
|
|
306
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
307
|
+
future = entry.future
|
|
308
|
+
try:
|
|
309
|
+
future.result(timeout=timeout)
|
|
310
|
+
except FutureTimeoutError as exc:
|
|
311
|
+
raise SharedTensorTaskError(
|
|
312
|
+
f"Task '{task_id}' did not complete within {timeout} seconds"
|
|
313
|
+
) from exc
|
|
314
|
+
return self.result_local(task_id)
|
|
315
|
+
|
|
245
316
|
def cancel(self, task_id: str) -> bool:
|
|
246
317
|
self._maybe_cleanup()
|
|
247
318
|
with self._lock:
|
|
@@ -266,8 +337,10 @@ class TaskManager:
|
|
|
266
337
|
}
|
|
267
338
|
return items
|
|
268
339
|
|
|
269
|
-
def shutdown(self, *, wait: bool = True) -> None:
|
|
270
|
-
self.
|
|
340
|
+
def shutdown(self, *, wait: bool = True, cancel_futures: bool = True) -> None:
|
|
341
|
+
with self._lock:
|
|
342
|
+
self._accepting_submissions = False
|
|
343
|
+
self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
|
|
271
344
|
|
|
272
345
|
def _maybe_cleanup(self) -> None:
|
|
273
346
|
now = time.time()
|
|
@@ -8,16 +8,25 @@ from dataclasses import dataclass
|
|
|
8
8
|
from typing import Any, cast
|
|
9
9
|
|
|
10
10
|
from shared_tensor.errors import (
|
|
11
|
+
SharedTensorCapabilityError,
|
|
11
12
|
SharedTensorClientError,
|
|
13
|
+
SharedTensorConfigurationError,
|
|
14
|
+
SharedTensorError,
|
|
15
|
+
SharedTensorProviderError,
|
|
12
16
|
SharedTensorProtocolError,
|
|
13
17
|
SharedTensorRemoteError,
|
|
18
|
+
SharedTensorSerializationError,
|
|
19
|
+
SharedTensorTaskError,
|
|
14
20
|
)
|
|
15
21
|
from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
|
|
22
|
+
from shared_tensor.runtime import get_local_server
|
|
16
23
|
from shared_tensor.transport import recv_message, send_message
|
|
24
|
+
from shared_tensor.async_task import TaskStatus
|
|
17
25
|
from shared_tensor.utils import (
|
|
18
26
|
deserialize_payload,
|
|
19
27
|
resolve_runtime_socket_path,
|
|
20
28
|
serialize_call_payloads,
|
|
29
|
+
validate_payload_for_transport,
|
|
21
30
|
)
|
|
22
31
|
|
|
23
32
|
|
|
@@ -50,6 +59,54 @@ class SharedTensorClient:
|
|
|
50
59
|
self.timeout = timeout
|
|
51
60
|
self.verbose_debug = verbose_debug
|
|
52
61
|
|
|
62
|
+
def _local_server(self):
|
|
63
|
+
return get_local_server(self.socket_path)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _remote_error_from_local(exc: SharedTensorError) -> SharedTensorRemoteError:
|
|
67
|
+
if isinstance(exc, SharedTensorProtocolError):
|
|
68
|
+
code = 1
|
|
69
|
+
elif isinstance(exc, SharedTensorProviderError):
|
|
70
|
+
code = 2
|
|
71
|
+
elif isinstance(exc, SharedTensorSerializationError):
|
|
72
|
+
code = 3
|
|
73
|
+
elif isinstance(exc, SharedTensorCapabilityError):
|
|
74
|
+
code = 4
|
|
75
|
+
elif isinstance(exc, SharedTensorTaskError):
|
|
76
|
+
code = 5
|
|
77
|
+
elif isinstance(exc, SharedTensorConfigurationError):
|
|
78
|
+
code = 6
|
|
79
|
+
else:
|
|
80
|
+
code = 7
|
|
81
|
+
return SharedTensorRemoteError(
|
|
82
|
+
f"Remote error [{code}]: {exc}",
|
|
83
|
+
code=code,
|
|
84
|
+
data=None,
|
|
85
|
+
error_type=type(exc).__name__,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _run_local(self, operation):
|
|
89
|
+
try:
|
|
90
|
+
return operation()
|
|
91
|
+
except SharedTensorError as exc:
|
|
92
|
+
raise self._remote_error_from_local(exc) from exc
|
|
93
|
+
|
|
94
|
+
def _decode_local_result(self, result: Any) -> Any:
|
|
95
|
+
if result is None:
|
|
96
|
+
return None
|
|
97
|
+
value = result.value
|
|
98
|
+
if value is None:
|
|
99
|
+
return None
|
|
100
|
+
validate_payload_for_transport(value, allow_dict_keys=isinstance(value, dict))
|
|
101
|
+
object_id = result.object_id
|
|
102
|
+
if object_id is None:
|
|
103
|
+
return value
|
|
104
|
+
return SharedObjectHandle(
|
|
105
|
+
object_id=cast(str, object_id),
|
|
106
|
+
value=value,
|
|
107
|
+
_releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
|
|
108
|
+
)
|
|
109
|
+
|
|
53
110
|
def _send_request(self, request: dict[str, Any]) -> Any:
|
|
54
111
|
method = request.get("method", "<unknown>")
|
|
55
112
|
if self.verbose_debug:
|
|
@@ -104,6 +161,13 @@ class SharedTensorClient:
|
|
|
104
161
|
def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
|
|
105
162
|
if self.verbose_debug:
|
|
106
163
|
logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
|
|
164
|
+
local_server = self._local_server()
|
|
165
|
+
if local_server is not None:
|
|
166
|
+
return self._run_local(
|
|
167
|
+
lambda: self._decode_local_result(
|
|
168
|
+
local_server.call_local_client(endpoint, args=tuple(args), kwargs=dict(kwargs))
|
|
169
|
+
)
|
|
170
|
+
)
|
|
107
171
|
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
108
172
|
result = self._request(
|
|
109
173
|
"call",
|
|
@@ -119,6 +183,19 @@ class SharedTensorClient:
|
|
|
119
183
|
def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
|
|
120
184
|
if self.verbose_debug:
|
|
121
185
|
logger.debug("Client submitting task", extra={"endpoint": endpoint})
|
|
186
|
+
local_server = self._local_server()
|
|
187
|
+
if local_server is not None:
|
|
188
|
+
return self._run_local(
|
|
189
|
+
lambda: cast(
|
|
190
|
+
str,
|
|
191
|
+
local_server._submit_endpoint_task(
|
|
192
|
+
endpoint,
|
|
193
|
+
local_server.provider.get_endpoint(endpoint),
|
|
194
|
+
tuple(args),
|
|
195
|
+
dict(kwargs),
|
|
196
|
+
).task_id,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
122
199
|
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
123
200
|
result = self._request(
|
|
124
201
|
"submit",
|
|
@@ -134,18 +211,43 @@ class SharedTensorClient:
|
|
|
134
211
|
def release(self, object_id: str) -> bool:
|
|
135
212
|
if self.verbose_debug:
|
|
136
213
|
logger.debug("Client releasing managed object", extra={"object_id": object_id})
|
|
214
|
+
local_server = self._local_server()
|
|
215
|
+
if local_server is not None:
|
|
216
|
+
return self._run_local(
|
|
217
|
+
lambda: bool(local_server._handle_release_object({"object_id": object_id})["released"])
|
|
218
|
+
)
|
|
137
219
|
result = self._request("release_object", {"object_id": object_id})
|
|
138
220
|
return bool(result["released"])
|
|
139
221
|
|
|
140
222
|
def release_many(self, object_ids: list[str]) -> dict[str, bool]:
|
|
223
|
+
local_server = self._local_server()
|
|
224
|
+
if local_server is not None:
|
|
225
|
+
return self._run_local(
|
|
226
|
+
lambda: {
|
|
227
|
+
object_id: bool(released)
|
|
228
|
+
for object_id, released in local_server._handle_release_objects({"object_ids": object_ids})[
|
|
229
|
+
"released"
|
|
230
|
+
].items()
|
|
231
|
+
}
|
|
232
|
+
)
|
|
141
233
|
result = self._request("release_objects", {"object_ids": object_ids})
|
|
142
234
|
return {object_id: bool(released) for object_id, released in result["released"].items()}
|
|
143
235
|
|
|
144
236
|
def get_object_info(self, object_id: str) -> dict[str, Any] | None:
|
|
237
|
+
local_server = self._local_server()
|
|
238
|
+
if local_server is not None:
|
|
239
|
+
return self._run_local(
|
|
240
|
+
lambda: cast(
|
|
241
|
+
dict[str, Any] | None,
|
|
242
|
+
local_server._handle_get_object_info({"object_id": object_id}).get("object"),
|
|
243
|
+
)
|
|
244
|
+
)
|
|
145
245
|
result = self._request("get_object_info", {"object_id": object_id})
|
|
146
246
|
return cast(dict[str, Any] | None, result.get("object"))
|
|
147
247
|
|
|
148
248
|
def ping(self) -> bool:
|
|
249
|
+
if self._local_server() is not None:
|
|
250
|
+
return True
|
|
149
251
|
try:
|
|
150
252
|
self._request("ping")
|
|
151
253
|
except (SharedTensorClientError, SharedTensorRemoteError):
|
|
@@ -153,29 +255,66 @@ class SharedTensorClient:
|
|
|
153
255
|
return True
|
|
154
256
|
|
|
155
257
|
def get_server_info(self) -> dict[str, Any]:
|
|
258
|
+
local_server = self._local_server()
|
|
259
|
+
if local_server is not None:
|
|
260
|
+
return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
|
|
156
261
|
return cast(dict[str, Any], self._request("get_server_info"))
|
|
157
262
|
|
|
158
263
|
def list_endpoints(self) -> dict[str, Any]:
|
|
264
|
+
local_server = self._local_server()
|
|
265
|
+
if local_server is not None:
|
|
266
|
+
return self._run_local(lambda: cast(dict[str, Any], local_server.provider.list_endpoints()))
|
|
159
267
|
return cast(dict[str, Any], self._request("list_endpoints"))
|
|
160
268
|
|
|
161
269
|
def get_task_status(self, task_id: str) -> dict[str, Any]:
|
|
270
|
+
local_server = self._local_server()
|
|
271
|
+
if local_server is not None:
|
|
272
|
+
return self._run_local(
|
|
273
|
+
lambda: cast(dict[str, Any], local_server._task_manager_instance().get(task_id).to_dict())
|
|
274
|
+
)
|
|
162
275
|
return cast(dict[str, Any], self._request("get_task", {"task_id": task_id}))
|
|
163
276
|
|
|
164
277
|
def get_task_result(self, task_id: str) -> Any:
|
|
278
|
+
local_server = self._local_server()
|
|
279
|
+
if local_server is not None:
|
|
280
|
+
return self._run_local(
|
|
281
|
+
lambda: self._decode_local_result(local_server.get_task_result_local(task_id))
|
|
282
|
+
)
|
|
165
283
|
return self._decode_rpc_payload(self._request("get_task_result", {"task_id": task_id}))
|
|
166
284
|
|
|
167
285
|
def wait_task(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
|
|
168
286
|
if self.verbose_debug:
|
|
169
287
|
logger.debug("Client waiting for task", extra={"task_id": task_id, "timeout": timeout})
|
|
288
|
+
local_server = self._local_server()
|
|
289
|
+
if local_server is not None:
|
|
290
|
+
return self._run_local(
|
|
291
|
+
lambda: cast(dict[str, Any], local_server.wait_task_local(task_id, timeout=timeout))
|
|
292
|
+
)
|
|
170
293
|
params = {"task_id": task_id}
|
|
171
294
|
if timeout is not None:
|
|
172
295
|
params["timeout"] = timeout
|
|
173
296
|
return cast(dict[str, Any], self._request("wait_task", params))
|
|
174
297
|
|
|
175
298
|
def cancel_task(self, task_id: str) -> bool:
|
|
299
|
+
local_server = self._local_server()
|
|
300
|
+
if local_server is not None:
|
|
301
|
+
return self._run_local(lambda: bool(local_server._task_manager_instance().cancel(task_id)))
|
|
176
302
|
return bool(self._request("cancel_task", {"task_id": task_id})["cancelled"])
|
|
177
303
|
|
|
178
304
|
def list_tasks(self, status: str | None = None) -> dict[str, Any]:
|
|
305
|
+
local_server = self._local_server()
|
|
306
|
+
if local_server is not None:
|
|
307
|
+
return self._run_local(
|
|
308
|
+
lambda: cast(
|
|
309
|
+
dict[str, Any],
|
|
310
|
+
{
|
|
311
|
+
listed_task_id: info.to_dict()
|
|
312
|
+
for listed_task_id, info in local_server._task_manager_instance()
|
|
313
|
+
.list(status=None if status is None else TaskStatus(status))
|
|
314
|
+
.items()
|
|
315
|
+
},
|
|
316
|
+
)
|
|
317
|
+
)
|
|
179
318
|
params = {"status": status} if status else None
|
|
180
319
|
return cast(dict[str, Any], self._request("list_tasks", params))
|
|
181
320
|
|