shared-tensor 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Local endpoint-oriented RPC for same-host same-GPU PyTorch IPC
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
@@ -25,6 +25,7 @@ Classifier: Topic :: System :: Distributed Computing
25
25
  Requires-Python: >=3.10
26
26
  Description-Content-Type: text/markdown
27
27
  License-File: LICENSE
28
+ Requires-Dist: cloudpickle>=3.0.0
28
29
  Requires-Dist: numpy<2
29
30
  Requires-Dist: requests>=2.25.0
30
31
  Requires-Dist: torch>=2.2.0
@@ -63,6 +64,7 @@ Dynamic: license-file
63
64
  - task-backed slow object construction
64
65
  - endpoint-level serialization and cache-key singleflight
65
66
  - zero-branch auto mode driven by `SHARED_TENSOR_ROLE`
67
+ - auto mode is gated by `SHARED_TENSOR_ENABLED=1`
66
68
  - port routing by `base_port + cuda_device_index`
67
69
 
68
70
  ## What It Does Not Support
@@ -104,9 +106,40 @@ conda activate shared-tensor-dev
104
106
  pip install -e ".[dev,test]"
105
107
  ```
106
108
 
107
- ## Zero-Branch Example
109
+ ## Enabling Auto Mode
108
110
 
109
- One file:
111
+ `SharedTensorProvider()` now defaults to safe local mode unless you explicitly enable shared-tensor behavior.
112
+
113
+ Global default:
114
+
115
+ ```bash
116
+ export SHARED_TENSOR_ENABLED=1
117
+ ```
118
+
119
+ Per-provider override:
120
+
121
+ ```python
122
+ provider = SharedTensorProvider(enabled=True)
123
+ provider = SharedTensorProvider(enabled=False)
124
+ provider = SharedTensorProvider(enabled=None)
125
+ ```
126
+
127
+ `enabled=None` means do not override and keep using the environment variable.
128
+
129
+ Then `execution_mode="auto"` behaves like this:
130
+
131
+ - `enabled=False`: provider stays in local mode
132
+ - `enabled=True` and `SHARED_TENSOR_ROLE=server`: auto-start server and execute locally on the server side
133
+ - `enabled=True` and no role set: provider becomes a client wrapper
134
+ - `enabled=None`: fall back to `SHARED_TENSOR_ENABLED`
135
+
136
+ This makes accidental opt-in much less likely in scripts that import shared endpoints but did not intend to start RPC behavior.
137
+
138
+ ## Example 1: Zero-Branch Auto Mode
139
+
140
+ See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
141
+
142
+ One file, two processes, no branch in user code:
110
143
 
111
144
  ```python
112
145
  import torch
@@ -141,55 +174,56 @@ if __name__ == "__main__":
141
174
  Run process A as the auto server:
142
175
 
143
176
  ```bash
144
- SHARED_TENSOR_ROLE=server python demo.py
177
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
145
178
  ```
146
179
 
147
180
  Run process B as the client with the exact same file:
148
181
 
149
182
  ```bash
183
+ SHARED_TENSOR_ENABLED=1 python demo.py
184
+ ```
185
+
186
+ Equivalent stepwise form:
187
+
188
+ ```bash
189
+ export SHARED_TENSOR_ENABLED=1
190
+ SHARED_TENSOR_ROLE=server python demo.py
150
191
  python demo.py
151
192
  ```
152
193
 
153
194
  Behavior:
154
195
 
196
+ - `SHARED_TENSOR_ENABLED=1` enables shared-tensor auto behavior for providers that keep `enabled=None`
155
197
  - `SHARED_TENSOR_ROLE=server` makes the provider auto-start a background localhost daemon
156
198
  - in the server process, shared functions still execute locally
157
199
  - in the client process, the same function names become RPC wrappers
158
200
  - no `SHARED_TENSOR_HOST` is used; transport is fixed to `127.0.0.1`
159
201
  - the final port is `SHARED_TENSOR_BASE_PORT + current_cuda_device_index`
160
202
 
161
- ## Endpoint Semantics
162
-
163
- Each endpoint is registered once and then supports two client-side call styles.
203
+ Why this works:
164
204
 
165
- - `fn(...)` or `provider.call(name, ...)`
166
- - synchronous
167
- - blocks until the result is ready
168
- - `fn.submit(...)` or `provider.submit(name, ...)`
169
- - asynchronous
170
- - returns a task id
205
+ ```text
206
+ same code file
171
207
 
172
- Endpoint options:
208
+ Process A Process B
209
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ENABLED=1
210
+ SHARED_TENSOR_ROLE=server SHARED_TENSOR_ROLE unset
211
+ ---------------------------------- ----------------------------------
212
+ provider.share(...) provider.share(...)
213
+ provider auto-starts localhost daemon provider builds RPC wrappers
214
+ shared fn executes locally shared fn becomes RPC call
173
215
 
174
- - `execution="direct"`
175
- - sync calls run the function directly on the server
176
- - best for fast tensor transforms
177
- - `execution="task"`
178
- - sync calls still block, but they block on the task system
179
- - async submit is the natural path for slow model construction
180
- - `concurrency="parallel"`
181
- - multiple server executions may run at once
182
- - `concurrency="serialized"`
183
- - only one execution of that endpoint runs at a time
184
- - `singleflight=True`
185
- - identical in-flight cache keys collapse to one execution
186
- - this is the recommended model-loading default
216
+ load_model(...) load_model(...)
217
+ -> local CUDA model -> JSON-RPC to localhost daemon
218
+ identity(x) -> receives CUDA IPC-backed result
219
+ -> local tensor return
220
+ ```
187
221
 
