shared-tensor 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/PKG-INFO +143 -111
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/README.md +141 -110
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/pyproject.toml +2 -1
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/__init__.py +1 -1
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/async_provider.py +2 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/provider.py +21 -2
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/server.py +51 -9
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/utils.py +58 -3
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/LICENSE +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/MANIFEST.in +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/setup.cfg +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/async_client.py +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/async_task.py +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/client.py +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/errors.py +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/jsonrpc.py +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor/managed_object.py +0 -0
- {shared_tensor-0.2.2 → shared_tensor-0.2.4}/shared_tensor.egg-info/SOURCES.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shared-tensor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: Local endpoint-oriented RPC for same-host same-GPU PyTorch IPC
|
|
5
5
|
Author-email: Athena Team <contact@world-sim-dev.org>
|
|
6
6
|
Maintainer-email: Athena Team <contact@world-sim-dev.org>
|
|
@@ -25,6 +25,7 @@ Classifier: Topic :: System :: Distributed Computing
|
|
|
25
25
|
Requires-Python: >=3.10
|
|
26
26
|
Description-Content-Type: text/markdown
|
|
27
27
|
License-File: LICENSE
|
|
28
|
+
Requires-Dist: cloudpickle>=3.0.0
|
|
28
29
|
Requires-Dist: numpy<2
|
|
29
30
|
Requires-Dist: requests>=2.25.0
|
|
30
31
|
Requires-Dist: torch>=2.2.0
|
|
@@ -63,6 +64,7 @@ Dynamic: license-file
|
|
|
63
64
|
- task-backed slow object construction
|
|
64
65
|
- endpoint-level serialization and cache-key singleflight
|
|
65
66
|
- zero-branch auto mode driven by `SHARED_TENSOR_ROLE`
|
|
67
|
+
- auto mode is gated by `SHARED_TENSOR_ENABLED=1`
|
|
66
68
|
- port routing by `base_port + cuda_device_index`
|
|
67
69
|
|
|
68
70
|
## What It Does Not Support
|
|
@@ -104,9 +106,40 @@ conda activate shared-tensor-dev
|
|
|
104
106
|
pip install -e ".[dev,test]"
|
|
105
107
|
```
|
|
106
108
|
|
|
107
|
-
##
|
|
109
|
+
## Enabling Auto Mode
|
|
108
110
|
|
|
109
|
-
|
|
111
|
+
`SharedTensorProvider()` now defaults to safe local mode unless you explicitly enable shared-tensor behavior.
|
|
112
|
+
|
|
113
|
+
Global default:
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
export SHARED_TENSOR_ENABLED=1
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
Per-provider override:
|
|
120
|
+
|
|
121
|
+
```python
|
|
122
|
+
provider = SharedTensorProvider(enabled=True)
|
|
123
|
+
provider = SharedTensorProvider(enabled=False)
|
|
124
|
+
provider = SharedTensorProvider(enabled=None)
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
`enabled=None` means do not override and keep using the environment variable.
|
|
128
|
+
|
|
129
|
+
Then `execution_mode="auto"` behaves like this:
|
|
130
|
+
|
|
131
|
+
- `enabled=False`: provider stays in local mode
|
|
132
|
+
- `enabled=True` and `SHARED_TENSOR_ROLE=server`: auto-start server and execute locally on the server side
|
|
133
|
+
- `enabled=True` and no role set: provider becomes a client wrapper
|
|
134
|
+
- `enabled=None`: fall back to `SHARED_TENSOR_ENABLED`
|
|
135
|
+
|
|
136
|
+
This makes accidental opt-in much less likely in scripts that import shared endpoints but did not intend to start RPC behavior.
|
|
137
|
+
|
|
138
|
+
## Example 1: Zero-Branch Auto Mode
|
|
139
|
+
|
|
140
|
+
See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
|
|
141
|
+
|
|
142
|
+
One file, two processes, no branch in user code:
|
|
110
143
|
|
|
111
144
|
```python
|
|
112
145
|
import torch
|
|
@@ -141,55 +174,56 @@ if __name__ == "__main__":
|
|
|
141
174
|
Run process A as the auto server:
|
|
142
175
|
|
|
143
176
|
```bash
|
|
144
|
-
SHARED_TENSOR_ROLE=server python demo.py
|
|
177
|
+
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
|
|
145
178
|
```
|
|
146
179
|
|
|
147
180
|
Run process B as the client with the exact same file:
|
|
148
181
|
|
|
149
182
|
```bash
|
|
183
|
+
SHARED_TENSOR_ENABLED=1 python demo.py
|
|
184
|
+
```
|
|
185
|
+
|
|
186
|
+
Equivalent stepwise form:
|
|
187
|
+
|
|
188
|
+
```bash
|
|
189
|
+
export SHARED_TENSOR_ENABLED=1
|
|
190
|
+
SHARED_TENSOR_ROLE=server python demo.py
|
|
150
191
|
python demo.py
|
|
151
192
|
```
|
|
152
193
|
|
|
153
194
|
Behavior:
|
|
154
195
|
|
|
196
|
+
- `SHARED_TENSOR_ENABLED=1` enables shared-tensor auto behavior for providers that keep `enabled=None`
|
|
155
197
|
- `SHARED_TENSOR_ROLE=server` makes the provider auto-start a background localhost daemon
|
|
156
198
|
- in the server process, shared functions still execute locally
|
|
157
199
|
- in the client process, the same function names become RPC wrappers
|
|
158
200
|
- no `SHARED_TENSOR_HOST` is used; transport is fixed to `127.0.0.1`
|
|
159
201
|
- the final port is `SHARED_TENSOR_BASE_PORT + current_cuda_device_index`
|
|
160
202
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
Each endpoint is registered once and then supports two client-side call styles.
|
|
203
|
+
Why this works:
|
|
164
204
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
- blocks until the result is ready
|
|
168
|
-
- `fn.submit(...)` or `provider.submit(name, ...)`
|
|
169
|
-
- asynchronous
|
|
170
|
-
- returns a task id
|
|
205
|
+
```text
|
|
206
|
+
same code file
|
|
171
207
|
|
|
172
|
-
|
|
208
|
+
Process A Process B
|
|
209
|
+
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ENABLED=1
|
|
210
|
+
SHARED_TENSOR_ROLE=server SHARED_TENSOR_ROLE unset
|
|
211
|
+
---------------------------------- ----------------------------------
|
|
212
|
+
provider.share(...) provider.share(...)
|
|
213
|
+
provider auto-starts localhost daemon provider builds RPC wrappers
|
|
214
|
+
shared fn executes locally shared fn becomes RPC call
|
|
173
215
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
- async submit is the natural path for slow model construction
|
|
180
|
-
- `concurrency="parallel"`
|
|
181
|
-
- multiple server executions may run at once
|
|
182
|
-
- `concurrency="serialized"`
|
|
183
|
-
- only one execution of that endpoint runs at a time
|
|
184
|
-
- `singleflight=True`
|
|
185
|
-
- identical in-flight cache keys collapse to one execution
|
|
186
|
-
- this is the recommended model-loading default
|
|
216
|
+
load_model(...) load_model(...)
|
|
217
|
+
-> local CUDA model -> JSON-RPC to localhost daemon
|
|
218
|
+
identity(x) -> receives CUDA IPC-backed result
|
|
219
|
+
-> local tensor return
|
|
220
|
+
```
|
|
187
221
|
|
|
188
|
-
|
|
222
|
+
Use this mode when you want the cleanest operator experience: one script, one env var difference, server side stays local, client side becomes remote automatically.
|
|
189
223
|
|
|
190
|
-
|
|
224
|
+
## Example 2: Fast Tensor Transform
|
|
191
225
|
|
|
192
|
-
|
|
226
|
+
See [examples/model_service.py](./examples/model_service.py).
|
|
193
227
|
|
|
194
228
|
```python
|
|
195
229
|
@provider.share(execution="direct", cache=False)
|
|
@@ -197,6 +231,14 @@ def scale_tensor(tensor: torch.Tensor, factor: torch.Tensor) -> torch.Tensor:
|
|
|
197
231
|
return tensor * factor
|
|
198
232
|
```
|
|
199
233
|
|
|
234
|
+
What happens on the wire:
|
|
235
|
+
|
|
236
|
+
```text
|
|
237
|
+
client tensor -> direct RPC -> server runs function immediately -> CUDA result back
|
|
238
|
+
```
|
|
239
|
+
|
|
240
|
+
Use this for cheap tensor math, lightweight preprocessing, and request-scoped outputs.
|
|
241
|
+
|
|
200
242
|
Recommended combination:
|
|
201
243
|
|
|
202
244
|
- `execution="direct"`
|
|
@@ -204,130 +246,120 @@ Recommended combination:
|
|
|
204
246
|
- `managed=False`
|
|
205
247
|
- `concurrency="parallel"`
|
|
206
248
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
- direct execution has the lowest overhead
|
|
210
|
-
- these calls are request-scoped, so caching is usually wrong
|
|
211
|
-
- parallel execution is usually fine because the work is short-lived
|
|
249
|
+
## Example 3: Reusable Model Service
|
|
212
250
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
Use this for loading or building a CUDA model that may take hundreds of milliseconds or multiple seconds.
|
|
251
|
+
See [examples/model_service.py](./examples/model_service.py).
|
|
216
252
|
|
|
217
253
|
```python
|
|
218
254
|
@provider.share(
|
|
219
255
|
execution="task",
|
|
220
256
|
managed=True,
|
|
221
257
|
concurrency="serialized",
|
|
222
|
-
cache_format_key="model:{
|
|
258
|
+
cache_format_key="model:{input_dim}:{output_dim}",
|
|
223
259
|
)
|
|
224
|
-
def
|
|
225
|
-
|
|
260
|
+
def load_linear_model(input_dim: int = 16, output_dim: int = 4) -> torch.nn.Module:
|
|
261
|
+
...
|
|
226
262
|
```
|
|
227
263
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
- `execution="task"`
|
|
231
|
-
- `managed=True`
|
|
232
|
-
- `concurrency="serialized"`
|
|
233
|
-
- `singleflight=True`
|
|
234
|
-
- explicit `cache_format_key`
|
|
235
|
-
|
|
236
|
-
Why:
|
|
237
|
-
|
|
238
|
-
- task mode gives you both blocking sync calls and true async submission
|
|
239
|
-
- managed handles let the client release the remote object explicitly
|
|
240
|
-
- serialized execution avoids multiple concurrent heavy loads on one GPU
|
|
241
|
-
- singleflight prevents duplicate in-flight construction for the same model key
|
|
242
|
-
|
|
243
|
-
### 3. Reusable Shared Model Service
|
|
264
|
+
What happens when two clients ask for the same model key:
|
|
244
265
|
|
|
245
|
-
|
|
266
|
+
```text
|
|
267
|
+
Client A Server Client B
|
|
268
|
+
------------------------ ------------------------ ------------------------
|
|
269
|
+
call("load_model", k) -----> cache miss call("load_model", k)
|
|
270
|
+
build object once -------------> same key in flight
|
|
271
|
+
object_id = obj-123 wait on same future
|
|
272
|
+
<------------------------- return handle(obj-123) <------------- return handle(obj-123)
|
|
246
273
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
execution="task",
|
|
250
|
-
managed=True,
|
|
251
|
-
cache_format_key="model:{model_name}:{dtype}",
|
|
252
|
-
)
|
|
253
|
-
def load_model(model_name: str, dtype: str) -> torch.nn.Module:
|
|
254
|
-
...
|
|
274
|
+
release(obj-123) ----------> refcount 2 -> 1
|
|
275
|
+
release(obj-123) ------------------------------------------> refcount 1 -> 0 -> destroy
|
|
255
276
|
```
|
|
256
277
|
|
|
257
|
-
|
|
278
|
+
Use this for big reusable models. The important mix is:
|
|
258
279
|
|
|
259
|
-
- `
|
|
280
|
+
- `execution="task"`
|
|
260
281
|
- `managed=True`
|
|
282
|
+
- `concurrency="serialized"`
|
|
261
283
|
- `singleflight=True`
|
|
262
|
-
- explicit
|
|
263
|
-
|
|
264
|
-
Why:
|
|
284
|
+
- explicit `cache_format_key`
|
|
265
285
|
|
|
266
|
-
|
|
267
|
-
- managed handles keep explicit lifecycle control
|
|
268
|
-
- stable cache keys prevent accidental duplication from argument shape changes
|
|
286
|
+
`managed=True` gives explicit lifecycle control. `cache_format_key` turns the endpoint into a model registry. `singleflight=True` ensures duplicate in-flight loads collapse to one build.
|
|
269
287
|
|
|
270
|
-
|
|
288
|
+
## Example 4: Fire-And-Poll Warmup
|
|
271
289
|
|
|
272
|
-
|
|
290
|
+
This is the same task-backed endpoint style, but the caller chooses async use:
|
|
273
291
|
|
|
274
292
|
```python
|
|
275
293
|
task_id = load_model.submit(hidden_size=8192)
|
|
276
294
|
model_handle = provider.wait_for_task(task_id)
|
|
277
295
|
```
|
|
278
296
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
- endpoint uses `execution="task"`
|
|
282
|
-
- caller uses `.submit(...)`
|
|
283
|
-
- optionally add `managed=True` for long-lived objects
|
|
284
|
-
|
|
285
|
-
Why:
|
|
297
|
+
Runtime shape:
|
|
286
298
|
|
|
287
|
-
|
|
288
|
-
|
|
299
|
+
```text
|
|
300
|
+
submit now -> task queue -> slow build on server -> poll later -> consume handle/result
|
|
301
|
+
```
|
|
289
302
|
|
|
290
|
-
|
|
303
|
+
Use this when the build is slow enough that the caller should not block immediately.
|
|
291
304
|
|
|
292
|
-
|
|
305
|
+
## Example 5: Serialized Fragile Path
|
|
293
306
|
|
|
294
307
|
```python
|
|
295
|
-
@provider.share(execution="task", cache=False, singleflight=False)
|
|
296
|
-
def
|
|
297
|
-
|
|
308
|
+
@provider.share(execution="task", concurrency="serialized", cache=False, singleflight=False)
|
|
309
|
+
def compact_memory(tensor: torch.Tensor) -> torch.Tensor:
|
|
310
|
+
...
|
|
298
311
|
```
|
|
299
312
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
- `cache=False`
|
|
303
|
-
- `singleflight=False`
|
|
304
|
-
- choose `execution="direct"` or `execution="task"` based on runtime cost
|
|
313
|
+
Execution model:
|
|
305
314
|
|
|
306
|
-
|
|
315
|
+
```text
|
|
316
|
+
request A -> lock -> run -> unlock
|
|
317
|
+
request B -> wait -> lock -> run -> unlock
|
|
318
|
+
```
|
|
307
319
|
|
|
308
|
-
-
|
|
309
|
-
- disabling singleflight ensures independent requests stay independent
|
|
320
|
+
Use this for GPU-heavy paths that must not overlap with themselves.
|
|
310
321
|
|
|
311
|
-
|
|
322
|
+
## Endpoint Semantics
|
|
312
323
|
|
|
313
|
-
|
|
324
|
+
Each endpoint is registered once and then supports two client-side call styles.
|
|
314
325
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
326
|
+
- `fn(...)` or `provider.call(name, ...)`
|
|
327
|
+
- synchronous
|
|
328
|
+
- blocks until the result is ready
|
|
329
|
+
- `fn.submit(...)` or `provider.submit(name, ...)`
|
|
330
|
+
- asynchronous
|
|
331
|
+
- returns a task id
|
|
320
332
|
|
|
321
|
-
|
|
333
|
+
Endpoint options:
|
|
322
334
|
|
|
335
|
+
- `execution="direct"`
|
|
336
|
+
- sync calls run the function directly on the server
|
|
337
|
+
- use this for fast tensor transforms
|
|
323
338
|
- `execution="task"`
|
|
339
|
+
- sync calls still block, but they block on the task system
|
|
340
|
+
- use this for slow construction, warmup, and reusable model loading
|
|
341
|
+
- `concurrency="parallel"`
|
|
342
|
+
- multiple server executions may run at once
|
|
324
343
|
- `concurrency="serialized"`
|
|
325
|
-
-
|
|
326
|
-
|
|
327
|
-
|
|
344
|
+
- only one execution of that endpoint runs at a time
|
|
345
|
+
- `singleflight=True`
|
|
346
|
+
- identical in-flight cache keys collapse to one execution
|
|
347
|
+
- this is the recommended model-loading default
|
|
328
348
|
|
|
329
|
-
|
|
330
|
-
|
|
349
|
+
## Scenario Map
|
|
350
|
+
|
|
351
|
+
- Fast tensor transform:
|
|
352
|
+
use `execution="direct"`, `cache=False`, `managed=False`
|
|
353
|
+
- Slow model construction:
|
|
354
|
+
use `execution="task"`, `managed=True`, `concurrency="serialized"`
|
|
355
|
+
- Reusable model registry:
|
|
356
|
+
add stable `cache_format_key` and keep `singleflight=True`
|
|
357
|
+
- Background warmup:
|
|
358
|
+
keep endpoint as task-backed and use `.submit(...)`
|
|
359
|
+
- Fragile non-overlapping GPU path:
|
|
360
|
+
use `concurrency="serialized"`
|
|
361
|
+
- Fresh per-request work:
|
|
362
|
+
disable cache and usually disable singleflight
|
|
331
363
|
|
|
332
364
|
## Parameter Guide
|
|
333
365
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
- task-backed slow object construction
|
|
13
13
|
- endpoint-level serialization and cache-key singleflight
|
|
14
14
|
- zero-branch auto mode driven by `SHARED_TENSOR_ROLE`
|
|
15
|
+
- auto mode is gated by `SHARED_TENSOR_ENABLED=1`
|
|
15
16
|
- port routing by `base_port + cuda_device_index`
|
|
16
17
|
|
|
17
18
|
## What It Does Not Support
|
|
@@ -53,9 +54,40 @@ conda activate shared-tensor-dev
|
|
|
53
54
|
pip install -e ".[dev,test]"
|
|
54
55
|
```
|
|
55
56
|
|
|
56
|
-
##
|
|
57
|
+
## Enabling Auto Mode
|
|
57
58
|
|
|
58
|
-
|
|
59
|
+
`SharedTensorProvider()` now defaults to safe local mode unless you explicitly enable shared-tensor behavior.
|
|
60
|
+
|
|
61
|
+
Global default:
|
|
62
|
+
|
|
63
|
+
```bash
|
|
64
|
+
export SHARED_TENSOR_ENABLED=1
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
Per-provider override:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
provider = SharedTensorProvider(enabled=True)
|
|
71
|
+
provider = SharedTensorProvider(enabled=False)
|
|
72
|
+
provider = SharedTensorProvider(enabled=None)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
`enabled=None` means do not override and keep using the environment variable.
|
|
76
|
+
|
|
77
|
+
Then `execution_mode="auto"` behaves like this:
|
|
78
|
+
|
|
79
|
+
- `enabled=False`: provider stays in local mode
|
|
80
|
+
- `enabled=True` and `SHARED_TENSOR_ROLE=server`: auto-start server and execute locally on the server side
|
|
81
|
+
- `enabled=True` and no role set: provider becomes a client wrapper
|
|
82
|
+
- `enabled=None`: fall back to `SHARED_TENSOR_ENABLED`
|
|
83
|
+
|
|
84
|
+
This makes accidental opt-in much less likely in scripts that import shared endpoints but did not intend to start RPC behavior.
|
|
85
|
+
|
|
86
|
+
## Example 1: Zero-Branch Auto Mode
|
|
87
|
+
|
|
88
|
+
See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
|
|
89
|
+
|
|
90
|
+
One file, two processes, no branch in user code:
|
|
59
91
|
|
|
60
92
|
```python
|
|
61
93
|
import torch
|
|
@@ -90,55 +122,56 @@ if __name__ == "__main__":
|
|
|
90
122
|
Run process A as the auto server:
|
|
91
123
|
|
|
92
124
|
```bash
|
|
93
|
-
SHARED_TENSOR_ROLE=server python demo.py
|
|
125
|
+
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
|
|
94
126
|
```
|
|
95
127
|
|
|
96
128
|
Run process B as the client with the exact same file:
|
|
97
129
|
|
|
98
130
|
```bash
|
|
131
|
+
SHARED_TENSOR_ENABLED=1 python demo.py
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
Equivalent stepwise form:
|
|
135
|
+
|
|
136
|
+
```bash
|
|
137
|
+
export SHARED_TENSOR_ENABLED=1
|
|
138
|
+
SHARED_TENSOR_ROLE=server python demo.py
|
|
99
139
|
python demo.py
|
|
100
140
|
```
|
|
101
141
|
|
|
102
142
|
Behavior:
|
|
103
143
|
|
|
144
|
+
- `SHARED_TENSOR_ENABLED=1` enables shared-tensor auto behavior for providers that keep `enabled=None`
|
|
104
145
|
- `SHARED_TENSOR_ROLE=server` makes the provider auto-start a background localhost daemon
|
|
105
146
|
- in the server process, shared functions still execute locally
|
|
106
147
|
- in the client process, the same function names become RPC wrappers
|
|
107
148
|
- no `SHARED_TENSOR_HOST` is used; transport is fixed to `127.0.0.1`
|
|
108
149
|
- the final port is `SHARED_TENSOR_BASE_PORT + current_cuda_device_index`
|
|
109
150
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
Each endpoint is registered once and then supports two client-side call styles.
|
|
151
|
+
Why this works:
|
|
113
152
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
- blocks until the result is ready
|
|
117
|
-
- `fn.submit(...)` or `provider.submit(name, ...)`
|
|
118
|
-
- asynchronous
|
|
119
|
-
- returns a task id
|
|
153
|
+
```text
|
|
154
|
+
same code file
|
|
120
155
|
|
|
121
|
-
|
|
156
|
+
Process A Process B
|
|
157
|
+
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ENABLED=1
|
|
158
|
+
SHARED_TENSOR_ROLE=server SHARED_TENSOR_ROLE unset
|
|
159
|
+
---------------------------------- ----------------------------------
|
|
160
|
+
provider.share(...) provider.share(...)
|
|
161
|
+
provider auto-starts localhost daemon provider builds RPC wrappers
|
|
162
|
+
shared fn executes locally shared fn becomes RPC call
|
|
122
163
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
- async submit is the natural path for slow model construction
|
|
129
|
-
- `concurrency="parallel"`
|
|
130
|
-
- multiple server executions may run at once
|
|
131
|
-
- `concurrency="serialized"`
|
|
132
|
-
- only one execution of that endpoint runs at a time
|
|
133
|
-
- `singleflight=True`
|
|
134
|
-
- identical in-flight cache keys collapse to one execution
|
|
135
|
-
- this is the recommended model-loading default
|
|
164
|
+
load_model(...) load_model(...)
|
|
165
|
+
-> local CUDA model -> JSON-RPC to localhost daemon
|
|
166
|
+
identity(x) -> receives CUDA IPC-backed result
|
|
167
|
+
-> local tensor return
|
|
168
|
+
```
|
|
136
169
|
|
|
137
|
-
|
|
170
|
+
Use this mode when you want the cleanest operator experience: one script, one env var difference, server side stays local, client side becomes remote automatically.
|
|
138
171
|
|
|
139
|
-
|
|
172
|
+
## Example 2: Fast Tensor Transform
|
|
140
173
|
|
|
141
|
-
|
|
174
|
+
See [examples/model_service.py](./examples/model_service.py).
|
|
142
175
|
|
|
143
176
|
```python
|
|
144
177
|
@provider.share(execution="direct", cache=False)
|
|
@@ -146,6 +179,14 @@ def scale_tensor(tensor: torch.Tensor, factor: torch.Tensor) -> torch.Tensor:
|
|
|
146
179
|
return tensor * factor
|
|
147
180
|
```
|
|
148
181
|
|
|
182
|
+
What happens on the wire:
|
|
183
|
+
|
|
184
|
+
```text
|
|
185
|
+
client tensor -> direct RPC -> server runs function immediately -> CUDA result back
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
Use this for cheap tensor math, lightweight preprocessing, and request-scoped outputs.
|
|
189
|
+
|
|
149
190
|
Recommended combination:
|
|
150
191
|
|
|
151
192
|
- `execution="direct"`
|
|
@@ -153,130 +194,120 @@ Recommended combination:
|
|
|
153
194
|
- `managed=False`
|
|
154
195
|
- `concurrency="parallel"`
|
|
155
196
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
- direct execution has the lowest overhead
|
|
159
|
-
- these calls are request-scoped, so caching is usually wrong
|
|
160
|
-
- parallel execution is usually fine because the work is short-lived
|
|
197
|
+
## Example 3: Reusable Model Service
|
|
161
198
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
Use this for loading or building a CUDA model that may take hundreds of milliseconds or multiple seconds.
|
|
199
|
+
See [examples/model_service.py](./examples/model_service.py).
|
|
165
200
|
|
|
166
201
|
```python
|
|
167
202
|
@provider.share(
|
|
168
203
|
execution="task",
|
|
169
204
|
managed=True,
|
|
170
205
|
concurrency="serialized",
|
|
171
|
-
cache_format_key="model:{
|
|
206
|
+
cache_format_key="model:{input_dim}:{output_dim}",
|
|
172
207
|
)
|
|
173
|
-
def
|
|
174
|
-
|
|
208
|
+
def load_linear_model(input_dim: int = 16, output_dim: int = 4) -> torch.nn.Module:
|
|
209
|
+
...
|
|
175
210
|
```
|
|
176
211
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
- `execution="task"`
|
|
180
|
-
- `managed=True`
|
|
181
|
-
- `concurrency="serialized"`
|
|
182
|
-
- `singleflight=True`
|
|
183
|
-
- explicit `cache_format_key`
|
|
184
|
-
|
|
185
|
-
Why:
|
|
186
|
-
|
|
187
|
-
- task mode gives you both blocking sync calls and true async submission
|
|
188
|
-
- managed handles let the client release the remote object explicitly
|
|
189
|
-
- serialized execution avoids multiple concurrent heavy loads on one GPU
|
|
190
|
-
- singleflight prevents duplicate in-flight construction for the same model key
|
|
191
|
-
|
|
192
|
-
### 3. Reusable Shared Model Service
|
|
212
|
+
What happens when two clients ask for the same model key:
|
|
193
213
|
|
|
194
|
-
|
|
214
|
+
```text
|
|
215
|
+
Client A Server Client B
|
|
216
|
+
------------------------ ------------------------ ------------------------
|
|
217
|
+
call("load_model", k) -----> cache miss call("load_model", k)
|
|
218
|
+
build object once -------------> same key in flight
|
|
219
|
+
object_id = obj-123 wait on same future
|
|
220
|
+
<------------------------- return handle(obj-123) <------------- return handle(obj-123)
|
|
195
221
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
execution="task",
|
|
199
|
-
managed=True,
|
|
200
|
-
cache_format_key="model:{model_name}:{dtype}",
|
|
201
|
-
)
|
|
202
|
-
def load_model(model_name: str, dtype: str) -> torch.nn.Module:
|
|
203
|
-
...
|
|
222
|
+
release(obj-123) ----------> refcount 2 -> 1
|
|
223
|
+
release(obj-123) ------------------------------------------> refcount 1 -> 0 -> destroy
|
|
204
224
|
```
|
|
205
225
|
|
|
206
|
-
|
|
226
|
+
Use this for big reusable models. The important mix is:
|
|
207
227
|
|
|
208
|
-
- `
|
|
228
|
+
- `execution="task"`
|
|
209
229
|
- `managed=True`
|
|
230
|
+
- `concurrency="serialized"`
|
|
210
231
|
- `singleflight=True`
|
|
211
|
-
- explicit
|
|
212
|
-
|
|
213
|
-
Why:
|
|
232
|
+
- explicit `cache_format_key`
|
|
214
233
|
|
|
215
|
-
|
|
216
|
-
- managed handles keep explicit lifecycle control
|
|
217
|
-
- stable cache keys prevent accidental duplication from argument shape changes
|
|
234
|
+
`managed=True` gives explicit lifecycle control. `cache_format_key` turns the endpoint into a model registry. `singleflight=True` ensures duplicate in-flight loads collapse to one build.
|
|
218
235
|
|
|
219
|
-
|
|
236
|
+
## Example 4: Fire-And-Poll Warmup
|
|
220
237
|
|
|
221
|
-
|
|
238
|
+
This is the same task-backed endpoint style, but the caller chooses async use:
|
|
222
239
|
|
|
223
240
|
```python
|
|
224
241
|
task_id = load_model.submit(hidden_size=8192)
|
|
225
242
|
model_handle = provider.wait_for_task(task_id)
|
|
226
243
|
```
|
|
227
244
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
- endpoint uses `execution="task"`
|
|
231
|
-
- caller uses `.submit(...)`
|
|
232
|
-
- optionally add `managed=True` for long-lived objects
|
|
233
|
-
|
|
234
|
-
Why:
|
|
245
|
+
Runtime shape:
|
|
235
246
|
|
|
236
|
-
|
|
237
|
-
|
|
247
|
+
```text
|
|
248
|
+
submit now -> task queue -> slow build on server -> poll later -> consume handle/result
|
|
249
|
+
```
|
|
238
250
|
|
|
239
|
-
|
|
251
|
+
Use this when the build is slow enough that the caller should not block immediately.
|
|
240
252
|
|
|
241
|
-
|
|
253
|
+
## Example 5: Serialized Fragile Path
|
|
242
254
|
|
|
243
255
|
```python
|
|
244
|
-
@provider.share(execution="task", cache=False, singleflight=False)
|
|
245
|
-
def
|
|
246
|
-
|
|
256
|
+
@provider.share(execution="task", concurrency="serialized", cache=False, singleflight=False)
|
|
257
|
+
def compact_memory(tensor: torch.Tensor) -> torch.Tensor:
|
|
258
|
+
...
|
|
247
259
|
```
|
|
248
260
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
- `cache=False`
|
|
252
|
-
- `singleflight=False`
|
|
253
|
-
- choose `execution="direct"` or `execution="task"` based on runtime cost
|
|
261
|
+
Execution model:
|
|
254
262
|
|
|
255
|
-
|
|
263
|
+
```text
|
|
264
|
+
request A -> lock -> run -> unlock
|
|
265
|
+
request B -> wait -> lock -> run -> unlock
|
|
266
|
+
```
|
|
256
267
|
|
|
257
|
-
-
|
|
258
|
-
- disabling singleflight ensures independent requests stay independent
|
|
268
|
+
Use this for GPU-heavy paths that must not overlap with themselves.
|
|
259
269
|
|
|
260
|
-
|
|
270
|
+
## Endpoint Semantics
|
|
261
271
|
|
|
262
|
-
|
|
272
|
+
Each endpoint is registered once and then supports two client-side call styles.
|
|
263
273
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
274
|
+
- `fn(...)` or `provider.call(name, ...)`
|
|
275
|
+
- synchronous
|
|
276
|
+
- blocks until the result is ready
|
|
277
|
+
- `fn.submit(...)` or `provider.submit(name, ...)`
|
|
278
|
+
- asynchronous
|
|
279
|
+
- returns a task id
|
|
269
280
|
|
|
270
|
-
|
|
281
|
+
Endpoint options:
|
|
271
282
|
|
|
283
|
+
- `execution="direct"`
|
|
284
|
+
- sync calls run the function directly on the server
|
|
285
|
+
- use this for fast tensor transforms
|
|
272
286
|
- `execution="task"`
|
|
287
|
+
- sync calls still block, but they block on the task system
|
|
288
|
+
- use this for slow construction, warmup, and reusable model loading
|
|
289
|
+
- `concurrency="parallel"`
|
|
290
|
+
- multiple server executions may run at once
|
|
273
291
|
- `concurrency="serialized"`
|
|
274
|
-
-
|
|
275
|
-
|
|
276
|
-
|
|
292
|
+
- only one execution of that endpoint runs at a time
|
|
293
|
+
- `singleflight=True`
|
|
294
|
+
- identical in-flight cache keys collapse to one execution
|
|
295
|
+
- this is the recommended model-loading default
|
|
277
296
|
|
|
278
|
-
|
|
279
|
-
|
|
297
|
+
## Scenario Map
|
|
298
|
+
|
|
299
|
+
- Fast tensor transform:
|
|
300
|
+
use `execution="direct"`, `cache=False`, `managed=False`
|
|
301
|
+
- Slow model construction:
|
|
302
|
+
use `execution="task"`, `managed=True`, `concurrency="serialized"`
|
|
303
|
+
- Reusable model registry:
|
|
304
|
+
add stable `cache_format_key` and keep `singleflight=True`
|
|
305
|
+
- Background warmup:
|
|
306
|
+
keep endpoint as task-backed and use `.submit(...)`
|
|
307
|
+
- Fragile non-overlapping GPU path:
|
|
308
|
+
use `concurrency="serialized"`
|
|
309
|
+
- Fresh per-request work:
|
|
310
|
+
disable cache and usually disable singleflight
|
|
280
311
|
|
|
281
312
|
## Parameter Guide
|
|
282
313
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shared-tensor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.4"
|
|
8
8
|
description = "Local endpoint-oriented RPC for same-host same-GPU PyTorch IPC"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -42,6 +42,7 @@ classifiers = [
|
|
|
42
42
|
]
|
|
43
43
|
requires-python = ">=3.10"
|
|
44
44
|
dependencies = [
|
|
45
|
+
"cloudpickle>=3.0.0",
|
|
45
46
|
"numpy<2",
|
|
46
47
|
"requests>=2.25.0",
|
|
47
48
|
"torch>=2.2.0",
|
|
@@ -15,6 +15,7 @@ class AsyncSharedTensorProvider(SharedTensorProvider):
|
|
|
15
15
|
base_port: int = 2537,
|
|
16
16
|
poll_interval: float = 1.0,
|
|
17
17
|
*,
|
|
18
|
+
enabled: bool | None = None,
|
|
18
19
|
server_host: str = "127.0.0.1",
|
|
19
20
|
device_index: int | None = None,
|
|
20
21
|
timeout: float = 30.0,
|
|
@@ -23,6 +24,7 @@ class AsyncSharedTensorProvider(SharedTensorProvider):
|
|
|
23
24
|
) -> None:
|
|
24
25
|
super().__init__(
|
|
25
26
|
base_port=base_port,
|
|
27
|
+
enabled=enabled,
|
|
26
28
|
server_host=server_host,
|
|
27
29
|
device_index=device_index,
|
|
28
30
|
timeout=timeout,
|
|
@@ -21,6 +21,7 @@ from shared_tensor.utils import (
|
|
|
21
21
|
|
|
22
22
|
EndpointExecution = Literal["direct", "task"]
|
|
23
23
|
EndpointConcurrency = Literal["parallel", "serialized"]
|
|
24
|
+
SHARED_TENSOR_ENABLED_ENV = "SHARED_TENSOR_ENABLED"
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
@dataclass(slots=True)
|
|
@@ -36,8 +37,14 @@ class EndpointDefinition:
|
|
|
36
37
|
singleflight: bool = True
|
|
37
38
|
|
|
38
39
|
|
|
39
|
-
def _resolve_execution_mode(
|
|
40
|
+
def _resolve_execution_mode(
|
|
41
|
+
execution_mode: str,
|
|
42
|
+
*,
|
|
43
|
+
enabled: bool | None = None,
|
|
44
|
+
) -> tuple[str, bool]:
|
|
40
45
|
if execution_mode == "auto":
|
|
46
|
+
if not _is_shared_tensor_enabled(enabled):
|
|
47
|
+
return "local", True
|
|
41
48
|
env_role = os.getenv("SHARED_TENSOR_ROLE", "").strip().lower()
|
|
42
49
|
if env_role in {"server", "client", "local"}:
|
|
43
50
|
return env_role, True
|
|
@@ -49,6 +56,13 @@ def _resolve_execution_mode(execution_mode: str) -> tuple[str, bool]:
|
|
|
49
56
|
return execution_mode, False
|
|
50
57
|
|
|
51
58
|
|
|
59
|
+
def _is_shared_tensor_enabled(enabled: bool | None) -> bool:
|
|
60
|
+
if enabled is not None:
|
|
61
|
+
return enabled
|
|
62
|
+
raw = os.getenv(SHARED_TENSOR_ENABLED_ENV, "").strip().lower()
|
|
63
|
+
return raw in {"1", "true", "yes", "on"}
|
|
64
|
+
|
|
65
|
+
|
|
52
66
|
def _validate_endpoint_options(
|
|
53
67
|
*,
|
|
54
68
|
execution: EndpointExecution,
|
|
@@ -71,15 +85,20 @@ class SharedTensorProvider:
|
|
|
71
85
|
self,
|
|
72
86
|
base_port: int = 2537,
|
|
73
87
|
*,
|
|
88
|
+
enabled: bool | None = None,
|
|
74
89
|
server_host: str = "127.0.0.1",
|
|
75
90
|
device_index: int | None = None,
|
|
76
91
|
timeout: float = 30.0,
|
|
77
92
|
execution_mode: str = "auto",
|
|
78
93
|
verbose_debug: bool = False,
|
|
79
94
|
) -> None:
|
|
80
|
-
resolved_mode, auto_mode = _resolve_execution_mode(
|
|
95
|
+
resolved_mode, auto_mode = _resolve_execution_mode(
|
|
96
|
+
execution_mode,
|
|
97
|
+
enabled=enabled,
|
|
98
|
+
)
|
|
81
99
|
self.server_host = server_host
|
|
82
100
|
self.base_port = resolve_server_base_port(base_port)
|
|
101
|
+
self.enabled = enabled
|
|
83
102
|
self.device_index = device_index
|
|
84
103
|
self.timeout = timeout
|
|
85
104
|
self.execution_mode = resolved_mode
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
+
import cloudpickle
|
|
6
7
|
import multiprocessing as mp
|
|
7
8
|
import os
|
|
8
9
|
import threading
|
|
@@ -10,7 +11,6 @@ import time
|
|
|
10
11
|
from concurrent.futures import Future
|
|
11
12
|
from dataclasses import dataclass
|
|
12
13
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
13
|
-
from multiprocessing.process import BaseProcess
|
|
14
14
|
from socketserver import ThreadingMixIn
|
|
15
15
|
from typing import Any
|
|
16
16
|
|
|
@@ -37,7 +37,7 @@ from shared_tensor.utils import (
|
|
|
37
37
|
capability_snapshot,
|
|
38
38
|
deserialize_payload,
|
|
39
39
|
serialize_payload,
|
|
40
|
-
|
|
40
|
+
validate_call_payload_for_transport,
|
|
41
41
|
)
|
|
42
42
|
|
|
43
43
|
logger = logging.getLogger(__name__)
|
|
@@ -118,7 +118,7 @@ class SharedTensorServer:
|
|
|
118
118
|
self.max_workers = max_workers
|
|
119
119
|
self.result_ttl = result_ttl
|
|
120
120
|
self.server: ThreadedHTTPServer | None = None
|
|
121
|
-
self.server_process:
|
|
121
|
+
self.server_process: Any | None = None
|
|
122
122
|
self.running = False
|
|
123
123
|
self.started_at: float | None = None
|
|
124
124
|
self.stats = {
|
|
@@ -231,6 +231,12 @@ class SharedTensorServer:
|
|
|
231
231
|
if inflight_key is not None:
|
|
232
232
|
future, owner = self._acquire_inflight(inflight_key)
|
|
233
233
|
if not owner:
|
|
234
|
+
if definition.managed:
|
|
235
|
+
payload = future.result()
|
|
236
|
+
object_id = payload.get("object_id")
|
|
237
|
+
if object_id is not None:
|
|
238
|
+
self._managed_objects.add_ref(object_id)
|
|
239
|
+
return payload
|
|
234
240
|
return future.result()
|
|
235
241
|
else:
|
|
236
242
|
future = None
|
|
@@ -401,8 +407,8 @@ class SharedTensorServer:
|
|
|
401
407
|
"Control encoding is reserved for empty args/kwargs only"
|
|
402
408
|
)
|
|
403
409
|
return endpoint, args, kwargs
|
|
404
|
-
|
|
405
|
-
|
|
410
|
+
validate_call_payload_for_transport(args)
|
|
411
|
+
validate_call_payload_for_transport(kwargs, allow_dict_keys=True)
|
|
406
412
|
return endpoint, args, kwargs
|
|
407
413
|
|
|
408
414
|
def _encode_result(self, value: Any, *, object_id: str | None = None) -> dict[str, str | None]:
|
|
@@ -437,7 +443,7 @@ class SharedTensorServer:
|
|
|
437
443
|
uptime = 0.0 if self.started_at is None else time.time() - self.started_at
|
|
438
444
|
return {
|
|
439
445
|
"server": "SharedTensorServer",
|
|
440
|
-
"version": "0.2.
|
|
446
|
+
"version": "0.2.4",
|
|
441
447
|
"host": self.host,
|
|
442
448
|
"port": self.port,
|
|
443
449
|
"uptime": uptime,
|
|
@@ -454,16 +460,52 @@ class SharedTensorServer:
|
|
|
454
460
|
return
|
|
455
461
|
if os.name != "posix":
|
|
456
462
|
raise SharedTensorConfigurationError(
|
|
457
|
-
"Non-blocking shared_tensor servers require POSIX
|
|
463
|
+
"Non-blocking shared_tensor servers require POSIX multiprocessing support"
|
|
458
464
|
)
|
|
459
|
-
|
|
460
|
-
process =
|
|
465
|
+
payload = cloudpickle.dumps(self.provider)
|
|
466
|
+
process = mp.get_context("spawn").Process(
|
|
467
|
+
target=self._serve_forever_from_payload,
|
|
468
|
+
args=(
|
|
469
|
+
payload,
|
|
470
|
+
self.host,
|
|
471
|
+
self.port,
|
|
472
|
+
self.max_request_bytes,
|
|
473
|
+
self.max_workers,
|
|
474
|
+
self.result_ttl,
|
|
475
|
+
self.verbose_debug,
|
|
476
|
+
),
|
|
477
|
+
name=f"shared-tensor-daemon:{self.port}",
|
|
478
|
+
)
|
|
461
479
|
process.start()
|
|
462
480
|
self.server_process = process
|
|
463
481
|
self.running = True
|
|
464
482
|
self.started_at = time.time()
|
|
465
483
|
|
|
484
|
+
@staticmethod
|
|
485
|
+
def _serve_forever_from_payload(
|
|
486
|
+
payload: bytes,
|
|
487
|
+
host: str,
|
|
488
|
+
port: int,
|
|
489
|
+
max_request_bytes: int,
|
|
490
|
+
max_workers: int,
|
|
491
|
+
result_ttl: float,
|
|
492
|
+
verbose_debug: bool,
|
|
493
|
+
) -> None:
|
|
494
|
+
SharedTensorServer._configure_cuda_runtime()
|
|
495
|
+
provider = cloudpickle.loads(payload)
|
|
496
|
+
server = SharedTensorServer(
|
|
497
|
+
provider,
|
|
498
|
+
host=host,
|
|
499
|
+
port=port,
|
|
500
|
+
max_request_bytes=max_request_bytes,
|
|
501
|
+
max_workers=max_workers,
|
|
502
|
+
result_ttl=result_ttl,
|
|
503
|
+
verbose_debug=verbose_debug,
|
|
504
|
+
)
|
|
505
|
+
server._serve_forever()
|
|
506
|
+
|
|
466
507
|
def _serve_forever(self) -> None:
|
|
508
|
+
self._configure_cuda_runtime()
|
|
467
509
|
self.server = ThreadedHTTPServer((self.host, self.port), SharedTensorRequestHandler)
|
|
468
510
|
self.server.shared_tensor_server = self # type: ignore[attr-defined]
|
|
469
511
|
self.running = True
|
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import inspect
|
|
7
7
|
import io
|
|
8
|
+
import multiprocessing.reduction as mp_reduction
|
|
8
9
|
import os
|
|
9
10
|
import pickle
|
|
10
11
|
from collections.abc import Callable
|
|
@@ -35,7 +36,17 @@ _EMPTY_DICT: dict[str, Any] = {}
|
|
|
35
36
|
def _torch_forking_pickler() -> type | None:
|
|
36
37
|
if TORCH_MODULE is None:
|
|
37
38
|
return None
|
|
38
|
-
|
|
39
|
+
reductions = TORCH_MODULE.multiprocessing.reductions
|
|
40
|
+
init_reductions = getattr(reductions, "init_reductions", None)
|
|
41
|
+
if callable(init_reductions):
|
|
42
|
+
try:
|
|
43
|
+
init_reductions()
|
|
44
|
+
except Exception:
|
|
45
|
+
return cast(type, mp_reduction.ForkingPickler)
|
|
46
|
+
pickler = getattr(reductions, "ForkingPickler", None)
|
|
47
|
+
if pickler is not None:
|
|
48
|
+
return cast(type, pickler)
|
|
49
|
+
return cast(type, mp_reduction.ForkingPickler)
|
|
39
50
|
|
|
40
51
|
|
|
41
52
|
def _raise_unsupported_payload(message: str) -> None:
|
|
@@ -91,11 +102,55 @@ def _validate_torch_payload(obj: Any, *, allow_dict_keys: bool = False) -> None:
|
|
|
91
102
|
_raise_unsupported_payload(f"Unsupported payload type: {type(obj).__name__}")
|
|
92
103
|
|
|
93
104
|
|
|
105
|
+
def _validate_call_payload(obj: Any, *, allow_dict_keys: bool = False) -> None:
|
|
106
|
+
if TORCH_MODULE is None:
|
|
107
|
+
raise SharedTensorCapabilityError("PyTorch is required for shared_tensor")
|
|
108
|
+
|
|
109
|
+
if isinstance(obj, (str, int, float, bool, type(None), bytes)):
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
if isinstance(obj, TORCH_MODULE.Tensor):
|
|
113
|
+
if not obj.is_cuda:
|
|
114
|
+
_raise_unsupported_payload("CPU torch.Tensor payloads are not supported")
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
if isinstance(obj, TORCH_MODULE.nn.Module):
|
|
118
|
+
_validate_module_device(obj)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
if isinstance(obj, tuple):
|
|
122
|
+
for item in obj:
|
|
123
|
+
_validate_call_payload(item)
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
if isinstance(obj, list):
|
|
127
|
+
for item in obj:
|
|
128
|
+
_validate_call_payload(item)
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
if isinstance(obj, dict):
|
|
132
|
+
for key, value in obj.items():
|
|
133
|
+
if allow_dict_keys:
|
|
134
|
+
if not isinstance(key, str):
|
|
135
|
+
_raise_unsupported_payload("Dictionary payload keys must be strings")
|
|
136
|
+
else:
|
|
137
|
+
_validate_call_payload(key)
|
|
138
|
+
_validate_call_payload(value)
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
_raise_unsupported_payload(f"Unsupported payload type: {type(obj).__name__}")
|
|
142
|
+
|
|
143
|
+
|
|
94
144
|
def validate_payload_for_transport(obj: Any, *, allow_dict_keys: bool = False) -> None:
|
|
95
145
|
"""Validate that a payload fits the supported CUDA torch transport contract."""
|
|
96
146
|
_validate_torch_payload(obj, allow_dict_keys=allow_dict_keys)
|
|
97
147
|
|
|
98
148
|
|
|
149
|
+
def validate_call_payload_for_transport(obj: Any, *, allow_dict_keys: bool = False) -> None:
|
|
150
|
+
"""Validate RPC call args/kwargs, allowing scalar controls alongside CUDA payloads."""
|
|
151
|
+
_validate_call_payload(obj, allow_dict_keys=allow_dict_keys)
|
|
152
|
+
|
|
153
|
+
|
|
99
154
|
def _torch_serialize(obj: Any) -> bytes:
|
|
100
155
|
pickler_cls = _torch_forking_pickler()
|
|
101
156
|
if pickler_cls is None:
|
|
@@ -140,8 +195,8 @@ def serialize_call_payloads(
|
|
|
140
195
|
return CONTROL_ENCODING, args_payload, serialize_empty_payload(_EMPTY_DICT)[1]
|
|
141
196
|
|
|
142
197
|
try:
|
|
143
|
-
|
|
144
|
-
|
|
198
|
+
_validate_call_payload(args)
|
|
199
|
+
_validate_call_payload(kwargs, allow_dict_keys=True)
|
|
145
200
|
return TORCH_ENCODING, pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL), pickle.dumps(
|
|
146
201
|
kwargs, protocol=pickle.HIGHEST_PROTOCOL
|
|
147
202
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|