shared-tensor 0.2.5__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/PKG-INFO +43 -28
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/README.md +38 -25
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/pyproject.toml +8 -6
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/__init__.py +1 -1
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/async_task.py +43 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/client.py +139 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/provider.py +5 -8
- shared_tensor-0.2.7/shared_tensor/runtime.py +30 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/server.py +196 -126
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor.egg-info/SOURCES.txt +1 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/LICENSE +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/MANIFEST.in +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/setup.cfg +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/async_client.py +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/async_provider.py +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/errors.py +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/managed_object.py +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/transport.py +0 -0
- {shared_tensor-0.2.5 → shared_tensor-0.2.7}/shared_tensor/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shared-tensor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.7
|
|
4
4
|
Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
|
|
5
5
|
Author-email: Athena Team <contact@world-sim-dev.org>
|
|
6
6
|
Maintainer-email: Athena Team <contact@world-sim-dev.org>
|
|
@@ -16,18 +16,20 @@ Classifier: Intended Audience :: Developers
|
|
|
16
16
|
Classifier: Intended Audience :: Science/Research
|
|
17
17
|
Classifier: Operating System :: POSIX :: Linux
|
|
18
18
|
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
19
20
|
Classifier: Programming Language :: Python :: 3.10
|
|
20
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
21
22
|
Classifier: Programming Language :: Python :: 3.12
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
24
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
24
26
|
Classifier: Topic :: System :: Distributed Computing
|
|
25
|
-
Requires-Python:
|
|
27
|
+
Requires-Python: <3.14,>=3.9
|
|
26
28
|
Description-Content-Type: text/markdown
|
|
27
29
|
License-File: LICENSE
|
|
28
30
|
Requires-Dist: cloudpickle>=3.0.0
|
|
29
31
|
Requires-Dist: numpy<2
|
|
30
|
-
Requires-Dist: torch
|
|
32
|
+
Requires-Dist: torch<2.8,>=2.1
|
|
31
33
|
Provides-Extra: dev
|
|
32
34
|
Requires-Dist: pytest>=6.0; extra == "dev"
|
|
33
35
|
Requires-Dist: pytest-cov>=2.0; extra == "dev"
|
|
@@ -63,6 +65,7 @@ Supported:
|
|
|
63
65
|
- sync `call` and task-backed `submit`
|
|
64
66
|
- managed object handles with explicit release
|
|
65
67
|
- server-side caching, `cache_format_key`, and singleflight
|
|
68
|
+
- manual two-process deployment as the primary production path
|
|
66
69
|
- zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
|
|
67
70
|
|
|
68
71
|
Not supported:
|
|
@@ -88,46 +91,58 @@ conda activate shared-tensor-dev
|
|
|
88
91
|
pip install -e ".[dev,test]"
|
|
89
92
|
```
|
|
90
93
|
|
|
91
|
-
## Example:
|
|
94
|
+
## Example: Manual Two-Process Deployment
|
|
95
|
+
|
|
96
|
+
Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
|
|
92
97
|
|
|
93
|
-
See [examples/
|
|
98
|
+
See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
|
|
99
|
+
|
|
100
|
+
Server process:
|
|
94
101
|
|
|
95
102
|
```python
|
|
96
|
-
import
|
|
103
|
+
from shared_tensor import SharedTensorProvider, SharedTensorServer
|
|
97
104
|
|
|
98
|
-
|
|
105
|
+
provider = SharedTensorProvider(execution_mode="server")
|
|
99
106
|
|
|
100
|
-
provider =
|
|
107
|
+
@provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
|
|
108
|
+
def load_model(hidden_size: int = 4):
|
|
109
|
+
...
|
|
101
110
|
|
|
111
|
+
server = SharedTensorServer(provider)
|
|
112
|
+
server.start(blocking=True)
|
|
113
|
+
```
|
|
102
114
|
|
|
103
|
-
|
|
104
|
-
execution="task",
|
|
105
|
-
managed=True,
|
|
106
|
-
concurrency="serialized",
|
|
107
|
-
cache_format_key="model:{hidden_size}",
|
|
108
|
-
)
|
|
109
|
-
def load_model(hidden_size: int = 4) -> torch.nn.Module:
|
|
110
|
-
return torch.nn.Linear(hidden_size, 2, device="cuda")
|
|
115
|
+
Client process:
|
|
111
116
|
|
|
117
|
+
```python
|
|
118
|
+
import torch
|
|
119
|
+
|
|
120
|
+
from shared_tensor import SharedObjectHandle, SharedTensorClient
|
|
112
121
|
|
|
122
|
+
client = SharedTensorClient()
|
|
113
123
|
x = torch.ones(1, 4, device="cuda")
|
|
114
|
-
result = load_model
|
|
124
|
+
result = client.call("load_model", hidden_size=4)
|
|
115
125
|
if isinstance(result, SharedObjectHandle):
|
|
116
126
|
with result as handle:
|
|
117
127
|
y = handle.value(x)
|
|
118
|
-
else:
|
|
119
|
-
y = result(x)
|
|
120
128
|
```
|
|
121
129
|
|
|
122
|
-
|
|
130
|
+
This keeps the contract explicit:
|
|
123
131
|
|
|
124
|
-
```
|
|
125
|
-
|
|
132
|
+
```text
|
|
133
|
+
server process client process
|
|
134
|
+
------------------------------ ------------------------------
|
|
135
|
+
owns CUDA allocations issues local UDS RPC requests
|
|
136
|
+
executes endpoint functions reopens CUDA objects via torch IPC
|
|
137
|
+
manages cache and refcounts releases managed handles explicitly
|
|
126
138
|
```
|
|
127
139
|
|
|
128
|
-
|
|
140
|
+
## Example: Same Code, Two Processes
|
|
141
|
+
|
|
142
|
+
See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
|
|
129
143
|
|
|
130
144
|
```bash
|
|
145
|
+
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
|
|
131
146
|
SHARED_TENSOR_ENABLED=1 python demo.py
|
|
132
147
|
```
|
|
133
148
|
|
|
@@ -138,7 +153,7 @@ same code
|
|
|
138
153
|
|
|
139
154
|
server process client process
|
|
140
155
|
------------------------------ ------------------------------
|
|
141
|
-
provider auto-starts
|
|
156
|
+
provider auto-starts local thread provider builds client wrappers
|
|
142
157
|
shared function runs locally shared function becomes RPC call
|
|
143
158
|
CUDA object stays on same GPU CUDA object is reopened via torch IPC
|
|
144
159
|
```
|
|
@@ -201,19 +216,19 @@ SharedTensorProvider(enabled=None)
|
|
|
201
216
|
Provider runtime controls:
|
|
202
217
|
|
|
203
218
|
```python
|
|
204
|
-
SharedTensorProvider(server_process_start_method="fork")
|
|
205
219
|
SharedTensorProvider(server_startup_timeout=30.0)
|
|
206
220
|
provider.get_runtime_info()
|
|
207
221
|
```
|
|
208
222
|
|
|
209
|
-
|
|
210
|
-
Leave it as `None` to let the library choose a safer default for the current entrypoint.
|
|
223
|
+
Non-blocking provider autostart runs the UDS server in a background thread inside the current process.
|
|
211
224
|
|
|
212
225
|
`execution_mode="auto"` behaves as follows:
|
|
213
226
|
- disabled: local mode
|
|
214
|
-
- enabled + `SHARED_TENSOR_ROLE=server`: auto-start local server and execute endpoints locally
|
|
227
|
+
- enabled + `SHARED_TENSOR_ROLE=server`: auto-start a local background server thread and execute endpoints locally
|
|
215
228
|
- enabled + role unset: build client wrappers
|
|
216
229
|
|
|
230
|
+
For production deployment, prefer explicit `SharedTensorServer(...).start(blocking=True)` in a dedicated server process.
|
|
231
|
+
|
|
217
232
|
Socket selection is per CUDA device:
|
|
218
233
|
- base path comes from `SHARED_TENSOR_BASE_PATH` or `/tmp/shared-tensor`
|
|
219
234
|
- runtime socket path is `<base_path>-<device_index>.sock`
|
|
@@ -13,6 +13,7 @@ Supported:
|
|
|
13
13
|
- sync `call` and task-backed `submit`
|
|
14
14
|
- managed object handles with explicit release
|
|
15
15
|
- server-side caching, `cache_format_key`, and singleflight
|
|
16
|
+
- manual two-process deployment as the primary production path
|
|
16
17
|
- zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
|
|
17
18
|
|
|
18
19
|
Not supported:
|
|
@@ -38,46 +39,58 @@ conda activate shared-tensor-dev
|
|
|
38
39
|
pip install -e ".[dev,test]"
|
|
39
40
|
```
|
|
40
41
|
|
|
41
|
-
## Example:
|
|
42
|
+
## Example: Manual Two-Process Deployment
|
|
43
|
+
|
|
44
|
+
Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
|
|
42
45
|
|
|
43
|
-
See [examples/
|
|
46
|
+
See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
|
|
47
|
+
|
|
48
|
+
Server process:
|
|
44
49
|
|
|
45
50
|
```python
|
|
46
|
-
import
|
|
51
|
+
from shared_tensor import SharedTensorProvider, SharedTensorServer
|
|
47
52
|
|
|
48
|
-
|
|
53
|
+
provider = SharedTensorProvider(execution_mode="server")
|
|
49
54
|
|
|
50
|
-
provider =
|
|
55
|
+
@provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
|
|
56
|
+
def load_model(hidden_size: int = 4):
|
|
57
|
+
...
|
|
51
58
|
|
|
59
|
+
server = SharedTensorServer(provider)
|
|
60
|
+
server.start(blocking=True)
|
|
61
|
+
```
|
|
52
62
|
|
|
53
|
-
|
|
54
|
-
execution="task",
|
|
55
|
-
managed=True,
|
|
56
|
-
concurrency="serialized",
|
|
57
|
-
cache_format_key="model:{hidden_size}",
|
|
58
|
-
)
|
|
59
|
-
def load_model(hidden_size: int = 4) -> torch.nn.Module:
|
|
60
|
-
return torch.nn.Linear(hidden_size, 2, device="cuda")
|
|
63
|
+
Client process:
|
|
61
64
|
|
|
65
|
+
```python
|
|
66
|
+
import torch
|
|
67
|
+
|
|
68
|
+
from shared_tensor import SharedObjectHandle, SharedTensorClient
|
|
62
69
|
|
|
70
|
+
client = SharedTensorClient()
|
|
63
71
|
x = torch.ones(1, 4, device="cuda")
|
|
64
|
-
result = load_model
|
|
72
|
+
result = client.call("load_model", hidden_size=4)
|
|
65
73
|
if isinstance(result, SharedObjectHandle):
|
|
66
74
|
with result as handle:
|
|
67
75
|
y = handle.value(x)
|
|
68
|
-
else:
|
|
69
|
-
y = result(x)
|
|
70
76
|
```
|
|
71
77
|
|
|
72
|
-
|
|
78
|
+
This keeps the contract explicit:
|
|
73
79
|
|
|
74
|
-
```
|
|
75
|
-
|
|
80
|
+
```text
|
|
81
|
+
server process client process
|
|
82
|
+
------------------------------ ------------------------------
|
|
83
|
+
owns CUDA allocations issues local UDS RPC requests
|
|
84
|
+
executes endpoint functions reopens CUDA objects via torch IPC
|
|
85
|
+
manages cache and refcounts releases managed handles explicitly
|
|
76
86
|
```
|
|
77
87
|
|
|
78
|
-
|
|
88
|
+
## Example: Same Code, Two Processes
|
|
89
|
+
|
|
90
|
+
See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
|
|
79
91
|
|
|
80
92
|
```bash
|
|
93
|
+
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
|
|
81
94
|
SHARED_TENSOR_ENABLED=1 python demo.py
|
|
82
95
|
```
|
|
83
96
|
|
|
@@ -88,7 +101,7 @@ same code
|
|
|
88
101
|
|
|
89
102
|
server process client process
|
|
90
103
|
------------------------------ ------------------------------
|
|
91
|
-
provider auto-starts
|
|
104
|
+
provider auto-starts local thread provider builds client wrappers
|
|
92
105
|
shared function runs locally shared function becomes RPC call
|
|
93
106
|
CUDA object stays on same GPU CUDA object is reopened via torch IPC
|
|
94
107
|
```
|
|
@@ -151,19 +164,19 @@ SharedTensorProvider(enabled=None)
|
|
|
151
164
|
Provider runtime controls:
|
|
152
165
|
|
|
153
166
|
```python
|
|
154
|
-
SharedTensorProvider(server_process_start_method="fork")
|
|
155
167
|
SharedTensorProvider(server_startup_timeout=30.0)
|
|
156
168
|
provider.get_runtime_info()
|
|
157
169
|
```
|
|
158
170
|
|
|
159
|
-
|
|
160
|
-
Leave it as `None` to let the library choose a safer default for the current entrypoint.
|
|
171
|
+
Non-blocking provider autostart runs the UDS server in a background thread inside the current process.
|
|
161
172
|
|
|
162
173
|
`execution_mode="auto"` behaves as follows:
|
|
163
174
|
- disabled: local mode
|
|
164
|
-
- enabled + `SHARED_TENSOR_ROLE=server`: auto-start local server and execute endpoints locally
|
|
175
|
+
- enabled + `SHARED_TENSOR_ROLE=server`: auto-start a local background server thread and execute endpoints locally
|
|
165
176
|
- enabled + role unset: build client wrappers
|
|
166
177
|
|
|
178
|
+
For production deployment, prefer explicit `SharedTensorServer(...).start(blocking=True)` in a dedicated server process.
|
|
179
|
+
|
|
167
180
|
Socket selection is per CUDA device:
|
|
168
181
|
- base path comes from `SHARED_TENSOR_BASE_PATH` or `/tmp/shared-tensor`
|
|
169
182
|
- runtime socket path is `<base_path>-<device_index>.sock`
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shared-tensor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.7"
|
|
8
8
|
description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -33,18 +33,20 @@ classifiers = [
|
|
|
33
33
|
"Intended Audience :: Science/Research",
|
|
34
34
|
"Operating System :: POSIX :: Linux",
|
|
35
35
|
"Programming Language :: Python :: 3",
|
|
36
|
+
"Programming Language :: Python :: 3.9",
|
|
36
37
|
"Programming Language :: Python :: 3.10",
|
|
37
38
|
"Programming Language :: Python :: 3.11",
|
|
38
39
|
"Programming Language :: Python :: 3.12",
|
|
40
|
+
"Programming Language :: Python :: 3.13",
|
|
39
41
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
40
42
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
41
43
|
"Topic :: System :: Distributed Computing",
|
|
42
44
|
]
|
|
43
|
-
requires-python = ">=3.
|
|
45
|
+
requires-python = ">=3.9,<3.14"
|
|
44
46
|
dependencies = [
|
|
45
47
|
"cloudpickle>=3.0.0",
|
|
46
48
|
"numpy<2",
|
|
47
|
-
"torch>=2.2.
|
|
49
|
+
"torch>=2.1,<2.8",
|
|
48
50
|
]
|
|
49
51
|
|
|
50
52
|
[project.optional-dependencies]
|
|
@@ -89,7 +91,7 @@ shared_tensor = ["*.so", "*.dll", "*.dylib"]
|
|
|
89
91
|
|
|
90
92
|
[tool.black]
|
|
91
93
|
line-length = 88
|
|
92
|
-
target-version = ['py310', 'py311', 'py312']
|
|
94
|
+
target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
|
|
93
95
|
include = '\.pyi?$'
|
|
94
96
|
extend-exclude = '''
|
|
95
97
|
/(
|
|
@@ -115,7 +117,7 @@ use_parentheses = true
|
|
|
115
117
|
ensure_newline_before_comments = true
|
|
116
118
|
|
|
117
119
|
[tool.mypy]
|
|
118
|
-
python_version = "3.
|
|
120
|
+
python_version = "3.9"
|
|
119
121
|
warn_return_any = true
|
|
120
122
|
warn_unused_configs = true
|
|
121
123
|
disallow_untyped_defs = true
|
|
@@ -180,7 +182,7 @@ exclude_lines = [
|
|
|
180
182
|
]
|
|
181
183
|
|
|
182
184
|
[tool.ruff]
|
|
183
|
-
target-version = "
|
|
185
|
+
target-version = "py39"
|
|
184
186
|
line-length = 88
|
|
185
187
|
|
|
186
188
|
[tool.ruff.lint]
|
|
@@ -66,6 +66,7 @@ class TaskInfo:
|
|
|
66
66
|
class _TaskEntry:
|
|
67
67
|
info: TaskInfo
|
|
68
68
|
future: Future[Any]
|
|
69
|
+
local_result: Any = None
|
|
69
70
|
|
|
70
71
|
|
|
71
72
|
class TaskManager:
|
|
@@ -139,6 +140,8 @@ class TaskManager:
|
|
|
139
140
|
)
|
|
140
141
|
return
|
|
141
142
|
|
|
143
|
+
self._store_local_result(task_id, result)
|
|
144
|
+
|
|
142
145
|
if result is None:
|
|
143
146
|
self._transition(
|
|
144
147
|
task_id,
|
|
@@ -191,6 +194,13 @@ class TaskManager:
|
|
|
191
194
|
for key, value in updates.items():
|
|
192
195
|
setattr(entry.info, key, value)
|
|
193
196
|
|
|
197
|
+
def _store_local_result(self, task_id: str, value: Any) -> None:
|
|
198
|
+
with self._lock:
|
|
199
|
+
entry = self._tasks.get(task_id)
|
|
200
|
+
if entry is None:
|
|
201
|
+
return
|
|
202
|
+
entry.local_result = value
|
|
203
|
+
|
|
194
204
|
def get(self, task_id: str) -> TaskInfo:
|
|
195
205
|
self._maybe_cleanup()
|
|
196
206
|
with self._lock:
|
|
@@ -207,6 +217,24 @@ class TaskManager:
|
|
|
207
217
|
return None
|
|
208
218
|
return deserialize_payload(encoding, payload_bytes)
|
|
209
219
|
|
|
220
|
+
def result_local(self, task_id: str) -> Any:
|
|
221
|
+
self._maybe_cleanup()
|
|
222
|
+
with self._lock:
|
|
223
|
+
entry = self._tasks.get(task_id)
|
|
224
|
+
if entry is None:
|
|
225
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
226
|
+
info = copy.deepcopy(entry.info)
|
|
227
|
+
value = entry.local_result
|
|
228
|
+
if info.status == TaskStatus.CANCELLED:
|
|
229
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
|
|
230
|
+
if info.status == TaskStatus.FAILED:
|
|
231
|
+
raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
|
|
232
|
+
if info.status != TaskStatus.COMPLETED:
|
|
233
|
+
raise SharedTensorTaskError(
|
|
234
|
+
f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
|
|
235
|
+
)
|
|
236
|
+
return value
|
|
237
|
+
|
|
210
238
|
def wait_result_payload(
|
|
211
239
|
self,
|
|
212
240
|
task_id: str,
|
|
@@ -242,6 +270,21 @@ class TaskManager:
|
|
|
242
270
|
"object_id": info.metadata.get("object_id"),
|
|
243
271
|
}
|
|
244
272
|
|
|
273
|
+
def wait_result_local(self, task_id: str, timeout: float | None = None) -> Any:
|
|
274
|
+
self._maybe_cleanup()
|
|
275
|
+
with self._lock:
|
|
276
|
+
entry = self._tasks.get(task_id)
|
|
277
|
+
if entry is None:
|
|
278
|
+
raise SharedTensorTaskError(f"Task '{task_id}' was not found")
|
|
279
|
+
future = entry.future
|
|
280
|
+
try:
|
|
281
|
+
future.result(timeout=timeout)
|
|
282
|
+
except FutureTimeoutError as exc:
|
|
283
|
+
raise SharedTensorTaskError(
|
|
284
|
+
f"Task '{task_id}' did not complete within {timeout} seconds"
|
|
285
|
+
) from exc
|
|
286
|
+
return self.result_local(task_id)
|
|
287
|
+
|
|
245
288
|
def cancel(self, task_id: str) -> bool:
|
|
246
289
|
self._maybe_cleanup()
|
|
247
290
|
with self._lock:
|
|
@@ -8,16 +8,25 @@ from dataclasses import dataclass
|
|
|
8
8
|
from typing import Any, cast
|
|
9
9
|
|
|
10
10
|
from shared_tensor.errors import (
|
|
11
|
+
SharedTensorCapabilityError,
|
|
11
12
|
SharedTensorClientError,
|
|
13
|
+
SharedTensorConfigurationError,
|
|
14
|
+
SharedTensorError,
|
|
15
|
+
SharedTensorProviderError,
|
|
12
16
|
SharedTensorProtocolError,
|
|
13
17
|
SharedTensorRemoteError,
|
|
18
|
+
SharedTensorSerializationError,
|
|
19
|
+
SharedTensorTaskError,
|
|
14
20
|
)
|
|
15
21
|
from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
|
|
22
|
+
from shared_tensor.runtime import get_local_server
|
|
16
23
|
from shared_tensor.transport import recv_message, send_message
|
|
24
|
+
from shared_tensor.async_task import TaskStatus
|
|
17
25
|
from shared_tensor.utils import (
|
|
18
26
|
deserialize_payload,
|
|
19
27
|
resolve_runtime_socket_path,
|
|
20
28
|
serialize_call_payloads,
|
|
29
|
+
validate_payload_for_transport,
|
|
21
30
|
)
|
|
22
31
|
|
|
23
32
|
|
|
@@ -50,6 +59,54 @@ class SharedTensorClient:
|
|
|
50
59
|
self.timeout = timeout
|
|
51
60
|
self.verbose_debug = verbose_debug
|
|
52
61
|
|
|
62
|
+
def _local_server(self):
|
|
63
|
+
return get_local_server(self.socket_path)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _remote_error_from_local(exc: SharedTensorError) -> SharedTensorRemoteError:
|
|
67
|
+
if isinstance(exc, SharedTensorProtocolError):
|
|
68
|
+
code = 1
|
|
69
|
+
elif isinstance(exc, SharedTensorProviderError):
|
|
70
|
+
code = 2
|
|
71
|
+
elif isinstance(exc, SharedTensorSerializationError):
|
|
72
|
+
code = 3
|
|
73
|
+
elif isinstance(exc, SharedTensorCapabilityError):
|
|
74
|
+
code = 4
|
|
75
|
+
elif isinstance(exc, SharedTensorTaskError):
|
|
76
|
+
code = 5
|
|
77
|
+
elif isinstance(exc, SharedTensorConfigurationError):
|
|
78
|
+
code = 6
|
|
79
|
+
else:
|
|
80
|
+
code = 7
|
|
81
|
+
return SharedTensorRemoteError(
|
|
82
|
+
f"Remote error [{code}]: {exc}",
|
|
83
|
+
code=code,
|
|
84
|
+
data=None,
|
|
85
|
+
error_type=type(exc).__name__,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _run_local(self, operation):
|
|
89
|
+
try:
|
|
90
|
+
return operation()
|
|
91
|
+
except SharedTensorError as exc:
|
|
92
|
+
raise self._remote_error_from_local(exc) from exc
|
|
93
|
+
|
|
94
|
+
def _decode_local_result(self, result: Any) -> Any:
|
|
95
|
+
if result is None:
|
|
96
|
+
return None
|
|
97
|
+
value = result.value
|
|
98
|
+
if value is None:
|
|
99
|
+
return None
|
|
100
|
+
validate_payload_for_transport(value, allow_dict_keys=isinstance(value, dict))
|
|
101
|
+
object_id = result.object_id
|
|
102
|
+
if object_id is None:
|
|
103
|
+
return value
|
|
104
|
+
return SharedObjectHandle(
|
|
105
|
+
object_id=cast(str, object_id),
|
|
106
|
+
value=value,
|
|
107
|
+
_releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
|
|
108
|
+
)
|
|
109
|
+
|
|
53
110
|
def _send_request(self, request: dict[str, Any]) -> Any:
|
|
54
111
|
method = request.get("method", "<unknown>")
|
|
55
112
|
if self.verbose_debug:
|
|
@@ -104,6 +161,13 @@ class SharedTensorClient:
|
|
|
104
161
|
def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
|
|
105
162
|
if self.verbose_debug:
|
|
106
163
|
logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
|
|
164
|
+
local_server = self._local_server()
|
|
165
|
+
if local_server is not None:
|
|
166
|
+
return self._run_local(
|
|
167
|
+
lambda: self._decode_local_result(
|
|
168
|
+
local_server.call_local_client(endpoint, args=tuple(args), kwargs=dict(kwargs))
|
|
169
|
+
)
|
|
170
|
+
)
|
|
107
171
|
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
108
172
|
result = self._request(
|
|
109
173
|
"call",
|
|
@@ -119,6 +183,19 @@ class SharedTensorClient:
|
|
|
119
183
|
def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
|
|
120
184
|
if self.verbose_debug:
|
|
121
185
|
logger.debug("Client submitting task", extra={"endpoint": endpoint})
|
|
186
|
+
local_server = self._local_server()
|
|
187
|
+
if local_server is not None:
|
|
188
|
+
return self._run_local(
|
|
189
|
+
lambda: cast(
|
|
190
|
+
str,
|
|
191
|
+
local_server._submit_endpoint_task(
|
|
192
|
+
endpoint,
|
|
193
|
+
local_server.provider.get_endpoint(endpoint),
|
|
194
|
+
tuple(args),
|
|
195
|
+
dict(kwargs),
|
|
196
|
+
).task_id,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
122
199
|
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
123
200
|
result = self._request(
|
|
124
201
|
"submit",
|
|
@@ -134,18 +211,43 @@ class SharedTensorClient:
|
|
|
134
211
|
def release(self, object_id: str) -> bool:
|
|
135
212
|
if self.verbose_debug:
|
|
136
213
|
logger.debug("Client releasing managed object", extra={"object_id": object_id})
|
|
214
|
+
local_server = self._local_server()
|
|
215
|
+
if local_server is not None:
|
|
216
|
+
return self._run_local(
|
|
217
|
+
lambda: bool(local_server._handle_release_object({"object_id": object_id})["released"])
|
|
218
|
+
)
|
|
137
219
|
result = self._request("release_object", {"object_id": object_id})
|
|
138
220
|
return bool(result["released"])
|
|
139
221
|
|
|
140
222
|
def release_many(self, object_ids: list[str]) -> dict[str, bool]:
|
|
223
|
+
local_server = self._local_server()
|
|
224
|
+
if local_server is not None:
|
|
225
|
+
return self._run_local(
|
|
226
|
+
lambda: {
|
|
227
|
+
object_id: bool(released)
|
|
228
|
+
for object_id, released in local_server._handle_release_objects({"object_ids": object_ids})[
|
|
229
|
+
"released"
|
|
230
|
+
].items()
|
|
231
|
+
}
|
|
232
|
+
)
|
|
141
233
|
result = self._request("release_objects", {"object_ids": object_ids})
|
|
142
234
|
return {object_id: bool(released) for object_id, released in result["released"].items()}
|
|
143
235
|
|
|
144
236
|
def get_object_info(self, object_id: str) -> dict[str, Any] | None:
|
|
237
|
+
local_server = self._local_server()
|
|
238
|
+
if local_server is not None:
|
|
239
|
+
return self._run_local(
|
|
240
|
+
lambda: cast(
|
|
241
|
+
dict[str, Any] | None,
|
|
242
|
+
local_server._handle_get_object_info({"object_id": object_id}).get("object"),
|
|
243
|
+
)
|
|
244
|
+
)
|
|
145
245
|
result = self._request("get_object_info", {"object_id": object_id})
|
|
146
246
|
return cast(dict[str, Any] | None, result.get("object"))
|
|
147
247
|
|
|
148
248
|
def ping(self) -> bool:
|
|
249
|
+
if self._local_server() is not None:
|
|
250
|
+
return True
|
|
149
251
|
try:
|
|
150
252
|
self._request("ping")
|
|
151
253
|
except (SharedTensorClientError, SharedTensorRemoteError):
|
|
@@ -153,29 +255,66 @@ class SharedTensorClient:
|
|
|
153
255
|
return True
|
|
154
256
|
|
|
155
257
|
def get_server_info(self) -> dict[str, Any]:
|
|
258
|
+
local_server = self._local_server()
|
|
259
|
+
if local_server is not None:
|
|
260
|
+
return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
|
|
156
261
|
return cast(dict[str, Any], self._request("get_server_info"))
|
|
157
262
|
|
|
158
263
|
def list_endpoints(self) -> dict[str, Any]:
|
|
264
|
+
local_server = self._local_server()
|
|
265
|
+
if local_server is not None:
|
|
266
|
+
return self._run_local(lambda: cast(dict[str, Any], local_server.provider.list_endpoints()))
|
|
159
267
|
return cast(dict[str, Any], self._request("list_endpoints"))
|
|
160
268
|
|
|
161
269
|
def get_task_status(self, task_id: str) -> dict[str, Any]:
|
|
270
|
+
local_server = self._local_server()
|
|
271
|
+
if local_server is not None:
|
|
272
|
+
return self._run_local(
|
|
273
|
+
lambda: cast(dict[str, Any], local_server._task_manager_instance().get(task_id).to_dict())
|
|
274
|
+
)
|
|
162
275
|
return cast(dict[str, Any], self._request("get_task", {"task_id": task_id}))
|
|
163
276
|
|
|
164
277
|
def get_task_result(self, task_id: str) -> Any:
|
|
278
|
+
local_server = self._local_server()
|
|
279
|
+
if local_server is not None:
|
|
280
|
+
return self._run_local(
|
|
281
|
+
lambda: self._decode_local_result(local_server.get_task_result_local(task_id))
|
|
282
|
+
)
|
|
165
283
|
return self._decode_rpc_payload(self._request("get_task_result", {"task_id": task_id}))
|
|
166
284
|
|
|
167
285
|
def wait_task(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
|
|
168
286
|
if self.verbose_debug:
|
|
169
287
|
logger.debug("Client waiting for task", extra={"task_id": task_id, "timeout": timeout})
|
|
288
|
+
local_server = self._local_server()
|
|
289
|
+
if local_server is not None:
|
|
290
|
+
return self._run_local(
|
|
291
|
+
lambda: cast(dict[str, Any], local_server.wait_task_local(task_id, timeout=timeout))
|
|
292
|
+
)
|
|
170
293
|
params = {"task_id": task_id}
|
|
171
294
|
if timeout is not None:
|
|
172
295
|
params["timeout"] = timeout
|
|
173
296
|
return cast(dict[str, Any], self._request("wait_task", params))
|
|
174
297
|
|
|
175
298
|
def cancel_task(self, task_id: str) -> bool:
|
|
299
|
+
local_server = self._local_server()
|
|
300
|
+
if local_server is not None:
|
|
301
|
+
return self._run_local(lambda: bool(local_server._task_manager_instance().cancel(task_id)))
|
|
176
302
|
return bool(self._request("cancel_task", {"task_id": task_id})["cancelled"])
|
|
177
303
|
|
|
178
304
|
def list_tasks(self, status: str | None = None) -> dict[str, Any]:
|
|
305
|
+
local_server = self._local_server()
|
|
306
|
+
if local_server is not None:
|
|
307
|
+
return self._run_local(
|
|
308
|
+
lambda: cast(
|
|
309
|
+
dict[str, Any],
|
|
310
|
+
{
|
|
311
|
+
listed_task_id: info.to_dict()
|
|
312
|
+
for listed_task_id, info in local_server._task_manager_instance()
|
|
313
|
+
.list(status=None if status is None else TaskStatus(status))
|
|
314
|
+
.items()
|
|
315
|
+
},
|
|
316
|
+
)
|
|
317
|
+
)
|
|
179
318
|
params = {"status": status} if status else None
|
|
180
319
|
return cast(dict[str, Any], self._request("list_tasks", params))
|
|
181
320
|
|
|
@@ -92,7 +92,6 @@ class SharedTensorProvider:
|
|
|
92
92
|
device_index: int | None = None,
|
|
93
93
|
timeout: float = 30.0,
|
|
94
94
|
execution_mode: str = "auto",
|
|
95
|
-
server_process_start_method: str | None = None,
|
|
96
95
|
server_startup_timeout: float = 30.0,
|
|
97
96
|
verbose_debug: bool = False,
|
|
98
97
|
) -> None:
|
|
@@ -106,7 +105,6 @@ class SharedTensorProvider:
|
|
|
106
105
|
self.timeout = timeout
|
|
107
106
|
self.execution_mode = resolved_mode
|
|
108
107
|
self.auto_mode = auto_mode
|
|
109
|
-
self.server_process_start_method = server_process_start_method
|
|
110
108
|
self.server_startup_timeout = server_startup_timeout
|
|
111
109
|
self.verbose_debug = verbose_debug
|
|
112
110
|
self._client: Any | None = None
|
|
@@ -165,9 +163,6 @@ class SharedTensorProvider:
|
|
|
165
163
|
if self._should_autostart_server():
|
|
166
164
|
self._restart_autostart_server()
|
|
167
165
|
|
|
168
|
-
if self.execution_mode == "server":
|
|
169
|
-
return func
|
|
170
|
-
|
|
171
166
|
@wraps(func)
|
|
172
167
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
173
168
|
return self.call(endpoint_name, *args, **kwargs)
|
|
@@ -215,7 +210,11 @@ class SharedTensorProvider:
|
|
|
215
210
|
def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
|
|
216
211
|
if self.verbose_debug:
|
|
217
212
|
logger.debug("Provider dispatching call", extra={"endpoint": endpoint, "mode": self.execution_mode})
|
|
218
|
-
if self.execution_mode
|
|
213
|
+
if self.execution_mode == "server":
|
|
214
|
+
if self._server is not None and hasattr(self._server, "invoke_local"):
|
|
215
|
+
return self._server.invoke_local(endpoint, args=args, kwargs=kwargs)
|
|
216
|
+
return self.invoke_local(endpoint, args=args, kwargs=kwargs)
|
|
217
|
+
if self.execution_mode == "local":
|
|
219
218
|
return self.invoke_local(endpoint, args=args, kwargs=kwargs)
|
|
220
219
|
return self._get_client().call(endpoint, *args, **kwargs)
|
|
221
220
|
|
|
@@ -370,7 +369,6 @@ class SharedTensorProvider:
|
|
|
370
369
|
"Provider restarting autostart server",
|
|
371
370
|
extra={
|
|
372
371
|
"socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
|
|
373
|
-
"process_start_method": self.server_process_start_method,
|
|
374
372
|
},
|
|
375
373
|
)
|
|
376
374
|
if self._server is not None:
|
|
@@ -378,7 +376,6 @@ class SharedTensorProvider:
|
|
|
378
376
|
self._server = SharedTensorServer(
|
|
379
377
|
self,
|
|
380
378
|
socket_path=resolve_runtime_socket_path(self.base_path, self.device_index),
|
|
381
|
-
process_start_method=self.server_process_start_method,
|
|
382
379
|
startup_timeout=self.server_startup_timeout,
|
|
383
380
|
verbose_debug=self.verbose_debug,
|
|
384
381
|
)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""In-process runtime registry for thread-backed local servers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from threading import RLock
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from shared_tensor.server import SharedTensorServer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_LOCK = RLock()
|
|
13
|
+
_SERVERS: dict[str, "SharedTensorServer"] = {}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def register_local_server(socket_path: str, server: "SharedTensorServer") -> None:
|
|
17
|
+
with _LOCK:
|
|
18
|
+
_SERVERS[socket_path] = server
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def unregister_local_server(socket_path: str, server: "SharedTensorServer") -> None:
|
|
22
|
+
with _LOCK:
|
|
23
|
+
current = _SERVERS.get(socket_path)
|
|
24
|
+
if current is server:
|
|
25
|
+
_SERVERS.pop(socket_path, None)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_local_server(socket_path: str) -> "SharedTensorServer | None":
|
|
29
|
+
with _LOCK:
|
|
30
|
+
return _SERVERS.get(socket_path)
|
|
@@ -2,16 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import cloudpickle
|
|
6
5
|
import logging
|
|
7
|
-
import multiprocessing as mp
|
|
8
6
|
import os
|
|
9
|
-
import sys
|
|
10
7
|
import socket
|
|
11
8
|
import threading
|
|
12
9
|
import time
|
|
13
10
|
from concurrent.futures import Future
|
|
14
|
-
from dataclasses import dataclass
|
|
11
|
+
from dataclasses import dataclass, field
|
|
15
12
|
from typing import Any
|
|
16
13
|
|
|
17
14
|
from shared_tensor.async_task import TaskManager, TaskStatus
|
|
@@ -25,6 +22,7 @@ from shared_tensor.errors import (
|
|
|
25
22
|
)
|
|
26
23
|
from shared_tensor.managed_object import ManagedObjectRegistry
|
|
27
24
|
from shared_tensor.provider import EndpointDefinition, SharedTensorProvider
|
|
25
|
+
from shared_tensor.runtime import register_local_server, unregister_local_server
|
|
28
26
|
from shared_tensor.transport import recv_message, send_message
|
|
29
27
|
from shared_tensor.utils import (
|
|
30
28
|
CONTROL_ENCODING,
|
|
@@ -41,11 +39,33 @@ from shared_tensor.utils import (
|
|
|
41
39
|
logger = logging.getLogger(__name__)
|
|
42
40
|
|
|
43
41
|
|
|
42
|
+
def _server_version() -> str:
|
|
43
|
+
try:
|
|
44
|
+
from shared_tensor import __version__
|
|
45
|
+
except ImportError:
|
|
46
|
+
return "unknown"
|
|
47
|
+
return __version__
|
|
48
|
+
|
|
49
|
+
|
|
44
50
|
@dataclass(slots=True)
|
|
45
51
|
class _InFlightCall:
|
|
46
52
|
future: Future[dict[str, Any]]
|
|
47
53
|
|
|
48
54
|
|
|
55
|
+
@dataclass(slots=True)
|
|
56
|
+
class _ServerThreadState:
|
|
57
|
+
thread: threading.Thread
|
|
58
|
+
ready: threading.Event = field(default_factory=threading.Event)
|
|
59
|
+
stopped: threading.Event = field(default_factory=threading.Event)
|
|
60
|
+
error: BaseException | None = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(slots=True)
|
|
64
|
+
class _EndpointResult:
|
|
65
|
+
value: Any
|
|
66
|
+
object_id: str | None = None
|
|
67
|
+
|
|
68
|
+
|
|
49
69
|
class SharedTensorServer:
|
|
50
70
|
def __init__(
|
|
51
71
|
self,
|
|
@@ -72,6 +92,7 @@ class SharedTensorServer:
|
|
|
72
92
|
self.startup_timeout = startup_timeout
|
|
73
93
|
self.listener: socket.socket | None = None
|
|
74
94
|
self.server_process: Any | None = None
|
|
95
|
+
self.server_thread: _ServerThreadState | None = None
|
|
75
96
|
self._resolved_process_start_method: str | None = None
|
|
76
97
|
self.running = False
|
|
77
98
|
self.started_at: float | None = None
|
|
@@ -81,10 +102,13 @@ class SharedTensorServer:
|
|
|
81
102
|
}
|
|
82
103
|
self._task_manager: TaskManager | None = None
|
|
83
104
|
self._cache: dict[str, dict[str, Any]] = {}
|
|
105
|
+
self._local_cache: dict[str, Any] = {}
|
|
84
106
|
self._managed_objects = ManagedObjectRegistry()
|
|
85
107
|
self._inflight: dict[str, _InFlightCall] = {}
|
|
86
108
|
self._endpoint_locks: dict[str, threading.Lock] = {}
|
|
87
109
|
self._coordination_lock = threading.RLock()
|
|
110
|
+
if getattr(self.provider, "_server", None) is None:
|
|
111
|
+
self.provider._server = self
|
|
88
112
|
|
|
89
113
|
def process_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
|
90
114
|
if self.verbose_debug:
|
|
@@ -180,22 +204,22 @@ class SharedTensorServer:
|
|
|
180
204
|
) -> Any:
|
|
181
205
|
return self._task_manager_instance().submit(
|
|
182
206
|
endpoint,
|
|
183
|
-
self.
|
|
207
|
+
self._execute_endpoint_result,
|
|
184
208
|
(endpoint, definition, args, kwargs),
|
|
185
209
|
{},
|
|
186
|
-
result_encoder=
|
|
210
|
+
result_encoder=self._encode_endpoint_result,
|
|
187
211
|
)
|
|
188
212
|
|
|
189
|
-
def
|
|
213
|
+
def _execute_endpoint_result(
|
|
190
214
|
self,
|
|
191
215
|
endpoint: str,
|
|
192
216
|
definition: EndpointDefinition,
|
|
193
217
|
args: tuple[Any, ...],
|
|
194
218
|
kwargs: dict[str, Any],
|
|
195
|
-
) ->
|
|
219
|
+
) -> _EndpointResult:
|
|
196
220
|
cache_key = self._cache_key(endpoint, definition, args, kwargs)
|
|
197
221
|
if cache_key is not None:
|
|
198
|
-
cached = self.
|
|
222
|
+
cached = self._lookup_cached_result_value(definition, cache_key)
|
|
199
223
|
if cached is not None:
|
|
200
224
|
if self.verbose_debug:
|
|
201
225
|
logger.debug("Server cache hit", extra={"endpoint": endpoint, "cache_key": cache_key})
|
|
@@ -207,20 +231,15 @@ class SharedTensorServer:
|
|
|
207
231
|
if self.verbose_debug and owner:
|
|
208
232
|
logger.debug("Server created singleflight entry", extra={"endpoint": endpoint, "cache_key": inflight_key})
|
|
209
233
|
if not owner:
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
object_id = payload.get("object_id")
|
|
215
|
-
if object_id is not None:
|
|
216
|
-
self._managed_objects.add_ref(object_id)
|
|
217
|
-
return payload
|
|
218
|
-
return future.result()
|
|
234
|
+
result = future.result()
|
|
235
|
+
if definition.managed and result.object_id is not None:
|
|
236
|
+
self._managed_objects.add_ref(result.object_id)
|
|
237
|
+
return result
|
|
219
238
|
else:
|
|
220
239
|
future = None
|
|
221
240
|
|
|
222
241
|
try:
|
|
223
|
-
|
|
242
|
+
result = self._run_endpoint_under_policy(endpoint, definition, args, kwargs, cache_key)
|
|
224
243
|
except Exception as exc:
|
|
225
244
|
if future is not None:
|
|
226
245
|
future.set_exception(exc)
|
|
@@ -228,9 +247,20 @@ class SharedTensorServer:
|
|
|
228
247
|
raise
|
|
229
248
|
|
|
230
249
|
if future is not None:
|
|
231
|
-
future.set_result(
|
|
250
|
+
future.set_result(result)
|
|
232
251
|
self._release_inflight(inflight_key, future)
|
|
233
|
-
return
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
def _execute_endpoint_call(
|
|
255
|
+
self,
|
|
256
|
+
endpoint: str,
|
|
257
|
+
definition: EndpointDefinition,
|
|
258
|
+
args: tuple[Any, ...],
|
|
259
|
+
kwargs: dict[str, Any],
|
|
260
|
+
) -> dict[str, Any]:
|
|
261
|
+
return self._encode_endpoint_result(
|
|
262
|
+
self._execute_endpoint_result(endpoint, definition, args, kwargs)
|
|
263
|
+
)
|
|
234
264
|
|
|
235
265
|
def _run_endpoint_under_policy(
|
|
236
266
|
self,
|
|
@@ -239,11 +269,11 @@ class SharedTensorServer:
|
|
|
239
269
|
args: tuple[Any, ...],
|
|
240
270
|
kwargs: dict[str, Any],
|
|
241
271
|
cache_key: str | None,
|
|
242
|
-
) ->
|
|
272
|
+
) -> _EndpointResult:
|
|
243
273
|
if definition.concurrency == "serialized":
|
|
244
274
|
lock = self._endpoint_lock(endpoint)
|
|
245
275
|
with lock:
|
|
246
|
-
cached = self.
|
|
276
|
+
cached = self._lookup_cached_result_value(definition, cache_key)
|
|
247
277
|
if cached is not None:
|
|
248
278
|
return cached
|
|
249
279
|
return self._materialize_endpoint_result(endpoint, definition, args, kwargs, cache_key)
|
|
@@ -256,16 +286,15 @@ class SharedTensorServer:
|
|
|
256
286
|
args: tuple[Any, ...],
|
|
257
287
|
kwargs: dict[str, Any],
|
|
258
288
|
cache_key: str | None,
|
|
259
|
-
) ->
|
|
289
|
+
) -> _EndpointResult:
|
|
260
290
|
if definition.managed:
|
|
261
291
|
return self._materialize_managed_result(endpoint, definition, args, kwargs, cache_key)
|
|
262
292
|
value = definition.func(*args, **kwargs)
|
|
263
293
|
if self.verbose_debug:
|
|
264
294
|
logger.debug("Server executed direct endpoint", extra={"endpoint": endpoint})
|
|
265
|
-
result = self._encode_result(value)
|
|
266
295
|
if cache_key is not None:
|
|
267
|
-
self.
|
|
268
|
-
return
|
|
296
|
+
self._local_cache[cache_key] = value
|
|
297
|
+
return _EndpointResult(value=value)
|
|
269
298
|
|
|
270
299
|
def _materialize_managed_result(
|
|
271
300
|
self,
|
|
@@ -274,24 +303,24 @@ class SharedTensorServer:
|
|
|
274
303
|
args: tuple[Any, ...],
|
|
275
304
|
kwargs: dict[str, Any],
|
|
276
305
|
cache_key: str | None,
|
|
277
|
-
) ->
|
|
306
|
+
) -> _EndpointResult:
|
|
278
307
|
if cache_key is not None:
|
|
279
308
|
cached = self._managed_objects.get_cached(cache_key)
|
|
280
309
|
if cached is not None:
|
|
281
310
|
self._managed_objects.add_ref(cached.object_id)
|
|
282
|
-
return
|
|
311
|
+
return _EndpointResult(value=cached.value, object_id=cached.object_id)
|
|
283
312
|
|
|
284
313
|
result = definition.func(*args, **kwargs)
|
|
285
314
|
if self.verbose_debug:
|
|
286
315
|
logger.debug("Server created managed object", extra={"endpoint": endpoint, "cache_key": cache_key})
|
|
287
316
|
entry = self._managed_objects.register(endpoint=endpoint, value=result, cache_key=cache_key)
|
|
288
|
-
return
|
|
317
|
+
return _EndpointResult(value=entry.value, object_id=entry.object_id)
|
|
289
318
|
|
|
290
|
-
def
|
|
319
|
+
def _lookup_cached_result_value(
|
|
291
320
|
self,
|
|
292
321
|
definition: EndpointDefinition,
|
|
293
322
|
cache_key: str | None,
|
|
294
|
-
) ->
|
|
323
|
+
) -> _EndpointResult | None:
|
|
295
324
|
if cache_key is None:
|
|
296
325
|
return None
|
|
297
326
|
if definition.managed:
|
|
@@ -299,8 +328,81 @@ class SharedTensorServer:
|
|
|
299
328
|
if cached is None:
|
|
300
329
|
return None
|
|
301
330
|
self._managed_objects.add_ref(cached.object_id)
|
|
302
|
-
return
|
|
303
|
-
|
|
331
|
+
return _EndpointResult(value=cached.value, object_id=cached.object_id)
|
|
332
|
+
local_value = self._local_cache.get(cache_key)
|
|
333
|
+
if local_value is None:
|
|
334
|
+
return None
|
|
335
|
+
return _EndpointResult(value=local_value)
|
|
336
|
+
|
|
337
|
+
def call_local_client(
|
|
338
|
+
self,
|
|
339
|
+
endpoint: str,
|
|
340
|
+
*,
|
|
341
|
+
args: tuple[Any, ...] = (),
|
|
342
|
+
kwargs: dict[str, Any] | None = None,
|
|
343
|
+
) -> _EndpointResult | None:
|
|
344
|
+
definition = self.provider.get_endpoint(endpoint)
|
|
345
|
+
resolved_kwargs = kwargs or {}
|
|
346
|
+
if definition.execution == "task":
|
|
347
|
+
task_info = self._submit_endpoint_task(endpoint, definition, args, resolved_kwargs)
|
|
348
|
+
return self.wait_task_result_local(task_info.task_id)
|
|
349
|
+
return self._execute_endpoint_result(endpoint, definition, args, resolved_kwargs)
|
|
350
|
+
|
|
351
|
+
def get_task_result_local(self, task_id: str) -> _EndpointResult | None:
|
|
352
|
+
result = self._task_manager_instance().result_local(task_id)
|
|
353
|
+
if result is None:
|
|
354
|
+
return None
|
|
355
|
+
return result
|
|
356
|
+
|
|
357
|
+
def wait_task_result_local(self, task_id: str, timeout: float | None = None) -> _EndpointResult | None:
|
|
358
|
+
result = self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
|
|
359
|
+
if result is None:
|
|
360
|
+
return None
|
|
361
|
+
return result
|
|
362
|
+
|
|
363
|
+
def wait_task_local(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
|
|
364
|
+
try:
|
|
365
|
+
self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
|
|
366
|
+
except SharedTensorTaskError:
|
|
367
|
+
info = self._task_manager_instance().get(task_id)
|
|
368
|
+
if info.status in {TaskStatus.PENDING, TaskStatus.RUNNING}:
|
|
369
|
+
return info.to_dict()
|
|
370
|
+
raise
|
|
371
|
+
return self._task_manager_instance().get(task_id).to_dict()
|
|
372
|
+
|
|
373
|
+
def encode_local_result(self, result: _EndpointResult | None) -> dict[str, Any]:
|
|
374
|
+
if result is None:
|
|
375
|
+
return {"encoding": None, "payload_bytes": None, "object_id": None}
|
|
376
|
+
return self._encode_endpoint_result(result)
|
|
377
|
+
|
|
378
|
+
def invoke_local(
|
|
379
|
+
self,
|
|
380
|
+
endpoint: str,
|
|
381
|
+
*,
|
|
382
|
+
args: tuple[Any, ...] = (),
|
|
383
|
+
kwargs: dict[str, Any] | None = None,
|
|
384
|
+
) -> Any:
|
|
385
|
+
definition = self.provider.get_endpoint(endpoint)
|
|
386
|
+
resolved_kwargs = kwargs or {}
|
|
387
|
+
cache_key = self._cache_key(endpoint, definition, args, resolved_kwargs)
|
|
388
|
+
if definition.managed:
|
|
389
|
+
if cache_key is not None:
|
|
390
|
+
cached = self._managed_objects.get_cached(cache_key)
|
|
391
|
+
if cached is not None:
|
|
392
|
+
return cached.value
|
|
393
|
+
value = definition.func(*args, **resolved_kwargs)
|
|
394
|
+
if cache_key is not None:
|
|
395
|
+
existing = self._managed_objects.get_cached(cache_key)
|
|
396
|
+
if existing is not None:
|
|
397
|
+
return existing.value
|
|
398
|
+
self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
|
|
399
|
+
return value
|
|
400
|
+
if cache_key is not None and cache_key in self._local_cache:
|
|
401
|
+
return self._local_cache[cache_key]
|
|
402
|
+
value = definition.func(*args, **resolved_kwargs)
|
|
403
|
+
if cache_key is not None:
|
|
404
|
+
self._local_cache[cache_key] = value
|
|
405
|
+
return value
|
|
304
406
|
|
|
305
407
|
def _cache_key(
|
|
306
408
|
self,
|
|
@@ -400,6 +502,9 @@ class SharedTensorServer:
|
|
|
400
502
|
encoding, payload = serialize_payload(value)
|
|
401
503
|
return {"encoding": encoding, "payload_bytes": payload, "object_id": object_id}
|
|
402
504
|
|
|
505
|
+
def _encode_endpoint_result(self, result: _EndpointResult) -> dict[str, Any]:
|
|
506
|
+
return self._encode_result(result.value, object_id=result.object_id)
|
|
507
|
+
|
|
403
508
|
def _task_manager_instance(self) -> TaskManager:
|
|
404
509
|
if self._task_manager is None:
|
|
405
510
|
self._task_manager = TaskManager(
|
|
@@ -426,7 +531,7 @@ class SharedTensorServer:
|
|
|
426
531
|
uptime = 0.0 if self.started_at is None else time.time() - self.started_at
|
|
427
532
|
return {
|
|
428
533
|
"server": "SharedTensorServer",
|
|
429
|
-
"version":
|
|
534
|
+
"version": _server_version(),
|
|
430
535
|
"socket_path": self.socket_path,
|
|
431
536
|
"uptime": uptime,
|
|
432
537
|
"running": self.running,
|
|
@@ -448,101 +553,66 @@ class SharedTensorServer:
|
|
|
448
553
|
"data": None,
|
|
449
554
|
}
|
|
450
555
|
|
|
451
|
-
def _resolve_process_start_method(self) -> str:
|
|
452
|
-
if self.process_start_method is not None:
|
|
453
|
-
allowed = set(mp.get_all_start_methods())
|
|
454
|
-
if self.process_start_method not in allowed:
|
|
455
|
-
raise SharedTensorConfigurationError(
|
|
456
|
-
f"Unsupported process_start_method '{self.process_start_method}'"
|
|
457
|
-
)
|
|
458
|
-
return self.process_start_method
|
|
459
|
-
if os.name != "posix":
|
|
460
|
-
return "spawn"
|
|
461
|
-
try:
|
|
462
|
-
import torch
|
|
463
|
-
except ImportError:
|
|
464
|
-
torch = None
|
|
465
|
-
if torch is not None and torch.cuda.is_available() and torch.cuda.is_initialized():
|
|
466
|
-
return "spawn"
|
|
467
|
-
if not hasattr(sys.modules.get("__main__"), "__file__"):
|
|
468
|
-
return "fork"
|
|
469
|
-
return "spawn"
|
|
470
|
-
|
|
471
556
|
def start(self, blocking: bool = True) -> None:
|
|
472
557
|
if self.verbose_debug:
|
|
473
558
|
logger.info("Server starting", extra={"socket_path": self.socket_path, "blocking": blocking})
|
|
474
|
-
if self.running:
|
|
559
|
+
if self.running or self.server_thread is not None:
|
|
475
560
|
raise SharedTensorConfigurationError("Server is already running")
|
|
476
561
|
if blocking:
|
|
477
562
|
self._resolved_process_start_method = None
|
|
478
563
|
self._serve_forever()
|
|
479
564
|
return
|
|
480
|
-
if
|
|
565
|
+
if self.process_start_method is not None:
|
|
481
566
|
raise SharedTensorConfigurationError(
|
|
482
|
-
"
|
|
567
|
+
"process_start_method is not supported for thread-backed non-blocking servers"
|
|
483
568
|
)
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
args=(
|
|
489
|
-
payload,
|
|
490
|
-
self.socket_path,
|
|
491
|
-
self.max_request_bytes,
|
|
492
|
-
self.max_workers,
|
|
493
|
-
self.result_ttl,
|
|
494
|
-
self.verbose_debug,
|
|
495
|
-
start_method,
|
|
496
|
-
),
|
|
497
|
-
name=f"shared-tensor-daemon:{self.socket_path}",
|
|
569
|
+
thread = threading.Thread(
|
|
570
|
+
target=self._serve_forever_in_thread,
|
|
571
|
+
name=f"shared-tensor-server:{self.socket_path}",
|
|
572
|
+
daemon=True,
|
|
498
573
|
)
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
574
|
+
state = _ServerThreadState(thread=thread)
|
|
575
|
+
self.server_thread = state
|
|
576
|
+
self._resolved_process_start_method = "thread"
|
|
577
|
+
thread.start()
|
|
578
|
+
if not state.ready.wait(timeout=self.startup_timeout):
|
|
579
|
+
self.stop()
|
|
580
|
+
raise TimeoutError(f"Timed out waiting for server socket {self.socket_path}")
|
|
581
|
+
if state.error is not None:
|
|
582
|
+
error = state.error
|
|
583
|
+
self.stop()
|
|
584
|
+
raise SharedTensorConfigurationError(
|
|
585
|
+
f"Failed to start background server thread for {self.socket_path}: {error}"
|
|
586
|
+
) from error
|
|
509
587
|
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
server = SharedTensorServer(
|
|
523
|
-
provider,
|
|
524
|
-
socket_path=socket_path,
|
|
525
|
-
max_request_bytes=max_request_bytes,
|
|
526
|
-
max_workers=max_workers,
|
|
527
|
-
result_ttl=result_ttl,
|
|
528
|
-
process_start_method=process_start_method,
|
|
529
|
-
verbose_debug=verbose_debug,
|
|
530
|
-
)
|
|
531
|
-
server._resolved_process_start_method = process_start_method
|
|
532
|
-
server._serve_forever()
|
|
588
|
+
def _serve_forever_in_thread(self) -> None:
|
|
589
|
+
state = self.server_thread
|
|
590
|
+
if state is None:
|
|
591
|
+
return
|
|
592
|
+
try:
|
|
593
|
+
self._serve_forever(started_event=state.ready)
|
|
594
|
+
except BaseException as exc: # noqa: BLE001
|
|
595
|
+
state.error = exc
|
|
596
|
+
state.ready.set()
|
|
597
|
+
raise
|
|
598
|
+
finally:
|
|
599
|
+
state.stopped.set()
|
|
533
600
|
|
|
534
|
-
def _serve_forever(self) -> None:
|
|
601
|
+
def _serve_forever(self, *, started_event: threading.Event | None = None) -> None:
|
|
535
602
|
self._configure_cuda_runtime()
|
|
536
603
|
unlink_socket_path(self.socket_path)
|
|
537
604
|
listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
538
|
-
listener.bind(self.socket_path)
|
|
539
|
-
listener.listen()
|
|
540
|
-
if self.verbose_debug:
|
|
541
|
-
logger.info("Server listening", extra={"socket_path": self.socket_path})
|
|
542
|
-
self.listener = listener
|
|
543
|
-
self.running = True
|
|
544
|
-
self.started_at = time.time()
|
|
545
605
|
try:
|
|
606
|
+
listener.bind(self.socket_path)
|
|
607
|
+
listener.listen()
|
|
608
|
+
if self.verbose_debug:
|
|
609
|
+
logger.info("Server listening", extra={"socket_path": self.socket_path})
|
|
610
|
+
self.listener = listener
|
|
611
|
+
self.running = True
|
|
612
|
+
self.started_at = time.time()
|
|
613
|
+
register_local_server(self.socket_path, self)
|
|
614
|
+
if started_event is not None:
|
|
615
|
+
started_event.set()
|
|
546
616
|
while self.running:
|
|
547
617
|
try:
|
|
548
618
|
conn, _ = listener.accept()
|
|
@@ -553,6 +623,8 @@ class SharedTensorServer:
|
|
|
553
623
|
thread = threading.Thread(target=self._handle_connection, args=(conn,), daemon=True)
|
|
554
624
|
thread.start()
|
|
555
625
|
finally:
|
|
626
|
+
if started_event is not None and not started_event.is_set():
|
|
627
|
+
started_event.set()
|
|
556
628
|
self._shutdown_local_resources()
|
|
557
629
|
|
|
558
630
|
def _handle_connection(self, conn: socket.socket) -> None:
|
|
@@ -586,24 +658,20 @@ class SharedTensorServer:
|
|
|
586
658
|
def stop(self) -> None:
|
|
587
659
|
if self.verbose_debug:
|
|
588
660
|
logger.info("Server stopping", extra={"socket_path": self.socket_path})
|
|
589
|
-
if not self.running:
|
|
590
|
-
unlink_socket_path(self.socket_path)
|
|
591
|
-
return
|
|
592
661
|
self.running = False
|
|
593
|
-
if self.server_process is not None:
|
|
594
|
-
self.server_process.terminate()
|
|
595
|
-
self.server_process.join(timeout=5)
|
|
596
|
-
if self.server_process.is_alive():
|
|
597
|
-
self.server_process.kill()
|
|
598
|
-
self.server_process.join(timeout=5)
|
|
599
|
-
self.server_process = None
|
|
600
|
-
unlink_socket_path(self.socket_path)
|
|
601
|
-
return
|
|
602
662
|
if self.listener is not None:
|
|
603
663
|
self.listener.close()
|
|
604
|
-
self.
|
|
664
|
+
state = self.server_thread
|
|
665
|
+
if state is not None and state.thread.is_alive() and threading.current_thread() is not state.thread:
|
|
666
|
+
state.stopped.wait(timeout=5)
|
|
667
|
+
state.thread.join(timeout=5)
|
|
668
|
+
self.server_thread = None
|
|
669
|
+
self.server_process = None
|
|
670
|
+
if self.listener is None:
|
|
671
|
+
unlink_socket_path(self.socket_path)
|
|
605
672
|
|
|
606
673
|
def _shutdown_local_resources(self) -> None:
|
|
674
|
+
self.running = False
|
|
607
675
|
if self.listener is not None:
|
|
608
676
|
self.listener.close()
|
|
609
677
|
self.listener = None
|
|
@@ -612,8 +680,10 @@ class SharedTensorServer:
|
|
|
612
680
|
self._task_manager = None
|
|
613
681
|
self._managed_objects.clear()
|
|
614
682
|
self._cache.clear()
|
|
683
|
+
self._local_cache.clear()
|
|
615
684
|
self._inflight.clear()
|
|
616
685
|
self._endpoint_locks.clear()
|
|
686
|
+
unregister_local_server(self.socket_path, self)
|
|
617
687
|
unlink_socket_path(self.socket_path)
|
|
618
688
|
|
|
619
689
|
def __enter__(self) -> SharedTensorServer:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|