188
- ## Common Scenarios
222
+ Use this mode when you want the cleanest operator experience: one script, one env var difference, server side stays local, client side becomes remote automatically.
189
223
 
190
- ### 1. Fast Tensor Transform
224
+ ## Example 2: Fast Tensor Transform
191
225
 
192
- Use this for cheap operations such as clone, view-like transforms, elementwise scaling, or lightweight preprocessing.
226
+ See [examples/model_service.py](./examples/model_service.py).
193
227
 
194
228
  ```python
195
229
  @provider.share(execution="direct", cache=False)
@@ -197,6 +231,14 @@ def scale_tensor(tensor: torch.Tensor, factor: torch.Tensor) -> torch.Tensor:
197
231
  return tensor * factor
198
232
  ```
199
233
 
234
+ What happens on the wire:
235
+
236
+ ```text
237
+ client tensor -> direct RPC -> server runs function immediately -> CUDA result back
238
+ ```
239
+
240
+ Use this for cheap tensor math, lightweight preprocessing, and request-scoped outputs.
241
+
200
242
  Recommended combination:
201
243
 
202
244
  - `execution="direct"`
@@ -204,130 +246,120 @@ Recommended combination:
204
246
  - `managed=False`
205
247
  - `concurrency="parallel"`
206
248
 
207
- Why:
208
-
209
- - direct execution has the lowest overhead
210
- - these calls are request-scoped, so caching is usually wrong
211
- - parallel execution is usually fine because the work is short-lived
249
+ ## Example 3: Reusable Model Service
212
250
 
213
- ### 2. Slow Model Construction
214
-
215
- Use this for loading or building a CUDA model that may take hundreds of milliseconds or multiple seconds.
251
+ See [examples/model_service.py](./examples/model_service.py).
216
252
 
217
253
  ```python
218
254
  @provider.share(
219
255
  execution="task",
220
256
  managed=True,
221
257
  concurrency="serialized",
222
- cache_format_key="model:{hidden_size}",
258
+ cache_format_key="model:{input_dim}:{output_dim}",
223
259
  )
224
- def load_model(hidden_size: int) -> torch.nn.Module:
225
- return torch.nn.Linear(hidden_size, 2, device="cuda")
260
+ def load_linear_model(input_dim: int = 16, output_dim: int = 4) -> torch.nn.Module:
261
+ ...
226
262
  ```
227
263
 
228
- Recommended combination:
229
-
230
- - `execution="task"`
231
- - `managed=True`
232
- - `concurrency="serialized"`
233
- - `singleflight=True`
234
- - explicit `cache_format_key`
235
-
236
- Why:
237
-
238
- - task mode gives you both blocking sync calls and true async submission
239
- - managed handles let the client release the remote object explicitly
240
- - serialized execution avoids multiple concurrent heavy loads on one GPU
241
- - singleflight prevents duplicate in-flight construction for the same model key
242
-
243
- ### 3. Reusable Shared Model Service
264
+ What happens when two clients ask for the same model key:
244
265
 
245
- Use this when the model should be built once and reused by many client calls.
266
+ ```text
267
+ Client A Server Client B
268
+ ------------------------ ------------------------ ------------------------
269
+ call("load_model", k) -----> cache miss call("load_model", k)
270
+ build object once -------------> same key in flight
271
+ object_id = obj-123 wait on same future
272
+ <------------------------- return handle(obj-123) <------------- return handle(obj-123)
246
273
 
247
- ```python
248
- @provider.share(
249
- execution="task",
250
- managed=True,
251
- cache_format_key="model:{model_name}:{dtype}",
252
- )
253
- def load_model(model_name: str, dtype: str) -> torch.nn.Module:
254
- ...
274
+ release(obj-123) ----------> refcount 2 -> 1
275
+ release(obj-123) ------------------------------------------> refcount 1 -> 0 -> destroy
255
276
  ```
256
277
 
257
- Recommended combination:
278
+ Use this for big reusable models. The important mix is:
258
279
 
259
- - `cache=True`
280
+ - `execution="task"`
260
281
  - `managed=True`
282
+ - `concurrency="serialized"`
261
283
  - `singleflight=True`
262
- - explicit stable cache key
263
-
264
- Why:
284
+ - explicit `cache_format_key`
265
285
 
266
- - caching makes the endpoint act like a model registry
267
- - managed handles keep explicit lifecycle control
268
- - stable cache keys prevent accidental duplication from argument shape changes
286
+ `managed=True` gives explicit lifecycle control. `cache_format_key` turns the endpoint into a model registry. `singleflight=True` ensures duplicate in-flight loads collapse to one build.
269
287
 
270
- ### 4. Fire-and-Poll Background Warmup
288
+ ## Example 4: Fire-And-Poll Warmup
271
289
 
272
- Use this when the caller should not block, for example prewarming a model or allocating a large reusable tensor in the background.
290
+ This is the same task-backed endpoint style, but the caller chooses async use:
273
291
 
274
292
  ```python
275
293
  task_id = load_model.submit(hidden_size=8192)
