shared-tensor 0.2.6__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/PKG-INFO +5 -3
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/pyproject.toml +8 -6
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/__init__.py +1 -1
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/async_task.py +43 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/client.py +139 -0
- shared_tensor-0.2.7/shared_tensor/runtime.py +30 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/server.py +87 -35
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor.egg-info/SOURCES.txt +1 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/LICENSE +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/MANIFEST.in +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/README.md +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/setup.cfg +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/async_client.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/async_provider.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/errors.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/managed_object.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/provider.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/transport.py +0 -0
- {shared_tensor-0.2.6 → shared_tensor-0.2.7}/shared_tensor/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shared-tensor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.7
|
|
4
4
|
Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
|
|
5
5
|
Author-email: Athena Team <contact@world-sim-dev.org>
|
|
6
6
|
Maintainer-email: Athena Team <contact@world-sim-dev.org>
|
|
@@ -16,18 +16,20 @@ Classifier: Intended Audience :: Developers
|
|
|
16
16
|
Classifier: Intended Audience :: Science/Research
|
|
17
17
|
Classifier: Operating System :: POSIX :: Linux
|
|
18
18
|
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
19
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
20
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
21
22
|
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
24
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
24
26
|
Classifier: Topic :: System :: Distributed Computing
|
|
25
|
-
Requires-Python:
|
|
27
|
+
Requires-Python: <3.14,>=3.9
|
|
26
28
|
Description-Content-Type: text/markdown
|
|
27
29
|
License-File: LICENSE
|
|
28
30
|
Requires-Dist: cloudpickle>=3.0.0
|
|
29
31
|
Requires-Dist: numpy<2
|
|
30
|
-
Requires-Dist: torch
|
|
32
|
+
Requires-Dist: torch<2.8,>=2.1
|
|
31
33
|
Provides-Extra: dev
|
|
32
34
|
Requires-Dist: pytest>=6.0; extra == "dev"
|
|
33
35
|
Requires-Dist: pytest-cov>=2.0; extra == "dev"
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shared-tensor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.7"
|
|
8
8
|
description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -33,18 +33,20 @@ classifiers = [
|
|
|
33
33
|
"Intended Audience :: Science/Research",
|
|
34
34
|
"Operating System :: POSIX :: Linux",
|
|
35
35
|
"Programming Language :: Python :: 3",
|
|
36
|
+
"Programming Language :: Python :: 3.9",
|
|
36
37
|
"Programming Language :: Python :: 3.10",
|
|
37
38
|
"Programming Language :: Python :: 3.11",
|
|
38
39
|
"Programming Language :: Python :: 3.12",
|
|
40
|
+
"Programming Language :: Python :: 3.13",
|
|
39
41
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
40
42
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
41
43
|
"Topic :: System :: Distributed Computing",
|
|
42
44
|
]
|
|
43
|
-
requires-python = ">=3.
|
|
45
|
+
requires-python = ">=3.9,<3.14"
|
|
44
46
|
dependencies = [
|
|
45
47
|
"cloudpickle>=3.0.0",
|
|
46
48
|
"numpy<2",
|
|
47
|
-
"torch>=2.2.
|
|
49
|
+
"torch>=2.1,<2.8",
|
|
48
50
|
]
|
|
49
51
|
|
|
50
52
|
[project.optional-dependencies]
|
|
@@ -89,7 +91,7 @@ shared_tensor = ["*.so", "*.dll", "*.dylib"]
|
|
|
89
91
|
|
|
90
92
|
[tool.black]
|
|
91
93
|
line-length = 88
|
|
92
|
-
target-version = ['py310', 'py311', 'py312']
|
|
94
|
+
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
|
|
93
95
|
include = '\.pyi?$'
|
|
94
96
|
extend-exclude = '''
|
|
95
97
|
/(
|
|
@@ -115,7 +117,7 @@ use_parentheses = true
|
|
|
115
117
|
ensure_newline_before_comments = true
|
|
116
118
|
|
|
117
119
|
[tool.mypy]
|
|
118
|
-
python_version = "3.
|
|
120
|
+
python_version = "3.9"
|
|
119
121
|
warn_return_any = true
|
|
120
122
|
warn_unused_configs = true
|
|
121
123
|
disallow_untyped_defs = true
|
|
@@ -180,7 +182,7 @@ exclude_lines = [
|
|
|
180
182
|
]
|
|
181
183
|
|
|
182
184
|
[tool.ruff]
|
|
183
|
-
target-version = "
|
|
185
|
+
target-version = "py39"
|
|
184
186
|
line-length = 88
|
|
185
187
|
|
|
186
188
|
[tool.ruff.lint]
|
|
@@ -66,6 +66,7 @@ class TaskInfo:
|
|
|
66
66
|
class _TaskEntry:
|
|
67
67
|
info: TaskInfo
|
|
68
68
|
future: Future[Any]
|
|
69
|
+
local_result: Any = None
|
|
69
70
|
|
|
70
71
|
|
|
71
72
|
class TaskManager:
|
|
@@ -139,6 +140,8 @@ class TaskManager:
|
|
|
139
140
|
)
|
|
140
141
|
return
|
|
141
142
|
|
|
143
|
+
self._store_local_result(task_id, result)
|
|
144
|
+
|
|
142
145
|
if result is None:
|
|
143
146
|
self._transition(
|
|
144
147
|
task_id,
|
|
@@ -191,6 +194,13 @@ class TaskManager:
|
|
|
191
194
|
for key, value in updates.items():
|
|
192
195
|
setattr(entry.info, key, value)
|
|
193
196
|
|
|
197
|
+
def _store_local_result(self, task_id: str, value: Any) -> None:
|
|
198
|
+
with self._lock:
|
|
199
|
+
entry = self._tasks.get(task_id)
|
|
200
|
+
if entry is None:
|
|
201
|
+
return
|
|
202
|
+
entry.local_result = value
|
|
203
|
+
|
|
194
204
|
def get(self, task_id: str) -> TaskInfo:
|
|
195
205
|
self._maybe_cleanup()
|
|
196
206
|
with self._lock:
|
|
@@ -207,6 +217,24 @@ class TaskManager:
|
|
|
207
217
|
return None
|
|
208
218
|
return deserialize_payload(encoding, payload_bytes)
|
|
209
219
|
|
|
220
|
+
def result_local(self, task_id: str) -> Any:
|
|
221
|
+
self._maybe_cleanup()
|
|
222
|
+
with self._lock:
|
|
223
|
+
entry = self._tasks.get(task_id)
|
|
224
|
+
if entry is None:
|
|
225
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
226
|
+
info = copy.deepcopy(entry.info)
|
|
227
|
+
value = entry.local_result
|
|
228
|
+
if info.status == TaskStatus.CANCELLED:
|
|
229
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
|
|
230
|
+
if info.status == TaskStatus.FAILED:
|
|
231
|
+
raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
|
|
232
|
+
if info.status != TaskStatus.COMPLETED:
|
|
233
|
+
raise SharedTensorTaskError(
|
|
234
|
+
f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
|
|
235
|
+
)
|
|
236
|
+
return value
|
|
237
|
+
|
|
210
238
|
def wait_result_payload(
|
|
211
239
|
self,
|
|
212
240
|
task_id: str,
|
|
@@ -242,6 +270,21 @@ class TaskManager:
|
|
|
242
270
|
"object_id": info.metadata.get("object_id"),
|
|
243
271
|
}
|
|
244
272
|
|
|
273
|
+
def wait_result_local(self, task_id: str, timeout: float | None = None) -> Any:
|
|
274
|
+
self._maybe_cleanup()
|
|
275
|
+
with self._lock:
|
|
276
|
+
entry = self._tasks.get(task_id)
|
|
277
|
+
if entry is None:
|
|
278
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
279
|
+
future = entry.future
|
|
280
|
+
try:
|
|
281
|
+
future.result(timeout=timeout)
|
|
282
|
+
except FutureTimeoutError as exc:
|
|
283
|
+
raise SharedTensorTaskError(
|
|
284
|
+
f"Task '{task_id}' did not complete within {timeout} seconds"
|
|
285
|
+
) from exc
|
|
286
|
+
return self.result_local(task_id)
|
|
287
|
+
|
|
245
288
|
def cancel(self, task_id: str) -> bool:
|
|
246
289
|
self._maybe_cleanup()
|
|
247
290
|
with self._lock:
|
|
@@ -8,16 +8,25 @@ from dataclasses import dataclass
|
|
|
8
8
|
from typing import Any, cast
|
|
9
9
|
|
|
10
10
|
from shared_tensor.errors import (
|
|
11
|
+
SharedTensorCapabilityError,
|
|
11
12
|
SharedTensorClientError,
|
|
13
|
+
SharedTensorConfigurationError,
|
|
14
|
+
SharedTensorError,
|
|
15
|
+
SharedTensorProviderError,
|
|
12
16
|
SharedTensorProtocolError,
|
|
13
17
|
SharedTensorRemoteError,
|
|
18
|
+
SharedTensorSerializationError,
|
|
19
|
+
SharedTensorTaskError,
|
|
14
20
|
)
|
|
15
21
|
from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
|
|
22
|
+
from shared_tensor.runtime import get_local_server
|
|
16
23
|
from shared_tensor.transport import recv_message, send_message
|
|
24
|
+
from shared_tensor.async_task import TaskStatus
|
|
17
25
|
from shared_tensor.utils import (
|
|
18
26
|
deserialize_payload,
|
|
19
27
|
resolve_runtime_socket_path,
|
|
20
28
|
serialize_call_payloads,
|
|
29
|
+
validate_payload_for_transport,
|
|
21
30
|
)
|
|
22
31
|
|
|
23
32
|
|
|
@@ -50,6 +59,54 @@ class SharedTensorClient:
|
|
|
50
59
|
self.timeout = timeout
|
|
51
60
|
self.verbose_debug = verbose_debug
|
|
52
61
|
|
|
62
|
+
def _local_server(self):
|
|
63
|
+
return get_local_server(self.socket_path)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _remote_error_from_local(exc: SharedTensorError) -> SharedTensorRemoteError:
|
|
67
|
+
if isinstance(exc, SharedTensorProtocolError):
|
|
68
|
+
code = 1
|
|
69
|
+
elif isinstance(exc, SharedTensorProviderError):
|
|
70
|
+
code = 2
|
|
71
|
+
elif isinstance(exc, SharedTensorSerializationError):
|
|
72
|
+
code = 3
|
|
73
|
+
elif isinstance(exc, SharedTensorCapabilityError):
|
|
74
|
+
code = 4
|
|
75
|
+
elif isinstance(exc, SharedTensorTaskError):
|
|
76
|
+
code = 5
|
|
77
|
+
elif isinstance(exc, SharedTensorConfigurationError):
|
|
78
|
+
code = 6
|
|
79
|
+
else:
|
|
80
|
+
code = 7
|
|
81
|
+
return SharedTensorRemoteError(
|
|
82
|
+
f"Remote error [{code}]: {exc}",
|
|
83
|
+
code=code,
|
|
84
|
+
data=None,
|
|
85
|
+
error_type=type(exc).__name__,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _run_local(self, operation):
|
|
89
|
+
try:
|
|
90
|
+
return operation()
|
|
91
|
+
except SharedTensorError as exc:
|
|
92
|
+
raise self._remote_error_from_local(exc) from exc
|
|
93
|
+
|
|
94
|
+
def _decode_local_result(self, result: Any) -> Any:
|
|
95
|
+
if result is None:
|
|
96
|
+
return None
|
|
97
|
+
value = result.value
|
|
98
|
+
if value is None:
|
|
99
|
+
return None
|
|
100
|
+
validate_payload_for_transport(value, allow_dict_keys=isinstance(value, dict))
|
|
101
|
+
object_id = result.object_id
|
|
102
|
+
if object_id is None:
|
|
103
|
+
return value
|
|
104
|
+
return SharedObjectHandle(
|
|
105
|
+
object_id=cast(str, object_id),
|
|
106
|
+
value=value,
|
|
107
|
+
_releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
|
|
108
|
+
)
|
|
109
|
+
|
|
53
110
|
def _send_request(self, request: dict[str, Any]) -> Any:
|
|
54
111
|
method = request.get("method", "<unknown>")
|
|
55
112
|
if self.verbose_debug:
|
|
@@ -104,6 +161,13 @@ class SharedTensorClient:
|
|
|
104
161
|
def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
|
|
105
162
|
if self.verbose_debug:
|
|
106
163
|
logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
|
|
164
|
+
local_server = self._local_server()
|
|
165
|
+
if local_server is not None:
|
|
166
|
+
return self._run_local(
|
|
167
|
+
lambda: self._decode_local_result(
|
|
168
|
+
local_server.call_local_client(endpoint, args=tuple(args), kwargs=dict(kwargs))
|
|
169
|
+
)
|
|
170
|
+
)
|
|
107
171
|
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
108
172
|
result = self._request(
|
|
109
173
|
"call",
|
|
@@ -119,6 +183,19 @@ class SharedTensorClient:
|
|
|
119
183
|
def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
|
|
120
184
|
if self.verbose_debug:
|
|
121
185
|
logger.debug("Client submitting task", extra={"endpoint": endpoint})
|
|
186
|
+
local_server = self._local_server()
|
|
187
|
+
if local_server is not None:
|
|
188
|
+
return self._run_local(
|
|
189
|
+
lambda: cast(
|
|
190
|
+
str,
|
|
191
|
+
local_server._submit_endpoint_task(
|
|
192
|
+
endpoint,
|
|
193
|
+
local_server.provider.get_endpoint(endpoint),
|
|
194
|
+
tuple(args),
|
|
195
|
+
dict(kwargs),
|
|
196
|
+
).task_id,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
122
199
|
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
123
200
|
result = self._request(
|
|
124
201
|
"submit",
|
|
@@ -134,18 +211,43 @@ class SharedTensorClient:
|
|
|
134
211
|
def release(self, object_id: str) -> bool:
|
|
135
212
|
if self.verbose_debug:
|
|
136
213
|
logger.debug("Client releasing managed object", extra={"object_id": object_id})
|
|
214
|
+
local_server = self._local_server()
|
|
215
|
+
if local_server is not None:
|
|
216
|
+
return self._run_local(
|
|
217
|
+
lambda: bool(local_server._handle_release_object({"object_id": object_id})["released"])
|
|
218
|
+
)
|
|
137
219
|
result = self._request("release_object", {"object_id": object_id})
|
|
138
220
|
return bool(result["released"])
|
|
139
221
|
|
|
140
222
|
def release_many(self, object_ids: list[str]) -> dict[str, bool]:
|
|
223
|
+
local_server = self._local_server()
|
|
224
|
+
if local_server is not None:
|
|
225
|
+
return self._run_local(
|
|
226
|
+
lambda: {
|
|
227
|
+
object_id: bool(released)
|
|
228
|
+
for object_id, released in local_server._handle_release_objects({"object_ids": object_ids})[
|
|
229
|
+
"released"
|
|
230
|
+
].items()
|
|
231
|
+
}
|
|
232
|
+
)
|
|
141
233
|
result = self._request("release_objects", {"object_ids": object_ids})
|
|
142
234
|
return {object_id: bool(released) for object_id, released in result["released"].items()}
|
|
143
235
|
|
|
144
236
|
def get_object_info(self, object_id: str) -> dict[str, Any] | None:
|
|
237
|
+
local_server = self._local_server()
|
|
238
|
+
if local_server is not None:
|
|
239
|
+
return self._run_local(
|
|
240
|
+
lambda: cast(
|
|
241
|
+
dict[str, Any] | None,
|
|
242
|
+
local_server._handle_get_object_info({"object_id": object_id}).get("object"),
|
|
243
|
+
)
|
|
244
|
+
)
|
|
145
245
|
result = self._request("get_object_info", {"object_id": object_id})
|
|
146
246
|
return cast(dict[str, Any] | None, result.get("object"))
|
|
147
247
|
|
|
148
248
|
def ping(self) -> bool:
|
|
249
|
+
if self._local_server() is not None:
|
|
250
|
+
return True
|
|
149
251
|
try:
|
|
150
252
|
self._request("ping")
|
|
151
253
|
except (SharedTensorClientError, SharedTensorRemoteError):
|
|
@@ -153,29 +255,66 @@ class SharedTensorClient:
|
|
|
153
255
|
return True
|
|
154
256
|
|
|
155
257
|
def get_server_info(self) -> dict[str, Any]:
|
|
258
|
+
local_server = self._local_server()
|
|
259
|
+
if local_server is not None:
|
|
260
|
+
return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
|
|
156
261
|
return cast(dict[str, Any], self._request("get_server_info"))
|
|
157
262
|
|
|
158
263
|
def list_endpoints(self) -> dict[str, Any]:
|
|
264
|
+
local_server = self._local_server()
|
|
265
|
+
if local_server is not None:
|
|
266
|
+
return self._run_local(lambda: cast(dict[str, Any], local_server.provider.list_endpoints()))
|
|
159
267
|
return cast(dict[str, Any], self._request("list_endpoints"))
|
|
160
268
|
|
|
161
269
|
def get_task_status(self, task_id: str) -> dict[str, Any]:
|
|
270
|
+
local_server = self._local_server()
|
|
271
|
+
if local_server is not None:
|
|
272
|
+
return self._run_local(
|
|
273
|
+
lambda: cast(dict[str, Any], local_server._task_manager_instance().get(task_id).to_dict())
|
|
274
|
+
)
|
|
162
275
|
return cast(dict[str, Any], self._request("get_task", {"task_id": task_id}))
|
|
163
276
|
|
|
164
277
|
def get_task_result(self, task_id: str) -> Any:
|
|
278
|
+
local_server = self._local_server()
|
|
279
|
+
if local_server is not None:
|
|
280
|
+
return self._run_local(
|
|
281
|
+
lambda: self._decode_local_result(local_server.get_task_result_local(task_id))
|
|
282
|
+
)
|
|
165
283
|
return self._decode_rpc_payload(self._request("get_task_result", {"task_id": task_id}))
|
|
166
284
|
|
|
167
285
|
def wait_task(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
|
|
168
286
|
if self.verbose_debug:
|
|
169
287
|
logger.debug("Client waiting for task", extra={"task_id": task_id, "timeout": timeout})
|
|
288
|
+
local_server = self._local_server()
|
|
289
|
+
if local_server is not None:
|
|
290
|
+
return self._run_local(
|
|
291
|
+
lambda: cast(dict[str, Any], local_server.wait_task_local(task_id, timeout=timeout))
|
|
292
|
+
)
|
|
170
293
|
params = {"task_id": task_id}
|
|
171
294
|
if timeout is not None:
|
|
172
295
|
params["timeout"] = timeout
|
|
173
296
|
return cast(dict[str, Any], self._request("wait_task", params))
|
|
174
297
|
|
|
175
298
|
def cancel_task(self, task_id: str) -> bool:
|
|
299
|
+
local_server = self._local_server()
|
|
300
|
+
if local_server is not None:
|
|
301
|
+
return self._run_local(lambda: bool(local_server._task_manager_instance().cancel(task_id)))
|
|
176
302
|
return bool(self._request("cancel_task", {"task_id": task_id})["cancelled"])
|
|
177
303
|
|
|
178
304
|
def list_tasks(self, status: str | None = None) -> dict[str, Any]:
|
|
305
|
+
local_server = self._local_server()
|
|
306
|
+
if local_server is not None:
|
|
307
|
+
return self._run_local(
|
|
308
|
+
lambda: cast(
|
|
309
|
+
dict[str, Any],
|
|
310
|
+
{
|
|
311
|
+
listed_task_id: info.to_dict()
|
|
312
|
+
for listed_task_id, info in local_server._task_manager_instance()
|
|
313
|
+
.list(status=None if status is None else TaskStatus(status))
|
|
314
|
+
.items()
|
|
315
|
+
},
|
|
316
|
+
)
|
|
317
|
+
)
|
|
179
318
|
params = {"status": status} if status else None
|
|
180
319
|
return cast(dict[str, Any], self._request("list_tasks", params))
|
|
181
320
|
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""In-process runtime registry for thread-backed local servers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from threading import RLock
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from shared_tensor.server import SharedTensorServer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_LOCK = RLock()
|
|
13
|
+
_SERVERS: dict[str, "SharedTensorServer"] = {}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def register_local_server(socket_path: str, server: "SharedTensorServer") -> None:
|
|
17
|
+
with _LOCK:
|
|
18
|
+
_SERVERS[socket_path] = server
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def unregister_local_server(socket_path: str, server: "SharedTensorServer") -> None:
|
|
22
|
+
with _LOCK:
|
|
23
|
+
current = _SERVERS.get(socket_path)
|
|
24
|
+
if current is server:
|
|
25
|
+
_SERVERS.pop(socket_path, None)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_local_server(socket_path: str) -> "SharedTensorServer | None":
|
|
29
|
+
with _LOCK:
|
|
30
|
+
return _SERVERS.get(socket_path)
|
|
@@ -22,6 +22,7 @@ from shared_tensor.errors import (
|
|
|
22
22
|
)
|
|
23
23
|
from shared_tensor.managed_object import ManagedObjectRegistry
|
|
24
24
|
from shared_tensor.provider import EndpointDefinition, SharedTensorProvider
|
|
25
|
+
from shared_tensor.runtime import register_local_server, unregister_local_server
|
|
25
26
|
from shared_tensor.transport import recv_message, send_message
|
|
26
27
|
from shared_tensor.utils import (
|
|
27
28
|
CONTROL_ENCODING,
|
|
@@ -59,6 +60,12 @@ class _ServerThreadState:
|
|
|
59
60
|
error: BaseException | None = None
|
|
60
61
|
|
|
61
62
|
|
|
63
|
+
@dataclass(slots=True)
|
|
64
|
+
class _EndpointResult:
|
|
65
|
+
value: Any
|
|
66
|
+
object_id: str | None = None
|
|
67
|
+
|
|
68
|
+
|
|
62
69
|
class SharedTensorServer:
|
|
63
70
|
def __init__(
|
|
64
71
|
self,
|
|
@@ -197,22 +204,22 @@ class SharedTensorServer:
|
|
|
197
204
|
) -> Any:
|
|
198
205
|
return self._task_manager_instance().submit(
|
|
199
206
|
endpoint,
|
|
200
|
-
self.
|
|
207
|
+
self._execute_endpoint_result,
|
|
201
208
|
(endpoint, definition, args, kwargs),
|
|
202
209
|
{},
|
|
203
|
-
result_encoder=
|
|
210
|
+
result_encoder=self._encode_endpoint_result,
|
|
204
211
|
)
|
|
205
212
|
|
|
206
|
-
def
|
|
213
|
+
def _execute_endpoint_result(
|
|
207
214
|
self,
|
|
208
215
|
endpoint: str,
|
|
209
216
|
definition: EndpointDefinition,
|
|
210
217
|
args: tuple[Any, ...],
|
|
211
218
|
kwargs: dict[str, Any],
|
|
212
|
-
) ->
|
|
219
|
+
) -> _EndpointResult:
|
|
213
220
|
cache_key = self._cache_key(endpoint, definition, args, kwargs)
|
|
214
221
|
if cache_key is not None:
|
|
215
|
-
cached = self.
|
|
222
|
+
cached = self._lookup_cached_result_value(definition, cache_key)
|
|
216
223
|
if cached is not None:
|
|
217
224
|
if self.verbose_debug:
|
|
218
225
|
logger.debug("Server cache hit", extra={"endpoint": endpoint, "cache_key": cache_key})
|
|
@@ -224,20 +231,15 @@ class SharedTensorServer:
|
|
|
224
231
|
if self.verbose_debug and owner:
|
|
225
232
|
logger.debug("Server created singleflight entry", extra={"endpoint": endpoint, "cache_key": inflight_key})
|
|
226
233
|
if not owner:
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
object_id = payload.get("object_id")
|
|
232
|
-
if object_id is not None:
|
|
233
|
-
self._managed_objects.add_ref(object_id)
|
|
234
|
-
return payload
|
|
235
|
-
return future.result()
|
|
234
|
+
result = future.result()
|
|
235
|
+
if definition.managed and result.object_id is not None:
|
|
236
|
+
self._managed_objects.add_ref(result.object_id)
|
|
237
|
+
return result
|
|
236
238
|
else:
|
|
237
239
|
future = None
|
|
238
240
|
|
|
239
241
|
try:
|
|
240
|
-
|
|
242
|
+
result = self._run_endpoint_under_policy(endpoint, definition, args, kwargs, cache_key)
|
|
241
243
|
except Exception as exc:
|
|
242
244
|
if future is not None:
|
|
243
245
|
future.set_exception(exc)
|
|
@@ -245,9 +247,20 @@ class SharedTensorServer:
|
|
|
245
247
|
raise
|
|
246
248
|
|
|
247
249
|
if future is not None:
|
|
248
|
-
future.set_result(
|
|
250
|
+
future.set_result(result)
|
|
249
251
|
self._release_inflight(inflight_key, future)
|
|
250
|
-
return
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
def _execute_endpoint_call(
|
|
255
|
+
self,
|
|
256
|
+
endpoint: str,
|
|
257
|
+
definition: EndpointDefinition,
|
|
258
|
+
args: tuple[Any, ...],
|
|
259
|
+
kwargs: dict[str, Any],
|
|
260
|
+
) -> dict[str, Any]:
|
|
261
|
+
return self._encode_endpoint_result(
|
|
262
|
+
self._execute_endpoint_result(endpoint, definition, args, kwargs)
|
|
263
|
+
)
|
|
251
264
|
|
|
252
265
|
def _run_endpoint_under_policy(
|
|
253
266
|
self,
|
|
@@ -256,11 +269,11 @@ class SharedTensorServer:
|
|
|
256
269
|
args: tuple[Any, ...],
|
|
257
270
|
kwargs: dict[str, Any],
|
|
258
271
|
cache_key: str | None,
|
|
259
|
-
) ->
|
|
272
|
+
) -> _EndpointResult:
|
|
260
273
|
if definition.concurrency == "serialized":
|
|
261
274
|
lock = self._endpoint_lock(endpoint)
|
|
262
275
|
with lock:
|
|
263
|
-
cached = self.
|
|
276
|
+
cached = self._lookup_cached_result_value(definition, cache_key)
|
|
264
277
|
if cached is not None:
|
|
265
278
|
return cached
|
|
266
279
|
return self._materialize_endpoint_result(endpoint, definition, args, kwargs, cache_key)
|
|
@@ -273,17 +286,15 @@ class SharedTensorServer:
|
|
|
273
286
|
args: tuple[Any, ...],
|
|
274
287
|
kwargs: dict[str, Any],
|
|
275
288
|
cache_key: str | None,
|
|
276
|
-
) ->
|
|
289
|
+
) -> _EndpointResult:
|
|
277
290
|
if definition.managed:
|
|
278
291
|
return self._materialize_managed_result(endpoint, definition, args, kwargs, cache_key)
|
|
279
292
|
value = definition.func(*args, **kwargs)
|
|
280
293
|
if self.verbose_debug:
|
|
281
294
|
logger.debug("Server executed direct endpoint", extra={"endpoint": endpoint})
|
|
282
|
-
result = self._encode_result(value)
|
|
283
295
|
if cache_key is not None:
|
|
284
|
-
self._cache[cache_key] = result
|
|
285
296
|
self._local_cache[cache_key] = value
|
|
286
|
-
return
|
|
297
|
+
return _EndpointResult(value=value)
|
|
287
298
|
|
|
288
299
|
def _materialize_managed_result(
|
|
289
300
|
self,
|
|
@@ -292,24 +303,24 @@ class SharedTensorServer:
|
|
|
292
303
|
args: tuple[Any, ...],
|
|
293
304
|
kwargs: dict[str, Any],
|
|
294
305
|
cache_key: str | None,
|
|
295
|
-
) ->
|
|
306
|
+
) -> _EndpointResult:
|
|
296
307
|
if cache_key is not None:
|
|
297
308
|
cached = self._managed_objects.get_cached(cache_key)
|
|
298
309
|
if cached is not None:
|
|
299
310
|
self._managed_objects.add_ref(cached.object_id)
|
|
300
|
-
return
|
|
311
|
+
return _EndpointResult(value=cached.value, object_id=cached.object_id)
|
|
301
312
|
|
|
302
313
|
result = definition.func(*args, **kwargs)
|
|
303
314
|
if self.verbose_debug:
|
|
304
315
|
logger.debug("Server created managed object", extra={"endpoint": endpoint, "cache_key": cache_key})
|
|
305
316
|
entry = self._managed_objects.register(endpoint=endpoint, value=result, cache_key=cache_key)
|
|
306
|
-
return
|
|
317
|
+
return _EndpointResult(value=entry.value, object_id=entry.object_id)
|
|
307
318
|
|
|
308
|
-
def
|
|
319
|
+
def _lookup_cached_result_value(
|
|
309
320
|
self,
|
|
310
321
|
definition: EndpointDefinition,
|
|
311
322
|
cache_key: str | None,
|
|
312
|
-
) ->
|
|
323
|
+
) -> _EndpointResult | None:
|
|
313
324
|
if cache_key is None:
|
|
314
325
|
return None
|
|
315
326
|
if definition.managed:
|
|
@@ -317,16 +328,52 @@ class SharedTensorServer:
|
|
|
317
328
|
if cached is None:
|
|
318
329
|
return None
|
|
319
330
|
self._managed_objects.add_ref(cached.object_id)
|
|
320
|
-
return
|
|
321
|
-
cached = self._cache.get(cache_key)
|
|
322
|
-
if cached is not None:
|
|
323
|
-
return cached
|
|
331
|
+
return _EndpointResult(value=cached.value, object_id=cached.object_id)
|
|
324
332
|
local_value = self._local_cache.get(cache_key)
|
|
325
333
|
if local_value is None:
|
|
326
334
|
return None
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
335
|
+
return _EndpointResult(value=local_value)
|
|
336
|
+
|
|
337
|
+
def call_local_client(
|
|
338
|
+
self,
|
|
339
|
+
endpoint: str,
|
|
340
|
+
*,
|
|
341
|
+
args: tuple[Any, ...] = (),
|
|
342
|
+
kwargs: dict[str, Any] | None = None,
|
|
343
|
+
) -> _EndpointResult | None:
|
|
344
|
+
definition = self.provider.get_endpoint(endpoint)
|
|
345
|
+
resolved_kwargs = kwargs or {}
|
|
346
|
+
if definition.execution == "task":
|
|
347
|
+
task_info = self._submit_endpoint_task(endpoint, definition, args, resolved_kwargs)
|
|
348
|
+
return self.wait_task_result_local(task_info.task_id)
|
|
349
|
+
return self._execute_endpoint_result(endpoint, definition, args, resolved_kwargs)
|
|
350
|
+
|
|
351
|
+
def get_task_result_local(self, task_id: str) -> _EndpointResult | None:
|
|
352
|
+
result = self._task_manager_instance().result_local(task_id)
|
|
353
|
+
if result is None:
|
|
354
|
+
return None
|
|
355
|
+
return result
|
|
356
|
+
|
|
357
|
+
def wait_task_result_local(self, task_id: str, timeout: float | None = None) -> _EndpointResult | None:
|
|
358
|
+
result = self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
|
|
359
|
+
if result is None:
|
|
360
|
+
return None
|
|
361
|
+
return result
|
|
362
|
+
|
|
363
|
+
def wait_task_local(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
|
|
364
|
+
try:
|
|
365
|
+
self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
|
|
366
|
+
except SharedTensorTaskError:
|
|
367
|
+
info = self._task_manager_instance().get(task_id)
|
|
368
|
+
if info.status in {TaskStatus.PENDING, TaskStatus.RUNNING}:
|
|
369
|
+
return info.to_dict()
|
|
370
|
+
raise
|
|
371
|
+
return self._task_manager_instance().get(task_id).to_dict()
|
|
372
|
+
|
|
373
|
+
def encode_local_result(self, result: _EndpointResult | None) -> dict[str, Any]:
|
|
374
|
+
if result is None:
|
|
375
|
+
return {"encoding": None, "payload_bytes": None, "object_id": None}
|
|
376
|
+
return self._encode_endpoint_result(result)
|
|
330
377
|
|
|
331
378
|
def invoke_local(
|
|
332
379
|
self,
|
|
@@ -455,6 +502,9 @@ class SharedTensorServer:
|
|
|
455
502
|
encoding, payload = serialize_payload(value)
|
|
456
503
|
return {"encoding": encoding, "payload_bytes": payload, "object_id": object_id}
|
|
457
504
|
|
|
505
|
+
def _encode_endpoint_result(self, result: _EndpointResult) -> dict[str, Any]:
|
|
506
|
+
return self._encode_result(result.value, object_id=result.object_id)
|
|
507
|
+
|
|
458
508
|
def _task_manager_instance(self) -> TaskManager:
|
|
459
509
|
if self._task_manager is None:
|
|
460
510
|
self._task_manager = TaskManager(
|
|
@@ -560,6 +610,7 @@ class SharedTensorServer:
|
|
|
560
610
|
self.listener = listener
|
|
561
611
|
self.running = True
|
|
562
612
|
self.started_at = time.time()
|
|
613
|
+
register_local_server(self.socket_path, self)
|
|
563
614
|
if started_event is not None:
|
|
564
615
|
started_event.set()
|
|
565
616
|
while self.running:
|
|
@@ -632,6 +683,7 @@ class SharedTensorServer:
|
|
|
632
683
|
self._local_cache.clear()
|
|
633
684
|
self._inflight.clear()
|
|
634
685
|
self._endpoint_locks.clear()
|
|
686
|
+
unregister_local_server(self.socket_path, self)
|
|
635
687
|
unlink_socket_path(self.socket_path)
|
|
636
688
|
|
|
637
689
|
def __enter__(self) -> SharedTensorServer:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|