shared-tensor 0.2.7__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/PKG-INFO +13 -3
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/README.md +11 -1
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/pyproject.toml +2 -2
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/__init__.py +1 -1
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/async_task.py +44 -14
- shared_tensor-0.2.8/shared_tensor/managed_object.py +135 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/provider.py +41 -28
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/runtime.py +7 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/server.py +53 -23
- shared_tensor-0.2.7/shared_tensor/managed_object.py +0 -126
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/LICENSE +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/MANIFEST.in +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/setup.cfg +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/async_client.py +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/async_provider.py +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/client.py +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/errors.py +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/transport.py +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor/utils.py +0 -0
- {shared_tensor-0.2.7 → shared_tensor-0.2.8}/shared_tensor.egg-info/SOURCES.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shared-tensor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8
|
|
4
4
|
Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
|
|
5
5
|
Author-email: Athena Team <contact@world-sim-dev.org>
|
|
6
6
|
Maintainer-email: Athena Team <contact@world-sim-dev.org>
|
|
7
7
|
License-Expression: Apache-2.0
|
|
8
8
|
Project-URL: Homepage, https://github.com/world-sim-dev/shared-tensor
|
|
9
9
|
Project-URL: Repository, https://github.com/world-sim-dev/shared-tensor
|
|
10
|
-
Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/
|
|
10
|
+
Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/tree/main/docs
|
|
11
11
|
Project-URL: Bug Reports, https://github.com/world-sim-dev/shared-tensor/issues
|
|
12
12
|
Project-URL: Changelog, https://github.com/world-sim-dev/shared-tensor/releases
|
|
13
13
|
Keywords: gpu,memory,sharing,ipc,inter-process-communication,pytorch,cuda,model-serving,inference,torch,torch-ipc
|
|
@@ -77,7 +77,7 @@ Not supported:
|
|
|
77
77
|
|
|
78
78
|
## Install
|
|
79
79
|
|
|
80
|
-
Use Python `3.
|
|
80
|
+
Use Python `3.9+` and a CUDA-enabled PyTorch build.
|
|
81
81
|
|
|
82
82
|
```bash
|
|
83
83
|
pip install shared-tensor
|
|
@@ -91,6 +91,16 @@ conda activate shared-tensor-dev
|
|
|
91
91
|
pip install -e ".[dev,test]"
|
|
92
92
|
```
|
|
93
93
|
|
|
94
|
+
## Docs
|
|
95
|
+
|
|
96
|
+
Read the examples first, then the design notes:
|
|
97
|
+
|
|
98
|
+
- `docs/overview.md`
|
|
99
|
+
- `docs/patterns.md`
|
|
100
|
+
- `docs/architecture.md`
|
|
101
|
+
- `docs/lifecycle.md`
|
|
102
|
+
- `docs/diagrams.md`
|
|
103
|
+
|
|
94
104
|
## Example: Manual Two-Process Deployment
|
|
95
105
|
|
|
96
106
|
Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
|
|
@@ -25,7 +25,7 @@ Not supported:
|
|
|
25
25
|
|
|
26
26
|
## Install
|
|
27
27
|
|
|
28
|
-
Use Python `3.
|
|
28
|
+
Use Python `3.9+` and a CUDA-enabled PyTorch build.
|
|
29
29
|
|
|
30
30
|
```bash
|
|
31
31
|
pip install shared-tensor
|
|
@@ -39,6 +39,16 @@ conda activate shared-tensor-dev
|
|
|
39
39
|
pip install -e ".[dev,test]"
|
|
40
40
|
```
|
|
41
41
|
|
|
42
|
+
## Docs
|
|
43
|
+
|
|
44
|
+
Read the examples first, then the design notes:
|
|
45
|
+
|
|
46
|
+
- `docs/overview.md`
|
|
47
|
+
- `docs/patterns.md`
|
|
48
|
+
- `docs/architecture.md`
|
|
49
|
+
- `docs/lifecycle.md`
|
|
50
|
+
- `docs/diagrams.md`
|
|
51
|
+
|
|
42
52
|
## Example: Manual Two-Process Deployment
|
|
43
53
|
|
|
44
54
|
Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shared-tensor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.8"
|
|
8
8
|
description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -75,7 +75,7 @@ docs = [
|
|
|
75
75
|
[project.urls]
|
|
76
76
|
Homepage = "https://github.com/world-sim-dev/shared-tensor"
|
|
77
77
|
Repository = "https://github.com/world-sim-dev/shared-tensor"
|
|
78
|
-
Documentation = "https://github.com/world-sim-dev/shared-tensor/
|
|
78
|
+
Documentation = "https://github.com/world-sim-dev/shared-tensor/tree/main/docs"
|
|
79
79
|
"Bug Reports" = "https://github.com/world-sim-dev/shared-tensor/issues"
|
|
80
80
|
Changelog = "https://github.com/world-sim-dev/shared-tensor/releases"
|
|
81
81
|
|
|
@@ -33,8 +33,6 @@ class TaskInfo:
|
|
|
33
33
|
created_at: float
|
|
34
34
|
started_at: float | None = None
|
|
35
35
|
completed_at: float | None = None
|
|
36
|
-
result_encoding: str | None = None
|
|
37
|
-
result_payload: bytes | None = None
|
|
38
36
|
error_type: str | None = None
|
|
39
37
|
error_message: str | None = None
|
|
40
38
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
@@ -47,8 +45,6 @@ class TaskInfo:
|
|
|
47
45
|
"created_at": self.created_at,
|
|
48
46
|
"started_at": self.started_at,
|
|
49
47
|
"completed_at": self.completed_at,
|
|
50
|
-
"result_encoding": self.result_encoding,
|
|
51
|
-
"result_payload": self.result_payload,
|
|
52
48
|
"error_type": self.error_type,
|
|
53
49
|
"error_message": self.error_message,
|
|
54
50
|
"metadata": dict(self.metadata),
|
|
@@ -66,6 +62,8 @@ class TaskInfo:
|
|
|
66
62
|
class _TaskEntry:
|
|
67
63
|
info: TaskInfo
|
|
68
64
|
future: Future[Any]
|
|
65
|
+
result_encoding: str | None = None
|
|
66
|
+
result_payload: bytes | None = None
|
|
69
67
|
local_result: Any = None
|
|
70
68
|
|
|
71
69
|
|
|
@@ -87,6 +85,7 @@ class TaskManager:
|
|
|
87
85
|
self._last_cleanup = 0.0
|
|
88
86
|
self._lock = RLock()
|
|
89
87
|
self._tasks: dict[str, _TaskEntry] = {}
|
|
88
|
+
self._accepting_submissions = True
|
|
90
89
|
|
|
91
90
|
def submit(
|
|
92
91
|
self,
|
|
@@ -98,6 +97,8 @@ class TaskManager:
|
|
|
98
97
|
) -> TaskInfo:
|
|
99
98
|
self._maybe_cleanup()
|
|
100
99
|
with self._lock:
|
|
100
|
+
if not self._accepting_submissions:
|
|
101
|
+
raise SharedTensorTaskError("Task manager is shutting down and is not accepting new tasks")
|
|
101
102
|
self._drop_oldest_finished_tasks_if_needed()
|
|
102
103
|
if len(self._tasks) >= self._max_tasks:
|
|
103
104
|
raise SharedTensorTaskError("Task capacity exceeded")
|
|
@@ -143,12 +144,11 @@ class TaskManager:
|
|
|
143
144
|
self._store_local_result(task_id, result)
|
|
144
145
|
|
|
145
146
|
if result is None:
|
|
147
|
+
self._store_payload(task_id, encoding=None, payload=None, object_id=None)
|
|
146
148
|
self._transition(
|
|
147
149
|
task_id,
|
|
148
150
|
status=TaskStatus.COMPLETED,
|
|
149
151
|
completed_at=time.time(),
|
|
150
|
-
result_encoding=None,
|
|
151
|
-
result_payload=None,
|
|
152
152
|
)
|
|
153
153
|
return
|
|
154
154
|
|
|
@@ -168,13 +168,16 @@ class TaskManager:
|
|
|
168
168
|
)
|
|
169
169
|
return
|
|
170
170
|
|
|
171
|
+
self._store_payload(
|
|
172
|
+
task_id,
|
|
173
|
+
encoding=payload["encoding"],
|
|
174
|
+
payload=payload["payload_bytes"],
|
|
175
|
+
object_id=payload.get("object_id"),
|
|
176
|
+
)
|
|
171
177
|
self._transition(
|
|
172
178
|
task_id,
|
|
173
179
|
status=TaskStatus.COMPLETED,
|
|
174
180
|
completed_at=time.time(),
|
|
175
|
-
result_encoding=payload["encoding"],
|
|
176
|
-
result_payload=payload["payload_bytes"],
|
|
177
|
-
metadata={"object_id": payload.get("object_id")},
|
|
178
181
|
)
|
|
179
182
|
|
|
180
183
|
@staticmethod
|
|
@@ -201,6 +204,24 @@ class TaskManager:
|
|
|
201
204
|
return
|
|
202
205
|
entry.local_result = value
|
|
203
206
|
|
|
207
|
+
def _store_payload(
|
|
208
|
+
self,
|
|
209
|
+
task_id: str,
|
|
210
|
+
*,
|
|
211
|
+
encoding: str | None,
|
|
212
|
+
payload: bytes | None,
|
|
213
|
+
object_id: str | None,
|
|
214
|
+
) -> None:
|
|
215
|
+
with self._lock:
|
|
216
|
+
entry = self._tasks.get(task_id)
|
|
217
|
+
if entry is None:
|
|
218
|
+
return
|
|
219
|
+
entry.result_encoding = encoding
|
|
220
|
+
entry.result_payload = payload
|
|
221
|
+
metadata = dict(entry.info.metadata)
|
|
222
|
+
metadata["object_id"] = object_id
|
|
223
|
+
entry.info.metadata = metadata
|
|
224
|
+
|
|
204
225
|
def get(self, task_id: str) -> TaskInfo:
|
|
205
226
|
self._maybe_cleanup()
|
|
206
227
|
with self._lock:
|
|
@@ -255,7 +276,14 @@ class TaskManager:
|
|
|
255
276
|
return self.result_payload(task_id)
|
|
256
277
|
|
|
257
278
|
def result_payload(self, task_id: str) -> dict[str, str | bytes | None]:
|
|
258
|
-
|
|
279
|
+
self._maybe_cleanup()
|
|
280
|
+
with self._lock:
|
|
281
|
+
entry = self._tasks.get(task_id)
|
|
282
|
+
if entry is None:
|
|
283
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
284
|
+
info = copy.deepcopy(entry.info)
|
|
285
|
+
encoding = entry.result_encoding
|
|
286
|
+
payload = entry.result_payload
|
|
259
287
|
if info.status == TaskStatus.CANCELLED:
|
|
260
288
|
raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
|
|
261
289
|
if info.status == TaskStatus.FAILED:
|
|
@@ -265,8 +293,8 @@ class TaskManager:
|
|
|
265
293
|
f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
|
|
266
294
|
)
|
|
267
295
|
return {
|
|
268
|
-
"encoding":
|
|
269
|
-
"payload_bytes":
|
|
296
|
+
"encoding": encoding,
|
|
297
|
+
"payload_bytes": payload,
|
|
270
298
|
"object_id": info.metadata.get("object_id"),
|
|
271
299
|
}
|
|
272
300
|
|
|
@@ -309,8 +337,10 @@ class TaskManager:
|
|
|
309
337
|
}
|
|
310
338
|
return items
|
|
311
339
|
|
|
312
|
-
def shutdown(self, *, wait: bool = True) -> None:
|
|
313
|
-
self.
|
|
340
|
+
def shutdown(self, *, wait: bool = True, cancel_futures: bool = True) -> None:
|
|
341
|
+
with self._lock:
|
|
342
|
+
self._accepting_submissions = False
|
|
343
|
+
self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
|
|
314
344
|
|
|
315
345
|
def _maybe_cleanup(self) -> None:
|
|
316
346
|
now = time.time()
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Managed remote object handles and registry state."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from threading import RLock
|
|
8
|
+
from typing import Any, Generic, TypeVar
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(slots=True)
|
|
14
|
+
class ManagedObjectEntry:
|
|
15
|
+
object_id: str
|
|
16
|
+
value: Any
|
|
17
|
+
endpoint: str
|
|
18
|
+
cache_key: str | None
|
|
19
|
+
refcount: int = 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(slots=True)
|
|
23
|
+
class ManagedReleaseResult:
|
|
24
|
+
released: bool
|
|
25
|
+
destroyed: bool
|
|
26
|
+
refcount: int
|
|
27
|
+
cache_key: str | None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ManagedObjectRegistry:
|
|
31
|
+
def __init__(self) -> None:
|
|
32
|
+
self._entries: dict[str, ManagedObjectEntry] = {}
|
|
33
|
+
self._cache_index: dict[str, str] = {}
|
|
34
|
+
self._lock = RLock()
|
|
35
|
+
|
|
36
|
+
def get_cached(self, cache_key: str) -> ManagedObjectEntry | None:
|
|
37
|
+
with self._lock:
|
|
38
|
+
object_id = self._cache_index.get(cache_key)
|
|
39
|
+
if object_id is None:
|
|
40
|
+
return None
|
|
41
|
+
entry = self._entries.get(object_id)
|
|
42
|
+
if entry is None:
|
|
43
|
+
self._cache_index.pop(cache_key, None)
|
|
44
|
+
return None
|
|
45
|
+
return entry
|
|
46
|
+
|
|
47
|
+
def register(self, *, endpoint: str, value: Any, cache_key: str | None) -> ManagedObjectEntry:
|
|
48
|
+
with self._lock:
|
|
49
|
+
entry = ManagedObjectEntry(
|
|
50
|
+
object_id=uuid.uuid4().hex,
|
|
51
|
+
value=value,
|
|
52
|
+
endpoint=endpoint,
|
|
53
|
+
cache_key=cache_key,
|
|
54
|
+
)
|
|
55
|
+
self._entries[entry.object_id] = entry
|
|
56
|
+
if cache_key is not None:
|
|
57
|
+
self._cache_index[cache_key] = entry.object_id
|
|
58
|
+
return entry
|
|
59
|
+
|
|
60
|
+
def get(self, object_id: str) -> ManagedObjectEntry | None:
|
|
61
|
+
with self._lock:
|
|
62
|
+
return self._entries.get(object_id)
|
|
63
|
+
|
|
64
|
+
def add_ref(self, object_id: str) -> ManagedObjectEntry | None:
|
|
65
|
+
with self._lock:
|
|
66
|
+
entry = self._entries.get(object_id)
|
|
67
|
+
if entry is None:
|
|
68
|
+
return None
|
|
69
|
+
entry.refcount += 1
|
|
70
|
+
return entry
|
|
71
|
+
|
|
72
|
+
def release(self, object_id: str) -> ManagedReleaseResult:
|
|
73
|
+
with self._lock:
|
|
74
|
+
entry = self._entries.get(object_id)
|
|
75
|
+
if entry is None:
|
|
76
|
+
return ManagedReleaseResult(released=False, destroyed=False, refcount=0, cache_key=None)
|
|
77
|
+
|
|
78
|
+
entry.refcount -= 1
|
|
79
|
+
destroyed = entry.refcount <= 0
|
|
80
|
+
cache_key = entry.cache_key
|
|
81
|
+
refcount = max(entry.refcount, 0)
|
|
82
|
+
if destroyed:
|
|
83
|
+
self._entries.pop(object_id, None)
|
|
84
|
+
if cache_key is not None and self._cache_index.get(cache_key) == object_id:
|
|
85
|
+
self._cache_index.pop(cache_key, None)
|
|
86
|
+
return ManagedReleaseResult(
|
|
87
|
+
released=True,
|
|
88
|
+
destroyed=destroyed,
|
|
89
|
+
refcount=refcount,
|
|
90
|
+
cache_key=cache_key,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def info(self, object_id: str) -> dict[str, Any] | None:
|
|
94
|
+
with self._lock:
|
|
95
|
+
entry = self._entries.get(object_id)
|
|
96
|
+
if entry is None:
|
|
97
|
+
return None
|
|
98
|
+
return {
|
|
99
|
+
"object_id": entry.object_id,
|
|
100
|
+
"endpoint": entry.endpoint,
|
|
101
|
+
"cache_key": entry.cache_key,
|
|
102
|
+
"refcount": entry.refcount,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def clear(self) -> None:
|
|
106
|
+
with self._lock:
|
|
107
|
+
self._entries.clear()
|
|
108
|
+
self._cache_index.clear()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ReleaseHandle:
|
|
112
|
+
def release(self) -> bool: # pragma: no cover - protocol surface only
|
|
113
|
+
raise NotImplementedError
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass(slots=True)
|
|
117
|
+
class SharedObjectHandle(Generic[T]):
|
|
118
|
+
object_id: str
|
|
119
|
+
value: T
|
|
120
|
+
_releaser: ReleaseHandle
|
|
121
|
+
released: bool = False
|
|
122
|
+
|
|
123
|
+
def release(self) -> bool:
|
|
124
|
+
if self.released:
|
|
125
|
+
return False
|
|
126
|
+
released = self._releaser.release()
|
|
127
|
+
if released:
|
|
128
|
+
self.released = True
|
|
129
|
+
return released
|
|
130
|
+
|
|
131
|
+
def __enter__(self) -> SharedObjectHandle[T]:
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
|
|
135
|
+
self.release()
|
|
@@ -8,6 +8,7 @@ import os
|
|
|
8
8
|
from collections.abc import Callable
|
|
9
9
|
from dataclasses import dataclass
|
|
10
10
|
from functools import wraps
|
|
11
|
+
from threading import RLock
|
|
11
12
|
from typing import Any, Literal
|
|
12
13
|
|
|
13
14
|
from shared_tensor.errors import (
|
|
@@ -113,7 +114,9 @@ class SharedTensorProvider:
|
|
|
113
114
|
self._cache: dict[str, Any] = {}
|
|
114
115
|
self._endpoints: dict[str, EndpointDefinition] = {}
|
|
115
116
|
self._registered_functions = self._endpoints
|
|
116
|
-
|
|
117
|
+
self._lock = RLock()
|
|
118
|
+
self._atexit_registered = False
|
|
119
|
+
self._register_atexit_once()
|
|
117
120
|
|
|
118
121
|
def register(
|
|
119
122
|
self,
|
|
@@ -129,25 +132,26 @@ class SharedTensorProvider:
|
|
|
129
132
|
) -> Callable[..., Any]:
|
|
130
133
|
_validate_endpoint_options(execution=execution, concurrency=concurrency)
|
|
131
134
|
endpoint_name = func.__name__
|
|
132
|
-
|
|
133
|
-
|
|
135
|
+
with self._lock:
|
|
136
|
+
if endpoint_name in self._endpoints:
|
|
137
|
+
raise SharedTensorProviderError(f"Endpoint '{endpoint_name}' is already registered")
|
|
134
138
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
139
|
+
resolved_cache_format_key = (
|
|
140
|
+
func.__qualname__ if cache_format_key is None else cache_format_key
|
|
141
|
+
)
|
|
138
142
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
143
|
+
definition = EndpointDefinition(
|
|
144
|
+
name=endpoint_name,
|
|
145
|
+
func=func,
|
|
146
|
+
cache=cache,
|
|
147
|
+
cache_format_key=resolved_cache_format_key,
|
|
148
|
+
managed=managed,
|
|
149
|
+
async_default_wait=async_default_wait,
|
|
150
|
+
execution=execution,
|
|
151
|
+
concurrency=concurrency,
|
|
152
|
+
singleflight=singleflight,
|
|
153
|
+
)
|
|
154
|
+
self._endpoints[endpoint_name] = definition
|
|
151
155
|
if self.verbose_debug:
|
|
152
156
|
logger.debug(
|
|
153
157
|
"Provider registered endpoint",
|
|
@@ -161,7 +165,7 @@ class SharedTensorProvider:
|
|
|
161
165
|
)
|
|
162
166
|
|
|
163
167
|
if self._should_autostart_server():
|
|
164
|
-
self.
|
|
168
|
+
self._ensure_autostart_server()
|
|
165
169
|
|
|
166
170
|
@wraps(func)
|
|
167
171
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
@@ -272,11 +276,13 @@ class SharedTensorProvider:
|
|
|
272
276
|
return definition.func(*args, **resolved_kwargs)
|
|
273
277
|
|
|
274
278
|
cache_key = self._cache_key_for(endpoint, definition, args, resolved_kwargs)
|
|
275
|
-
|
|
276
|
-
|
|
279
|
+
with self._lock:
|
|
280
|
+
if cache_key in self._cache:
|
|
281
|
+
return self._cache[cache_key]
|
|
277
282
|
|
|
278
283
|
result = definition.func(*args, **resolved_kwargs)
|
|
279
|
-
self.
|
|
284
|
+
with self._lock:
|
|
285
|
+
self._cache[cache_key] = result
|
|
280
286
|
return result
|
|
281
287
|
|
|
282
288
|
def get_endpoint(self, endpoint: str) -> EndpointDefinition:
|
|
@@ -309,18 +315,19 @@ class SharedTensorProvider:
|
|
|
309
315
|
self._async_client.close()
|
|
310
316
|
self._async_client = None
|
|
311
317
|
if self._server is not None:
|
|
312
|
-
self._server.stop()
|
|
318
|
+
self._server.stop(wait_for_tasks=True)
|
|
313
319
|
self._server = None
|
|
314
320
|
|
|
315
321
|
def get_runtime_info(self) -> dict[str, Any]:
|
|
316
322
|
if self.execution_mode in {"server", "local"}:
|
|
323
|
+
server = self._server
|
|
317
324
|
return {
|
|
318
325
|
"execution_mode": self.execution_mode,
|
|
319
326
|
"auto_mode": self.auto_mode,
|
|
320
327
|
"base_path": self.base_path,
|
|
321
328
|
"device_index": self.device_index,
|
|
322
329
|
"server_socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
|
|
323
|
-
"server_running":
|
|
330
|
+
"server_running": bool(server is not None and getattr(server, "running", True)),
|
|
324
331
|
}
|
|
325
332
|
server_info = self._get_client().get_server_info()
|
|
326
333
|
return {
|
|
@@ -361,18 +368,18 @@ class SharedTensorProvider:
|
|
|
361
368
|
def _should_autostart_server(self) -> bool:
|
|
362
369
|
return self.auto_mode and self.execution_mode == "server"
|
|
363
370
|
|
|
364
|
-
def
|
|
371
|
+
def _ensure_autostart_server(self) -> None:
|
|
365
372
|
from shared_tensor.server import SharedTensorServer
|
|
366
373
|
|
|
374
|
+
if self._server is not None:
|
|
375
|
+
return
|
|
367
376
|
if self.verbose_debug:
|
|
368
377
|
logger.debug(
|
|
369
|
-
"Provider
|
|
378
|
+
"Provider starting autostart server",
|
|
370
379
|
extra={
|
|
371
380
|
"socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
|
|
372
381
|
},
|
|
373
382
|
)
|
|
374
|
-
if self._server is not None:
|
|
375
|
-
self._server.stop()
|
|
376
383
|
self._server = SharedTensorServer(
|
|
377
384
|
self,
|
|
378
385
|
socket_path=resolve_runtime_socket_path(self.base_path, self.device_index),
|
|
@@ -395,3 +402,9 @@ class SharedTensorProvider:
|
|
|
395
402
|
func=definition.func,
|
|
396
403
|
cache_format_key=definition.cache_format_key,
|
|
397
404
|
)
|
|
405
|
+
|
|
406
|
+
def _register_atexit_once(self) -> None:
|
|
407
|
+
if self._atexit_registered:
|
|
408
|
+
return
|
|
409
|
+
atexit.register(self.close)
|
|
410
|
+
self._atexit_registered = True
|
|
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|
|
5
5
|
from threading import RLock
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
|
+
from shared_tensor.errors import SharedTensorConfigurationError
|
|
9
|
+
|
|
8
10
|
if TYPE_CHECKING:
|
|
9
11
|
from shared_tensor.server import SharedTensorServer
|
|
10
12
|
|
|
@@ -15,6 +17,11 @@ _SERVERS: dict[str, "SharedTensorServer"] = {}
|
|
|
15
17
|
|
|
16
18
|
def register_local_server(socket_path: str, server: "SharedTensorServer") -> None:
|
|
17
19
|
with _LOCK:
|
|
20
|
+
current = _SERVERS.get(socket_path)
|
|
21
|
+
if current is not None and current is not server:
|
|
22
|
+
raise SharedTensorConfigurationError(
|
|
23
|
+
f"Local runtime socket '{socket_path}' is already registered by another server"
|
|
24
|
+
)
|
|
18
25
|
_SERVERS[socket_path] = server
|
|
19
26
|
|
|
20
27
|
|
|
@@ -39,6 +39,24 @@ from shared_tensor.utils import (
|
|
|
39
39
|
logger = logging.getLogger(__name__)
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
class _ConnectionExecutor:
|
|
43
|
+
def __init__(self, *, max_workers: int) -> None:
|
|
44
|
+
self._semaphore = threading.BoundedSemaphore(max_workers)
|
|
45
|
+
|
|
46
|
+
def submit(self, func, *args, **kwargs) -> threading.Thread:
|
|
47
|
+
self._semaphore.acquire()
|
|
48
|
+
|
|
49
|
+
def runner() -> None:
|
|
50
|
+
try:
|
|
51
|
+
func(*args, **kwargs)
|
|
52
|
+
finally:
|
|
53
|
+
self._semaphore.release()
|
|
54
|
+
|
|
55
|
+
thread = threading.Thread(target=runner, daemon=True)
|
|
56
|
+
thread.start()
|
|
57
|
+
return thread
|
|
58
|
+
|
|
59
|
+
|
|
42
60
|
def _server_version() -> str:
|
|
43
61
|
try:
|
|
44
62
|
from shared_tensor import __version__
|
|
@@ -49,7 +67,7 @@ def _server_version() -> str:
|
|
|
49
67
|
|
|
50
68
|
@dataclass(slots=True)
|
|
51
69
|
class _InFlightCall:
|
|
52
|
-
future: Future
|
|
70
|
+
future: Future
|
|
53
71
|
|
|
54
72
|
|
|
55
73
|
@dataclass(slots=True)
|
|
@@ -107,6 +125,8 @@ class SharedTensorServer:
|
|
|
107
125
|
self._inflight: dict[str, _InFlightCall] = {}
|
|
108
126
|
self._endpoint_locks: dict[str, threading.Lock] = {}
|
|
109
127
|
self._coordination_lock = threading.RLock()
|
|
128
|
+
self._connection_executor = _ConnectionExecutor(max_workers=max_workers)
|
|
129
|
+
self._accepting_requests = True
|
|
110
130
|
if getattr(self.provider, "_server", None) is None:
|
|
111
131
|
self.provider._server = self
|
|
112
132
|
|
|
@@ -114,6 +134,8 @@ class SharedTensorServer:
|
|
|
114
134
|
if self.verbose_debug:
|
|
115
135
|
logger.debug("Server processing request", extra={"method": request.get("method")})
|
|
116
136
|
try:
|
|
137
|
+
if not self._accepting_requests:
|
|
138
|
+
raise SharedTensorConfigurationError("Server is stopping and not accepting new requests")
|
|
117
139
|
method = request.get("method")
|
|
118
140
|
if not isinstance(method, str) or not method:
|
|
119
141
|
raise SharedTensorProtocolError("Missing required field 'method'")
|
|
@@ -293,7 +315,8 @@ class SharedTensorServer:
|
|
|
293
315
|
if self.verbose_debug:
|
|
294
316
|
logger.debug("Server executed direct endpoint", extra={"endpoint": endpoint})
|
|
295
317
|
if cache_key is not None:
|
|
296
|
-
self.
|
|
318
|
+
with self._coordination_lock:
|
|
319
|
+
self._local_cache[cache_key] = value
|
|
297
320
|
return _EndpointResult(value=value)
|
|
298
321
|
|
|
299
322
|
def _materialize_managed_result(
|
|
@@ -329,7 +352,8 @@ class SharedTensorServer:
|
|
|
329
352
|
return None
|
|
330
353
|
self._managed_objects.add_ref(cached.object_id)
|
|
331
354
|
return _EndpointResult(value=cached.value, object_id=cached.object_id)
|
|
332
|
-
|
|
355
|
+
with self._coordination_lock:
|
|
356
|
+
local_value = self._local_cache.get(cache_key)
|
|
333
357
|
if local_value is None:
|
|
334
358
|
return None
|
|
335
359
|
return _EndpointResult(value=local_value)
|
|
@@ -397,11 +421,14 @@ class SharedTensorServer:
|
|
|
397
421
|
return existing.value
|
|
398
422
|
self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
|
|
399
423
|
return value
|
|
400
|
-
if cache_key is not None
|
|
401
|
-
|
|
424
|
+
if cache_key is not None:
|
|
425
|
+
with self._coordination_lock:
|
|
426
|
+
if cache_key in self._local_cache:
|
|
427
|
+
return self._local_cache[cache_key]
|
|
402
428
|
value = definition.func(*args, **resolved_kwargs)
|
|
403
429
|
if cache_key is not None:
|
|
404
|
-
self.
|
|
430
|
+
with self._coordination_lock:
|
|
431
|
+
self._local_cache[cache_key] = value
|
|
405
432
|
return value
|
|
406
433
|
|
|
407
434
|
def _cache_key(
|
|
@@ -421,16 +448,16 @@ class SharedTensorServer:
|
|
|
421
448
|
cache_format_key=definition.cache_format_key,
|
|
422
449
|
)
|
|
423
450
|
|
|
424
|
-
def _acquire_inflight(self, inflight_key: str) -> tuple[Future
|
|
451
|
+
def _acquire_inflight(self, inflight_key: str) -> tuple[Future, bool]:
|
|
425
452
|
with self._coordination_lock:
|
|
426
453
|
inflight = self._inflight.get(inflight_key)
|
|
427
454
|
if inflight is not None:
|
|
428
455
|
return inflight.future, False
|
|
429
|
-
future
|
|
456
|
+
future = Future()
|
|
430
457
|
self._inflight[inflight_key] = _InFlightCall(future=future)
|
|
431
458
|
return future, True
|
|
432
459
|
|
|
433
|
-
def _release_inflight(self, inflight_key: str | None, future: Future
|
|
460
|
+
def _release_inflight(self, inflight_key: str | None, future: Future) -> None:
|
|
434
461
|
if inflight_key is None:
|
|
435
462
|
return
|
|
436
463
|
with self._coordination_lock:
|
|
@@ -535,7 +562,7 @@ class SharedTensorServer:
|
|
|
535
562
|
"socket_path": self.socket_path,
|
|
536
563
|
"uptime": uptime,
|
|
537
564
|
"running": self.running,
|
|
538
|
-
"ready": self.running and self.listener is not None,
|
|
565
|
+
"ready": self.running and self.listener is not None and self._accepting_requests,
|
|
539
566
|
"pid": os.getpid(),
|
|
540
567
|
"ppid": os.getppid(),
|
|
541
568
|
"device_index": resolve_device_index(self.provider.device_index),
|
|
@@ -558,6 +585,7 @@ class SharedTensorServer:
|
|
|
558
585
|
logger.info("Server starting", extra={"socket_path": self.socket_path, "blocking": blocking})
|
|
559
586
|
if self.running or self.server_thread is not None:
|
|
560
587
|
raise SharedTensorConfigurationError("Server is already running")
|
|
588
|
+
self._accepting_requests = True
|
|
561
589
|
if blocking:
|
|
562
590
|
self._resolved_process_start_method = None
|
|
563
591
|
self._serve_forever()
|
|
@@ -576,11 +604,11 @@ class SharedTensorServer:
|
|
|
576
604
|
self._resolved_process_start_method = "thread"
|
|
577
605
|
thread.start()
|
|
578
606
|
if not state.ready.wait(timeout=self.startup_timeout):
|
|
579
|
-
self.stop()
|
|
607
|
+
self.stop(wait_for_tasks=False)
|
|
580
608
|
raise TimeoutError(f"Timed out waiting for server socket {self.socket_path}")
|
|
581
609
|
if state.error is not None:
|
|
582
610
|
error = state.error
|
|
583
|
-
self.stop()
|
|
611
|
+
self.stop(wait_for_tasks=False)
|
|
584
612
|
raise SharedTensorConfigurationError(
|
|
585
613
|
f"Failed to start background server thread for {self.socket_path}: {error}"
|
|
586
614
|
) from error
|
|
@@ -620,12 +648,11 @@ class SharedTensorServer:
|
|
|
620
648
|
if self.running:
|
|
621
649
|
raise
|
|
622
650
|
break
|
|
623
|
-
|
|
624
|
-
thread.start()
|
|
651
|
+
self._connection_executor.submit(self._handle_connection, conn)
|
|
625
652
|
finally:
|
|
626
653
|
if started_event is not None and not started_event.is_set():
|
|
627
654
|
started_event.set()
|
|
628
|
-
self._shutdown_local_resources()
|
|
655
|
+
self._shutdown_local_resources(wait_for_tasks=True)
|
|
629
656
|
|
|
630
657
|
def _handle_connection(self, conn: socket.socket) -> None:
|
|
631
658
|
with conn:
|
|
@@ -655,9 +682,10 @@ class SharedTensorServer:
|
|
|
655
682
|
if 0 <= local_rank < torch.cuda.device_count():
|
|
656
683
|
torch.cuda.set_device(local_rank)
|
|
657
684
|
|
|
658
|
-
def stop(self) -> None:
|
|
685
|
+
def stop(self, *, wait_for_tasks: bool = True) -> None:
|
|
659
686
|
if self.verbose_debug:
|
|
660
687
|
logger.info("Server stopping", extra={"socket_path": self.socket_path})
|
|
688
|
+
self._accepting_requests = False
|
|
661
689
|
self.running = False
|
|
662
690
|
if self.listener is not None:
|
|
663
691
|
self.listener.close()
|
|
@@ -668,21 +696,23 @@ class SharedTensorServer:
|
|
|
668
696
|
self.server_thread = None
|
|
669
697
|
self.server_process = None
|
|
670
698
|
if self.listener is None:
|
|
671
|
-
|
|
699
|
+
self._shutdown_local_resources(wait_for_tasks=wait_for_tasks)
|
|
672
700
|
|
|
673
|
-
def _shutdown_local_resources(self) -> None:
|
|
701
|
+
def _shutdown_local_resources(self, *, wait_for_tasks: bool) -> None:
|
|
702
|
+
self._accepting_requests = False
|
|
674
703
|
self.running = False
|
|
675
704
|
if self.listener is not None:
|
|
676
705
|
self.listener.close()
|
|
677
706
|
self.listener = None
|
|
678
707
|
if self._task_manager is not None:
|
|
679
|
-
self._task_manager.shutdown(wait=
|
|
708
|
+
self._task_manager.shutdown(wait=wait_for_tasks, cancel_futures=not wait_for_tasks)
|
|
680
709
|
self._task_manager = None
|
|
681
710
|
self._managed_objects.clear()
|
|
682
|
-
self.
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
711
|
+
with self._coordination_lock:
|
|
712
|
+
self._cache.clear()
|
|
713
|
+
self._local_cache.clear()
|
|
714
|
+
self._inflight.clear()
|
|
715
|
+
self._endpoint_locks.clear()
|
|
686
716
|
unregister_local_server(self.socket_path, self)
|
|
687
717
|
unlink_socket_path(self.socket_path)
|
|
688
718
|
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
"""Managed remote object handles and registry state."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import uuid
|
|
6
|
-
from dataclasses import dataclass
|
|
7
|
-
from typing import Any, Generic, TypeVar
|
|
8
|
-
|
|
9
|
-
T = TypeVar("T")
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@dataclass(slots=True)
|
|
13
|
-
class ManagedObjectEntry:
|
|
14
|
-
object_id: str
|
|
15
|
-
value: Any
|
|
16
|
-
endpoint: str
|
|
17
|
-
cache_key: str | None
|
|
18
|
-
refcount: int = 1
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass(slots=True)
|
|
22
|
-
class ManagedReleaseResult:
|
|
23
|
-
released: bool
|
|
24
|
-
destroyed: bool
|
|
25
|
-
refcount: int
|
|
26
|
-
cache_key: str | None
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ManagedObjectRegistry:
|
|
30
|
-
def __init__(self) -> None:
|
|
31
|
-
self._entries: dict[str, ManagedObjectEntry] = {}
|
|
32
|
-
self._cache_index: dict[str, str] = {}
|
|
33
|
-
|
|
34
|
-
def get_cached(self, cache_key: str) -> ManagedObjectEntry | None:
|
|
35
|
-
object_id = self._cache_index.get(cache_key)
|
|
36
|
-
if object_id is None:
|
|
37
|
-
return None
|
|
38
|
-
entry = self._entries.get(object_id)
|
|
39
|
-
if entry is None:
|
|
40
|
-
self._cache_index.pop(cache_key, None)
|
|
41
|
-
return None
|
|
42
|
-
return entry
|
|
43
|
-
|
|
44
|
-
def register(self, *, endpoint: str, value: Any, cache_key: str | None) -> ManagedObjectEntry:
|
|
45
|
-
entry = ManagedObjectEntry(
|
|
46
|
-
object_id=uuid.uuid4().hex,
|
|
47
|
-
value=value,
|
|
48
|
-
endpoint=endpoint,
|
|
49
|
-
cache_key=cache_key,
|
|
50
|
-
)
|
|
51
|
-
self._entries[entry.object_id] = entry
|
|
52
|
-
if cache_key is not None:
|
|
53
|
-
self._cache_index[cache_key] = entry.object_id
|
|
54
|
-
return entry
|
|
55
|
-
|
|
56
|
-
def get(self, object_id: str) -> ManagedObjectEntry | None:
|
|
57
|
-
return self._entries.get(object_id)
|
|
58
|
-
|
|
59
|
-
def add_ref(self, object_id: str) -> ManagedObjectEntry | None:
|
|
60
|
-
entry = self._entries.get(object_id)
|
|
61
|
-
if entry is None:
|
|
62
|
-
return None
|
|
63
|
-
entry.refcount += 1
|
|
64
|
-
return entry
|
|
65
|
-
|
|
66
|
-
def release(self, object_id: str) -> ManagedReleaseResult:
|
|
67
|
-
entry = self._entries.get(object_id)
|
|
68
|
-
if entry is None:
|
|
69
|
-
return ManagedReleaseResult(released=False, destroyed=False, refcount=0, cache_key=None)
|
|
70
|
-
|
|
71
|
-
entry.refcount -= 1
|
|
72
|
-
destroyed = entry.refcount <= 0
|
|
73
|
-
cache_key = entry.cache_key
|
|
74
|
-
refcount = max(entry.refcount, 0)
|
|
75
|
-
if destroyed:
|
|
76
|
-
self._entries.pop(object_id, None)
|
|
77
|
-
if cache_key is not None and self._cache_index.get(cache_key) == object_id:
|
|
78
|
-
self._cache_index.pop(cache_key, None)
|
|
79
|
-
return ManagedReleaseResult(
|
|
80
|
-
released=True,
|
|
81
|
-
destroyed=destroyed,
|
|
82
|
-
refcount=refcount,
|
|
83
|
-
cache_key=cache_key,
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
def info(self, object_id: str) -> dict[str, Any] | None:
|
|
87
|
-
entry = self._entries.get(object_id)
|
|
88
|
-
if entry is None:
|
|
89
|
-
return None
|
|
90
|
-
return {
|
|
91
|
-
"object_id": entry.object_id,
|
|
92
|
-
"endpoint": entry.endpoint,
|
|
93
|
-
"cache_key": entry.cache_key,
|
|
94
|
-
"refcount": entry.refcount,
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
def clear(self) -> None:
|
|
98
|
-
self._entries.clear()
|
|
99
|
-
self._cache_index.clear()
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class ReleaseHandle:
|
|
103
|
-
def release(self) -> bool: # pragma: no cover - protocol surface only
|
|
104
|
-
raise NotImplementedError
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@dataclass(slots=True)
|
|
108
|
-
class SharedObjectHandle(Generic[T]):
|
|
109
|
-
object_id: str
|
|
110
|
-
value: T
|
|
111
|
-
_releaser: ReleaseHandle
|
|
112
|
-
released: bool = False
|
|
113
|
-
|
|
114
|
-
def release(self) -> bool:
|
|
115
|
-
if self.released:
|
|
116
|
-
return False
|
|
117
|
-
released = self._releaser.release()
|
|
118
|
-
if released:
|
|
119
|
-
self.released = True
|
|
120
|
-
return released
|
|
121
|
-
|
|
122
|
-
def __enter__(self) -> SharedObjectHandle[T]:
|
|
123
|
-
return self
|
|
124
|
-
|
|
125
|
-
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
|
|
126
|
-
self.release()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|