276
294
  model_handle = provider.wait_for_task(task_id)
277
295
  ```
278
296
 
279
- Recommended combination:
280
-
281
- - endpoint uses `execution="task"`
282
- - caller uses `.submit(...)`
283
- - optionally add `managed=True` for long-lived objects
284
-
285
- Why:
297
+ Runtime shape:
286
298
 
287
- - the endpoint stays declarative
288
- - the caller decides whether to block now or poll later
299
+ ```text
300
+ submit now -> task queue -> slow build on server -> poll later -> consume handle/result
301
+ ```
289
302
 
290
- ### 5. Strictly Non-Reusable Per-Request Work
303
+ Use this when the build is slow enough that the caller should not block immediately.
291
304
 
292
- Use this when every request must create a fresh result and reuse is wrong.
305
+ ## Example 5: Serialized Fragile Path
293
306
 
294
307
  ```python
295
- @provider.share(execution="task", cache=False, singleflight=False)
296
- def build_request_tensor(template: torch.Tensor) -> torch.Tensor:
297
- return template.clone()
308
+ @provider.share(execution="task", concurrency="serialized", cache=False, singleflight=False)
309
+ def compact_memory(tensor: torch.Tensor) -> torch.Tensor:
310
+ ...
298
311
  ```
299
312
 
300
- Recommended combination:
301
-
302
- - `cache=False`
303
- - `singleflight=False`
304
- - choose `execution="direct"` or `execution="task"` based on runtime cost
313
+ Execution model:
305
314
 
306
- Why:
315
+ ```text
316
+ request A -> lock -> run -> unlock
317
+ request B -> wait -> lock -> run -> unlock
318
+ ```
307
319
 
308
- - disabling cache avoids cross-request reuse
309
- - disabling singleflight ensures independent requests stay independent
320
+ Use this for GPU-heavy paths that must not overlap with themselves.
310
321
 
311
- ### 6. Endpoint That Must Run One At A Time
322
+ ## Endpoint Semantics
312
323
 
313
- Use this when the endpoint mutates shared state, temporarily spikes memory, or must not overlap with itself.
324
+ Each endpoint is registered once and then supports two client-side call styles.
314
325
 
315
- ```python
316
- @provider.share(execution="task", concurrency="serialized", cache=False, singleflight=False)
317
- def compact_memory(tensor: torch.Tensor) -> torch.Tensor:
318
- ...
319
- ```
326
+ - `fn(...)` or `provider.call(name, ...)`
327
+ - synchronous
328
+ - blocks until the result is ready
329
+ - `fn.submit(...)` or `provider.submit(name, ...)`
330
+ - asynchronous
331
+ - returns a task id
320
332
 
321
- Recommended combination:
333
+ Endpoint options:
322
334
 
335
+ - `execution="direct"`
336
+ - sync calls run the function directly on the server
337
+ - use this for fast tensor transforms
323
338
  - `execution="task"`
339
+ - sync calls still block, but they block on the task system
340
+ - use this for slow construction, warmup, and reusable model loading
341
+ - `concurrency="parallel"`
342
+ - multiple server executions may run at once
324
343
  - `concurrency="serialized"`
325
- - usually `cache=False`
326
-
327
- Why:
344
+ - only one execution of that endpoint runs at a time
345
+ - `singleflight=True`
346
+ - identical in-flight cache keys collapse to one execution
347
+ - this is the recommended model-loading default
328
348
 
329
- - serialization is endpoint-wide, not just per cache key
330
- - useful for fragile GPU-heavy paths where overlap is unsafe or wasteful
349
+ ## Scenario Map
350
+
351
+ - Fast tensor transform:
352
+ use `execution="direct"`, `cache=False`, `managed=False`
353
+ - Slow model construction:
354
+ use `execution="task"`, `managed=True`, `concurrency="serialized"`
355
+ - Reusable model registry:
356
+ add stable `cache_format_key` and keep `singleflight=True`
357
+ - Background warmup:
358
+ keep endpoint as task-backed and use `.submit(...)`
359
+ - Fragile non-overlapping GPU path:
360
+ use `concurrency="serialized"`
361
+ - Fresh per-request work:
362
+ disable cache and usually disable singleflight
331
363
 
332
364
  ## Parameter Guide
333
365
 
@@ -12,6 +12,7 @@
12
12
  - task-backed slow object construction
13
13
  - endpoint-level serialization and cache-key singleflight
14
14
  - zero-branch auto mode driven by `SHARED_TENSOR_ROLE`
15
+ - auto mode is gated by `SHARED_TENSOR_ENABLED=1`
15
16
  - port routing by `base_port + cuda_device_index`
16
17
 
17
18
  ## What It Does Not Support
@@ -53,9 +54,40 @@ conda activate shared-tensor-dev
53
54
  pip install -e ".[dev,test]"
54
55
  ```
55
56
 
56
- ## Zero-Branch Example
57
+ ## Enabling Auto Mode
57
58
 
58
- One file:
59
+ `SharedTensorProvider()` now defaults to safe local mode unless you explicitly enable shared-tensor behavior.
60
+
61
+ Global default:
62
+
63
+ ```bash
64
+ export SHARED_TENSOR_ENABLED=1
65
+ ```
66
+
67
+ Per-provider override:
68
+
69
+ ```python
70
+ provider = SharedTensorProvider(enabled=True)
71
+ provider = SharedTensorProvider(enabled=False)
72
+ provider = SharedTensorProvider(enabled=None)
73
+ ```
74
+
75
+ `enabled=None` means do not override and keep using the environment variable.
76
+
77
+ Then `execution_mode="auto"` behaves like this:
78
+
79
+ - `enabled=False`: provider stays in local mode
80
+ - `enabled=True` and `SHARED_TENSOR_ROLE=server`: auto-start server and execute locally on the server side
81
+ - `enabled=True` and no role set: provider becomes a client wrapper
82
+ - `enabled=None`: fall back to `SHARED_TENSOR_ENABLED`
83
+
84
+ This makes accidental opt-in much less likely in scripts that import shared endpoints but did not intend to start RPC behavior.
85
+
86
+ ## Example 1: Zero-Branch Auto Mode
87
+
88
+ See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
89
+
90
+ One file, two processes, no branch in user code:
59
91
 
60
92
  ```python
61
93
  import torch
@@ -90,55 +122,56 @@ if __name__ == "__main__":
90
122
  Run process A as the auto server:
91
123
 
92
124
  ```bash
93
- SHARED_TENSOR_ROLE=server python demo.py
125
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
94
126
  ```
95
127
 
96
128
  Run process B as the client with the exact same file:
97
129
 
98
130
  ```bash
131
+ SHARED_TENSOR_ENABLED=1 python demo.py
132
+ ```
133
+
134
+ Equivalent stepwise form:
135
+
136
+ ```bash
137
+ export SHARED_TENSOR_ENABLED=1
138
+ SHARED_TENSOR_ROLE=server python demo.py
99
139
  python demo.py
100
140
  ```
101
141
 
102
142
  Behavior:
103
143
 
144
+ - `SHARED_TENSOR_ENABLED=1` enables shared-tensor auto behavior for providers that keep `enabled=None`
104
145
  - `SHARED_TENSOR_ROLE=server` makes the provider auto-start a background localhost daemon
105
146
  - in the server process, shared functions still execute locally
106
147
  - in the client process, the same function names become RPC wrappers
107
148
  - no `SHARED_TENSOR_HOST` is used; transport is fixed to `127.0.0.1`
108
149
  - the final port is `SHARED_TENSOR_BASE_PORT + current_cuda_device_index`
109
150
 
110
- ## Endpoint Semantics
111
-
112
- Each endpoint is registered once and then supports two client-side call styles.
151
+ Why this works:
113
152
 
114
- - `fn(...)` or `provider.call(name, ...)`
115
- - synchronous
116
- - blocks until the result is ready
117
- - `fn.submit(...)` or `provider.submit(name, ...)`
118
- - asynchronous
119
- - returns a task id
153
+ ```text
154
+ same code file
120
155
 
121
- Endpoint options:
156
+ Process A Process B
157
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ENABLED=1
158
+ SHARED_TENSOR_ROLE=server SHARED_TENSOR_ROLE unset
159
+ ---------------------------------- ----------------------------------
160
+ provider.share(...) provider.share(...)
161
+ provider auto-starts localhost daemon provider builds RPC wrappers
162
+ shared fn executes locally shared fn becomes RPC call
122
163
 
123
- - `execution="direct"`
124
- - sync calls run the function directly on the server
125
- - best for fast tensor transforms
126
- - `execution="task"`
127
- - sync calls still block, but they block on the task system
128
- - async submit is the natural path for slow model construction
129
- - `concurrency="parallel"`
130
- - multiple server executions may run at once
131
- - `concurrency="serialized"`
132
- - only one execution of that endpoint runs at a time
133
- - `singleflight=True`
134
- - identical in-flight cache keys collapse to one execution
135
- - this is the recommended model-loading default
164
+ load_model(...) load_model(...)
165
+ -> local CUDA model -> JSON-RPC to localhost daemon
166
+ identity(x) -> receives CUDA IPC-backed result
167
+ -> local tensor return
168
+ ```
136
169
 
137
- ## Common Scenarios
170
+ Use this mode when you want the cleanest operator experience: one script, one env var difference, server side stays local, client side becomes remote automatically.
138
171
 
139
- ### 1. Fast Tensor Transform
172
+ ## Example 2: Fast Tensor Transform
140
173
 
141
- Use this for cheap operations such as clone, view-like transforms, elementwise scaling, or lightweight preprocessing.
174
+ See [examples/model_service.py](./examples/model_service.py).
142
175
 
143
176
  ```python
144
177
  @provider.share(execution="direct", cache=False)
@@ -146,6 +179,14 @@ def scale_tensor(tensor: torch.Tensor, factor: torch.Tensor) -> torch.Tensor:
146
179
  return tensor * factor
147
180
  ```
148
181
 
182
+ What happens on the wire:
183
+
184
+ ```text
185
+ client tensor -> direct RPC -> server runs function immediately -> CUDA result back
186
+ ```
187
+
188
+ Use this for cheap tensor math, lightweight preprocessing, and request-scoped outputs.
189
+
149
190
  Recommended combination:
150
191
 
151
192
  - `execution="direct"`
@@ -153,130 +194,120 @@ Recommended combination:
153
194
  - `managed=False`
154
195
  - `concurrency="parallel"`
155
196
 
156
- Why:
157
-
158
- - direct execution has the lowest overhead
159
- - these calls are request-scoped, so caching is usually wrong
160
- - parallel execution is usually fine because the work is short-lived
197
+ ## Example 3: Reusable Model Service
161
198
 
162
- ### 2. Slow Model Construction
163
-
164
- Use this for loading or building a CUDA model that may take hundreds of milliseconds or multiple seconds.
199
+ See [examples/model_service.py](./examples/model_service.py).
165
200
 
166
201
  ```python
167
202
  @provider.share(
168
203
  execution="task",
169
204
  managed=True,
170
205
  concurrency="serialized",
171
- cache_format_key="model:{hidden_size}",
206
+ cache_format_key="model:{input_dim}:{output_dim}",
172
207
  )
173
- def load_model(hidden_size: int) -> torch.nn.Module:
174
- return torch.nn.Linear(hidden_size, 2, device="cuda")
208
+ def load_linear_model(input_dim: int = 16, output_dim: int = 4) -> torch.nn.Module:
209
+ ...
175
210
  ```
176
211
 
177
- Recommended combination:
178
-
179
- - `execution="task"`
180
- - `managed=True`
181
- - `concurrency="serialized"`
182
- - `singleflight=True`
183
- - explicit `cache_format_key`
184
-
185
- Why:
186
-
187
- - task mode gives you both blocking sync calls and true async submission
188
- - managed handles let the client release the remote object explicitly
189
- - serialized execution avoids multiple concurrent heavy loads on one GPU
190
- - singleflight prevents duplicate in-flight construction for the same model key
191
-
192
- ### 3. Reusable Shared Model Service
212
+ What happens when two clients ask for the same model key:
193
213
 
194
- Use this when the model should be built once and reused by many client calls.
214
+ ```text
215
+ Client A Server Client B
216
+ ------------------------ ------------------------ ------------------------
217
+ call("load_model", k) -----> cache miss call("load_model", k)
218
+ build object once -------------> same key in flight
219
+ object_id = obj-123 wait on same future
220
+ <------------------------- return handle(obj-123) <------------- return handle(obj-123)
195
221
 
196
- ```python
197
- @provider.share(
198
- execution="task",
199
- managed=True,
200
- cache_format_key="model:{model_name}:{dtype}",
201
- )
202
- def load_model(model_name: str, dtype: str) -> torch.nn.Module:
203
- ...
222
+ release(obj-123) ----------> refcount 2 -> 1
223
+ release(obj-123) ------------------------------------------> refcount 1 -> 0 -> destroy
204
224
  ```
205
225
 
206
- Recommended combination:
226
+ Use this for big reusable models. The important mix is:
207
227
 
208
- - `cache=True`
228
+ - `execution="task"`
209
229
  - `managed=True`
230
+ - `concurrency="serialized"`
210
231
  - `singleflight=True`
211
- - explicit stable cache key
212
-
213
- Why:
232
+ - explicit `cache_format_key`
214
233
 
215
- - caching makes the endpoint act like a model registry
216
- - managed handles keep explicit lifecycle control
217
- - stable cache keys prevent accidental duplication from argument shape changes
234
+ `managed=True` gives explicit lifecycle control. `cache_format_key` turns the endpoint into a model registry. `singleflight=True` ensures duplicate in-flight loads collapse to one build.
218
235
 
219
- ### 4. Fire-and-Poll Background Warmup
236
+ ## Example 4: Fire-And-Poll Warmup
220
237
 
221
- Use this when the caller should not block, for example prewarming a model or allocating a large reusable tensor in the background.
238
+ This is the same task-backed endpoint style, but the caller chooses async use:
222
239
 
223
240
  ```python
224
241
  task_id = load_model.submit(hidden_size=8192)
225
242
  model_handle = provider.wait_for_task(task_id)
226
243
  ```
227
244
 
228
- Recommended combination:
229
-
230
- - endpoint uses `execution="task"`
231
- - caller uses `.submit(...)`
232
- - optionally add `managed=True` for long-lived objects
233
-
234
- Why:
245
+ Runtime shape:
235
246
 
236
- - the endpoint stays declarative
237
- - the caller decides whether to block now or poll later
247
+ ```text
248
+ submit now -> task queue -> slow build on server -> poll later -> consume handle/result
249
+ ```
238
250
 
239
- ### 5. Strictly Non-Reusable Per-Request Work
251
+ Use this when the build is slow enough that the caller should not block immediately.
240
252
 
241
- Use this when every request must create a fresh result and reuse is wrong.
253
+ ## Example 5: Serialized Fragile Path
242
254
 
243
255
  ```python
244
- @provider.share(execution="task", cache=False, singleflight=False)
245
- def build_request_tensor(template: torch.Tensor) -> torch.Tensor:
246
- return template.clone()
256
+ @provider.share(execution="task", concurrency="serialized", cache=False, singleflight=False)
257
+ def compact_memory(tensor: torch.Tensor) -> torch.Tensor:
258
+ ...
247
259
  ```
248
260
 
249
- Recommended combination:
250
-
251
- - `cache=False`
252
- - `singleflight=False`
253
- - choose `execution="direct"` or `execution="task"` based on runtime cost
261
+ Execution model:
254
262
 
255
- Why:
263
+ ```text
264
+ request A -> lock -> run -> unlock
265
+ request B -> wait -> lock -> run -> unlock
266
+ ```
256
267
 
257
- - disabling cache avoids cross-request reuse
258
- - disabling singleflight ensures independent requests stay independent
268
+ Use this for GPU-heavy paths that must not overlap with themselves.
259
269
 
260
- ### 6. Endpoint That Must Run One At A Time
270
+ ## Endpoint Semantics
261
271
 
262
- Use this when the endpoint mutates shared state, temporarily spikes memory, or must not overlap with itself.
272
+ Each endpoint is registered once and then supports two client-side call styles.
263
273
 
264
- ```python
265
- @provider.share(execution="task", concurrency="serialized", cache=False, singleflight=False)
266
- def compact_memory(tensor: torch.Tensor) -> torch.Tensor:
267
- ...
268
- ```
274
+ - `fn(...)` or `provider.call(name, ...)`
275
+ - synchronous
276
+ - blocks until the result is ready
277
+ - `fn.submit(...)` or `provider.submit(name, ...)`
278
+ - asynchronous
279
+ - returns a task id
269
280
 
270
- Recommended combination:
281
+ Endpoint options:
271
282
 
283
+ - `execution="direct"`
284
+ - sync calls run the function directly on the server
285
+ - use this for fast tensor transforms
272
286
  - `execution="task"`
287
+ - sync calls still block, but they block on the task system
288
+ - use this for slow construction, warmup, and reusable model loading
289
+ - `concurrency="parallel"`
290
+ - multiple server executions may run at once
273
291
  - `concurrency="serialized"`
274
- - usually `cache=False`
275
-
276
- Why:
292
+ - only one execution of that endpoint runs at a time
293
+ - `singleflight=True`
294
+ - identical in-flight cache keys collapse to one execution
295
+ - this is the recommended model-loading default
277
296
 
278
- - serialization is endpoint-wide, not just per cache key
279
- - useful for fragile GPU-heavy paths where overlap is unsafe or wasteful
297
+ ## Scenario Map
298
+
299
+ - Fast tensor transform:
300
+ use `execution="direct"`, `cache=False`, `managed=False`
301
+ - Slow model construction:
302
+ use `execution="task"`, `managed=True`, `concurrency="serialized"`
303
+ - Reusable model registry:
304
+ add stable `cache_format_key` and keep `singleflight=True`
305
+ - Background warmup:
306
+ keep endpoint as task-backed and use `.submit(...)`
307
+ - Fragile non-overlapping GPU path:
308
+ use `concurrency="serialized"`
309
+ - Fresh per-request work:
310
+ disable cache and usually disable singleflight
280
311
 
281
312
  ## Parameter Guide
282
313
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.2"
7
+ version = "0.2.4"
8
8
  description = "Local endpoint-oriented RPC for same-host same-GPU PyTorch IPC"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -42,6 +42,7 @@ classifiers = [
42
42
  ]
43
43
  requires-python = ">=3.10"
44
44
  dependencies = [
45
+ "cloudpickle>=3.0.0",
45
46
  "numpy<2",
46
47
  "requests>=2.25.0",
47
48
  "torch>=2.2.0",
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.2"
22
+ __version__ = "0.2.4"
@@ -15,6 +15,7 @@ class AsyncSharedTensorProvider(SharedTensorProvider):
15
15
  base_port: int = 2537,
16
16
  poll_interval: float = 1.0,
17
17
  *,
18
+ enabled: bool | None = None,
18
19
  server_host: str = "127.0.0.1",
19
20
  device_index: int | None = None,
20
21
  timeout: float = 30.0,
@@ -23,6 +24,7 @@ class AsyncSharedTensorProvider(SharedTensorProvider):
23
24
  ) -> None:
24
25
  super().__init__(
25
26
  base_port=base_port,
27
+ enabled=enabled,
26
28
  server_host=server_host,
27
29
  device_index=device_index,
28
30
  timeout=timeout,
@@ -21,6 +21,7 @@ from shared_tensor.utils import (
21
21
 
22
22
  EndpointExecution = Literal["direct", "task"]
23
23
  EndpointConcurrency = Literal["parallel", "serialized"]
24
+ SHARED_TENSOR_ENABLED_ENV = "SHARED_TENSOR_ENABLED"
24
25
 
25
26
 
26
27
  @dataclass(slots=True)
@@ -36,8 +37,14 @@ class EndpointDefinition:
36
37
  singleflight: bool = True
37
38
 
38
39
 
39
- def _resolve_execution_mode(execution_mode: str) -> tuple[str, bool]:
40
+ def _resolve_execution_mode(
41
+ execution_mode: str,
42
+ *,
43
+ enabled: bool | None = None,
44
+ ) -> tuple[str, bool]:
40
45
  if execution_mode == "auto":
46
+ if not _is_shared_tensor_enabled(enabled):
47
+ return "local", True
41
48
  env_role = os.getenv("SHARED_TENSOR_ROLE", "").strip().lower()
42
49
  if env_role in {"server", "client", "local"}:
43
50
  return env_role, True
@@ -49,6 +56,13 @@ def _resolve_execution_mode(execution_mode: str) -> tuple[str, bool]:
49
56
  return execution_mode, False
50
57
 
51
58
 
59
+ def _is_shared_tensor_enabled(enabled: bool | None) -> bool:
60
+ if enabled is not None:
61
+ return enabled
62
+ raw = os.getenv(SHARED_TENSOR_ENABLED_ENV, "").strip().lower()
63
+ return raw in {"1", "true", "yes", "on"}
64
+
65
+
52
66
  def _validate_endpoint_options(
53
67
  *,
54
68
  execution: EndpointExecution,
@@ -71,15 +85,20 @@ class SharedTensorProvider:
71
85
  self,
72
86
  base_port: int = 2537,
73
87
  *,
88
+ enabled: bool | None = None,
74
89
  server_host: str = "127.0.0.1",
75
90
  device_index: int | None = None,
76
91
  timeout: float = 30.0,
77
92
  execution_mode: str = "auto",
78
93
  verbose_debug: bool = False,
79
94
  ) -> None:
80
- resolved_mode, auto_mode = _resolve_execution_mode(execution_mode)
95
+ resolved_mode, auto_mode = _resolve_execution_mode(
96
+ execution_mode,
97
+ enabled=enabled,
98
+ )
81
99
  self.server_host = server_host
82
100
  self.base_port = resolve_server_base_port(base_port)
101
+ self.enabled = enabled
83
102
  self.device_index = device_index
84
103
  self.timeout = timeout
85
104
  self.execution_mode = resolved_mode
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
+ import cloudpickle
6
7
  import multiprocessing as mp
7
8
  import os
8
9
  import threading
@@ -10,7 +11,6 @@ import time
10
11
  from concurrent.futures import Future
11
12
  from dataclasses import dataclass
12
13
  from http.server import BaseHTTPRequestHandler, HTTPServer
13
- from multiprocessing.process import BaseProcess
14
14
  from socketserver import ThreadingMixIn
15
15
  from typing import Any
16
16
 
@@ -37,7 +37,7 @@ from shared_tensor.utils import (
37
37
  capability_snapshot,
38
38
  deserialize_payload,
39
39
  serialize_payload,
40
- validate_payload_for_transport,
40
+ validate_call_payload_for_transport,
41
41
  )
42
42
 
43
43
  logger = logging.getLogger(__name__)
@@ -118,7 +118,7 @@ class SharedTensorServer:
118
118
  self.max_workers = max_workers
119
119
  self.result_ttl = result_ttl
120
120
  self.server: ThreadedHTTPServer | None = None
121
- self.server_process: BaseProcess | None = None
121
+ self.server_process: Any | None = None
122
122
  self.running = False
123
123
  self.started_at: float | None = None
124
124
  self.stats = {
@@ -231,6 +231,12 @@ class SharedTensorServer:
231
231
  if inflight_key is not None:
232
232
  future, owner = self._acquire_inflight(inflight_key)
233
233
  if not owner:
234
+ if definition.managed:
235
+ payload = future.result()
236
+ object_id = payload.get("object_id")
237
+ if object_id is not None:
238
+ self._managed_objects.add_ref(object_id)
239
+ return payload
234
240
  return future.result()
235
241
  else:
236
242
  future = None
@@ -401,8 +407,8 @@ class SharedTensorServer:
401
407
  "Control encoding is reserved for empty args/kwargs only"
402
408
  )
403
409
  return endpoint, args, kwargs
404
- validate_payload_for_transport(args)
405
- validate_payload_for_transport(kwargs, allow_dict_keys=True)
410
+ validate_call_payload_for_transport(args)
411
+ validate_call_payload_for_transport(kwargs, allow_dict_keys=True)
406
412
  return endpoint, args, kwargs
407
413
 
408
414
  def _encode_result(self, value: Any, *, object_id: str | None = None) -> dict[str, str | None]:
@@ -437,7 +443,7 @@ class SharedTensorServer:
437
443
  uptime = 0.0 if self.started_at is None else time.time() - self.started_at
438
444
  return {
439
445
  "server": "SharedTensorServer",
440
- "version": "0.2.2",
446
+ "version": "0.2.4",
441
447
  "host": self.host,
442
448
  "port": self.port,
443
449
  "uptime": uptime,
@@ -454,16 +460,52 @@ class SharedTensorServer:
454
460
  return
455
461
  if os.name != "posix":
456
462
  raise SharedTensorConfigurationError(
457
- "Non-blocking shared_tensor servers require POSIX fork semantics"
463
+ "Non-blocking shared_tensor servers require POSIX multiprocessing support"
458
464
  )
459
- ctx = mp.get_context("fork")
460
- process = ctx.Process(target=self._serve_forever, name=f"shared-tensor-daemon:{self.port}")
465
+ payload = cloudpickle.dumps(self.provider)
466
+ process = mp.get_context("spawn").Process(
467
+ target=self._serve_forever_from_payload,
468
+ args=(
469
+ payload,
470
+ self.host,
471
+ self.port,
472
+ self.max_request_bytes,
473
+ self.max_workers,
474
+ self.result_ttl,
475
+ self.verbose_debug,
476
+ ),
477
+ name=f"shared-tensor-daemon:{self.port}",
478
+ )
461
479
  process.start()
462
480
  self.server_process = process
463
481
  self.running = True
464
482
  self.started_at = time.time()
465
483
 
484
+ @staticmethod
485
+ def _serve_forever_from_payload(
486
+ payload: bytes,
487
+ host: str,
488
+ port: int,
489
+ max_request_bytes: int,
490
+ max_workers: int,
491
+ result_ttl: float,
492
+ verbose_debug: bool,
493
+ ) -> None:
494
+ SharedTensorServer._configure_cuda_runtime()
495
+ provider = cloudpickle.loads(payload)
496
+ server = SharedTensorServer(
497
+ provider,
498
+ host=host,
499
+ port=port,
500
+ max_request_bytes=max_request_bytes,
501
+ max_workers=max_workers,
502
+ result_ttl=result_ttl,
503
+ verbose_debug=verbose_debug,
504
+ )
505
+ server._serve_forever()
506
+
466
507
  def _serve_forever(self) -> None:
508
+ self._configure_cuda_runtime()
467
509
  self.server = ThreadedHTTPServer((self.host, self.port), SharedTensorRequestHandler)
468
510
  self.server.shared_tensor_server = self # type: ignore[attr-defined]
469
511
  self.running = True
@@ -5,6 +5,7 @@ from __future__ import annotations
5
5
  import hashlib
6
6
  import inspect
7
7
  import io
8
+ import multiprocessing.reduction as mp_reduction
8
9
  import os
9
10
  import pickle
10
11
  from collections.abc import Callable
@@ -35,7 +36,17 @@ _EMPTY_DICT: dict[str, Any] = {}
35
36
  def _torch_forking_pickler() -> type | None:
36
37
  if TORCH_MODULE is None:
37
38
  return None
38
- return cast(type, TORCH_MODULE.multiprocessing.reductions.ForkingPickler)
39
+ reductions = TORCH_MODULE.multiprocessing.reductions
40
+ init_reductions = getattr(reductions, "init_reductions", None)
41
+ if callable(init_reductions):
42
+ try:
43
+ init_reductions()
44
+ except Exception:
45
+ return cast(type, mp_reduction.ForkingPickler)
46
+ pickler = getattr(reductions, "ForkingPickler", None)
47
+ if pickler is not None:
48
+ return cast(type, pickler)
49
+ return cast(type, mp_reduction.ForkingPickler)
39
50
 
40
51
 
41
52
  def _raise_unsupported_payload(message: str) -> None:
@@ -91,11 +102,55 @@ def _validate_torch_payload(obj: Any, *, allow_dict_keys: bool = False) -> None:
91
102
  _raise_unsupported_payload(f"Unsupported payload type: {type(obj).__name__}")
92
103
 
93
104
 
105
+ def _validate_call_payload(obj: Any, *, allow_dict_keys: bool = False) -> None:
106
+ if TORCH_MODULE is None:
107
+ raise SharedTensorCapabilityError("PyTorch is required for shared_tensor")
108
+
109
+ if isinstance(obj, (str, int, float, bool, type(None), bytes)):
110
+ return
111
+
112
+ if isinstance(obj, TORCH_MODULE.Tensor):
113
+ if not obj.is_cuda:
114
+ _raise_unsupported_payload("CPU torch.Tensor payloads are not supported")
115
+ return
116
+
117
+ if isinstance(obj, TORCH_MODULE.nn.Module):
118
+ _validate_module_device(obj)
119
+ return
120
+
121
+ if isinstance(obj, tuple):
122
+ for item in obj:
123
+ _validate_call_payload(item)
124
+ return
125
+
126
+ if isinstance(obj, list):
127
+ for item in obj:
128
+ _validate_call_payload(item)
129
+ return
130
+
131
+ if isinstance(obj, dict):
132
+ for key, value in obj.items():
133
+ if allow_dict_keys:
134
+ if not isinstance(key, str):
135
+ _raise_unsupported_payload("Dictionary payload keys must be strings")
136
+ else:
137
+ _validate_call_payload(key)
138
+ _validate_call_payload(value)
139
+ return
140
+
141
+ _raise_unsupported_payload(f"Unsupported payload type: {type(obj).__name__}")
142
+
143
+
94
144
  def validate_payload_for_transport(obj: Any, *, allow_dict_keys: bool = False) -> None:
95
145
  """Validate that a payload fits the supported CUDA torch transport contract."""
96
146
  _validate_torch_payload(obj, allow_dict_keys=allow_dict_keys)
97
147
 
98
148
 
149
+ def validate_call_payload_for_transport(obj: Any, *, allow_dict_keys: bool = False) -> None:
150
+ """Validate RPC call args/kwargs, allowing scalar controls alongside CUDA payloads."""
151
+ _validate_call_payload(obj, allow_dict_keys=allow_dict_keys)
152
+
153
+
99
154
  def _torch_serialize(obj: Any) -> bytes:
100
155
  pickler_cls = _torch_forking_pickler()
101
156
  if pickler_cls is None:
@@ -140,8 +195,8 @@ def serialize_call_payloads(
140
195
  return CONTROL_ENCODING, args_payload, serialize_empty_payload(_EMPTY_DICT)[1]
141
196
 
142
197
  try:
143
- _validate_torch_payload(args)
144
- _validate_torch_payload(kwargs, allow_dict_keys=True)
198
+ _validate_call_payload(args)
199
+ _validate_call_payload(kwargs, allow_dict_keys=True)
145
200
  return TORCH_ENCODING, pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL), pickle.dumps(
146
201
  kwargs, protocol=pickle.HIGHEST_PROTOCOL
147
202
  )
File without changes
File without changes
File without changes