shared-tensor 0.2.8__tar.gz → 0.2.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/PKG-INFO +87 -5
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/README.md +83 -2
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/pyproject.toml +4 -3
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/__init__.py +3 -3
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/async_client.py +33 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/client.py +62 -2
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/errors.py +18 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/managed_object.py +50 -1
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/provider.py +34 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/server.py +117 -10
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor.egg-info/SOURCES.txt +0 -1
- shared_tensor-0.2.8/shared_tensor/async_provider.py +0 -97
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/LICENSE +0 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/MANIFEST.in +0 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/setup.cfg +0 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/async_task.py +0 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/runtime.py +0 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/transport.py +0 -0
- {shared_tensor-0.2.8 → shared_tensor-0.2.10}/shared_tensor/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: shared-tensor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.10
|
|
4
4
|
Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
|
|
5
5
|
Author-email: Athena Team <contact@world-sim-dev.org>
|
|
6
6
|
Maintainer-email: Athena Team <contact@world-sim-dev.org>
|
|
@@ -21,15 +21,16 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
22
|
Classifier: Programming Language :: Python :: 3.12
|
|
23
23
|
Classifier: Programming Language :: Python :: 3.13
|
|
24
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
24
25
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
25
26
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
27
|
Classifier: Topic :: System :: Distributed Computing
|
|
27
|
-
Requires-Python:
|
|
28
|
+
Requires-Python: >=3.9
|
|
28
29
|
Description-Content-Type: text/markdown
|
|
29
30
|
License-File: LICENSE
|
|
30
31
|
Requires-Dist: cloudpickle>=3.0.0
|
|
31
32
|
Requires-Dist: numpy<2
|
|
32
|
-
Requires-Dist: torch
|
|
33
|
+
Requires-Dist: torch>=2.1
|
|
33
34
|
Provides-Extra: dev
|
|
34
35
|
Requires-Dist: pytest>=6.0; extra == "dev"
|
|
35
36
|
Requires-Dist: pytest-cov>=2.0; extra == "dev"
|
|
@@ -64,7 +65,7 @@ Supported:
|
|
|
64
65
|
- explicit endpoint registration
|
|
65
66
|
- sync `call` and task-backed `submit`
|
|
66
67
|
- managed object handles with explicit release
|
|
67
|
-
- server-side caching, `cache_format_key`, and
|
|
68
|
+
- server-side caching, `cache_format_key`, singleflight, and explicit cache invalidation
|
|
68
69
|
- manual two-process deployment as the primary production path
|
|
69
70
|
- zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
|
|
70
71
|
|
|
@@ -107,6 +108,8 @@ Production should prefer two explicitly started processes: one server process th
|
|
|
107
108
|
|
|
108
109
|
See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
|
|
109
110
|
|
|
111
|
+
The server-oriented example modules construct providers with explicit `execution_mode="server"` so importing the module already reflects the intended deployment role.
|
|
112
|
+
|
|
110
113
|
Server process:
|
|
111
114
|
|
|
112
115
|
```python
|
|
@@ -147,10 +150,32 @@ executes endpoint functions reopens CUDA objects via torch IPC
|
|
|
147
150
|
manages cache and refcounts releases managed handles explicitly
|
|
148
151
|
```
|
|
149
152
|
|
|
153
|
+
## Lifetime And Failure Contract
|
|
154
|
+
|
|
155
|
+
`shared_tensor` follows native PyTorch CUDA IPC semantics. It does not virtualize or harden producer lifetime.
|
|
156
|
+
|
|
157
|
+
Core assumption:
|
|
158
|
+
- the server process that owns the original CUDA allocation must stay alive while clients are still using reopened CUDA tensors or modules
|
|
159
|
+
- handle health checks can detect some stale-object conditions, but they do not remove the producer-liveness requirement
|
|
160
|
+
|
|
161
|
+
If the server exits, crashes, or is killed before the client is done with the shared CUDA object, behavior is no longer guaranteed by this library. Depending on PyTorch and CUDA runtime state, the client may see CUDA runtime errors, invalid resource handle failures, broken module execution, or process-level instability.
|
|
162
|
+
|
|
163
|
+
So the production contract is:
|
|
164
|
+
- client-side handles are only valid while the producer process remains alive
|
|
165
|
+
- `handle.release()` is explicit lifecycle cleanup, not durability
|
|
166
|
+
- this library does not promise survivability across producer death
|
|
167
|
+
|
|
168
|
+
Treat producer liveness as a hard requirement, not a soft optimization.
|
|
169
|
+
|
|
150
170
|
## Example: Same Code, Two Processes
|
|
151
171
|
|
|
152
172
|
See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
|
|
153
173
|
|
|
174
|
+
Resolution rule:
|
|
175
|
+
- `SHARED_TENSOR_ENABLED` unset or false: provider stays local
|
|
176
|
+
- `SHARED_TENSOR_ENABLED=1` and `SHARED_TENSOR_ROLE=server`: provider resolves to server and auto-starts the thread-backed local server
|
|
177
|
+
- `SHARED_TENSOR_ENABLED=1` and role unset or `client`: provider resolves to client
|
|
178
|
+
|
|
154
179
|
```bash
|
|
155
180
|
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
|
|
156
181
|
SHARED_TENSOR_ENABLED=1 python demo.py
|
|
@@ -168,6 +193,27 @@ shared function runs locally shared function becomes RPC call
|
|
|
168
193
|
CUDA object stays on same GPU CUDA object is reopened via torch IPC
|
|
169
194
|
```
|
|
170
195
|
|
|
196
|
+
## Example: Task Submission And Wait
|
|
197
|
+
|
|
198
|
+
See [examples/async_service.py](./examples/async_service.py).
|
|
199
|
+
|
|
200
|
+
```python
|
|
201
|
+
from shared_tensor import AsyncSharedTensorClient, SharedTensorProvider
|
|
202
|
+
|
|
203
|
+
provider = SharedTensorProvider(execution_mode="server")
|
|
204
|
+
|
|
205
|
+
@provider.share(execution="task")
|
|
206
|
+
def build_delayed_model(delay: float = 0.1):
|
|
207
|
+
...
|
|
208
|
+
|
|
209
|
+
client = AsyncSharedTensorClient()
|
|
210
|
+
task_id = client.submit("build_delayed_model", delay=0.1)
|
|
211
|
+
model = client.wait_for_task(task_id, timeout=30)
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
Use `SharedTensorProvider(execution="task")` for task-backed endpoints.
|
|
215
|
+
Use `AsyncSharedTensorClient` when you want a task-oriented waiting interface.
|
|
216
|
+
|
|
171
217
|
## Example: Reusable Model Registry
|
|
172
218
|
|
|
173
219
|
See [examples/model_service.py](./examples/model_service.py).
|
|
@@ -278,11 +324,47 @@ handle.release()
|
|
|
278
324
|
```
|
|
279
325
|
|
|
280
326
|
Use managed mode for cached models or other reusable long-lived CUDA objects.
|
|
327
|
+
Managed object introspection now includes `created_at` and `last_accessed_at` timestamps through `get_object_info()`.
|
|
328
|
+
|
|
329
|
+
## Cache Invalidation
|
|
330
|
+
|
|
331
|
+
The library now exposes explicit cache invalidation instead of forcing process restarts when a cached object becomes stale.
|
|
332
|
+
|
|
333
|
+
```python
|
|
334
|
+
provider.invalidate_call_cache("load_model", hidden_size=4096)
|
|
335
|
+
provider.invalidate_endpoint_cache("load_model")
|
|
336
|
+
```
|
|
337
|
+
|
|
338
|
+
Client-side equivalents are also available:
|
|
339
|
+
|
|
340
|
+
```python
|
|
341
|
+
client.invalidate_call_cache("load_model", hidden_size=4096)
|
|
342
|
+
client.invalidate_endpoint_cache("load_model")
|
|
343
|
+
```
|
|
344
|
+
|
|
345
|
+
Use call-level invalidation when you want to evict one cache key.
|
|
346
|
+
Use endpoint-level invalidation when you want to drop all cached variants for the endpoint.
|
|
347
|
+
Invalidation removes cache lookup entries; it does not guarantee that already-issued client handles remain valid after producer death.
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
## Handle Health Checks
|
|
351
|
+
|
|
352
|
+
Managed handles now carry the producer `server_id` and support lightweight liveness probes:
|
|
353
|
+
|
|
354
|
+
```python
|
|
355
|
+
handle = client.call("load_model", hidden_size=4096)
|
|
356
|
+
info = handle.get_object_info()
|
|
357
|
+
client.ensure_handle_live(handle)
|
|
358
|
+
```
|
|
359
|
+
|
|
360
|
+
If the producer no longer owns the object, `client.ensure_handle_live(handle)` raises `SharedTensorStaleHandleError`.
|
|
361
|
+
This is still advisory, not a durability guarantee: it helps detect stale handles earlier, but it cannot make producer death safe.
|
|
281
362
|
|
|
282
363
|
## Runtime Introspection
|
|
283
364
|
|
|
284
|
-
`client.get_server_info()` now returns readiness and process metadata in addition to endpoint and capability data.
|
|
365
|
+
`client.get_server_info()` now returns readiness, stable `server_id`, cache/task counters, and process metadata in addition to endpoint and capability data.
|
|
285
366
|
In client mode, `provider.get_runtime_info()` wraps that into a provider-oriented view.
|
|
367
|
+
`AsyncSharedTensorClient` exposes the same runtime, cache invalidation, release, and handle-health helper methods as `SharedTensorClient`; the async surface is task-oriented, not capability-reduced.
|
|
286
368
|
|
|
287
369
|
```python
|
|
288
370
|
info = provider.get_runtime_info()
|
|
@@ -12,7 +12,7 @@ Supported:
|
|
|
12
12
|
- explicit endpoint registration
|
|
13
13
|
- sync `call` and task-backed `submit`
|
|
14
14
|
- managed object handles with explicit release
|
|
15
|
-
- server-side caching, `cache_format_key`, and
|
|
15
|
+
- server-side caching, `cache_format_key`, singleflight, and explicit cache invalidation
|
|
16
16
|
- manual two-process deployment as the primary production path
|
|
17
17
|
- zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
|
|
18
18
|
|
|
@@ -55,6 +55,8 @@ Production should prefer two explicitly started processes: one server process th
|
|
|
55
55
|
|
|
56
56
|
See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
|
|
57
57
|
|
|
58
|
+
The server-oriented example modules construct providers with explicit `execution_mode="server"` so importing the module already reflects the intended deployment role.
|
|
59
|
+
|
|
58
60
|
Server process:
|
|
59
61
|
|
|
60
62
|
```python
|
|
@@ -95,10 +97,32 @@ executes endpoint functions reopens CUDA objects via torch IPC
|
|
|
95
97
|
manages cache and refcounts releases managed handles explicitly
|
|
96
98
|
```
|
|
97
99
|
|
|
100
|
+
## Lifetime And Failure Contract
|
|
101
|
+
|
|
102
|
+
`shared_tensor` follows native PyTorch CUDA IPC semantics. It does not virtualize or harden producer lifetime.
|
|
103
|
+
|
|
104
|
+
Core assumption:
|
|
105
|
+
- the server process that owns the original CUDA allocation must stay alive while clients are still using reopened CUDA tensors or modules
|
|
106
|
+
- handle health checks can detect some stale-object conditions, but they do not remove the producer-liveness requirement
|
|
107
|
+
|
|
108
|
+
If the server exits, crashes, or is killed before the client is done with the shared CUDA object, behavior is no longer guaranteed by this library. Depending on PyTorch and CUDA runtime state, the client may see CUDA runtime errors, invalid resource handle failures, broken module execution, or process-level instability.
|
|
109
|
+
|
|
110
|
+
So the production contract is:
|
|
111
|
+
- client-side handles are only valid while the producer process remains alive
|
|
112
|
+
- `handle.release()` is explicit lifecycle cleanup, not durability
|
|
113
|
+
- this library does not promise survivability across producer death
|
|
114
|
+
|
|
115
|
+
Treat producer liveness as a hard requirement, not a soft optimization.
|
|
116
|
+
|
|
98
117
|
## Example: Same Code, Two Processes
|
|
99
118
|
|
|
100
119
|
See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
|
|
101
120
|
|
|
121
|
+
Resolution rule:
|
|
122
|
+
- `SHARED_TENSOR_ENABLED` unset or false: provider stays local
|
|
123
|
+
- `SHARED_TENSOR_ENABLED=1` and `SHARED_TENSOR_ROLE=server`: provider resolves to server and auto-starts the thread-backed local server
|
|
124
|
+
- `SHARED_TENSOR_ENABLED=1` and role unset or `client`: provider resolves to client
|
|
125
|
+
|
|
102
126
|
```bash
|
|
103
127
|
SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
|
|
104
128
|
SHARED_TENSOR_ENABLED=1 python demo.py
|
|
@@ -116,6 +140,27 @@ shared function runs locally shared function becomes RPC call
|
|
|
116
140
|
CUDA object stays on same GPU CUDA object is reopened via torch IPC
|
|
117
141
|
```
|
|
118
142
|
|
|
143
|
+
## Example: Task Submission And Wait
|
|
144
|
+
|
|
145
|
+
See [examples/async_service.py](./examples/async_service.py).
|
|
146
|
+
|
|
147
|
+
```python
|
|
148
|
+
from shared_tensor import AsyncSharedTensorClient, SharedTensorProvider
|
|
149
|
+
|
|
150
|
+
provider = SharedTensorProvider(execution_mode="server")
|
|
151
|
+
|
|
152
|
+
@provider.share(execution="task")
|
|
153
|
+
def build_delayed_model(delay: float = 0.1):
|
|
154
|
+
...
|
|
155
|
+
|
|
156
|
+
client = AsyncSharedTensorClient()
|
|
157
|
+
task_id = client.submit("build_delayed_model", delay=0.1)
|
|
158
|
+
model = client.wait_for_task(task_id, timeout=30)
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
Use `SharedTensorProvider(execution="task")` for task-backed endpoints.
|
|
162
|
+
Use `AsyncSharedTensorClient` when you want a task-oriented waiting interface.
|
|
163
|
+
|
|
119
164
|
## Example: Reusable Model Registry
|
|
120
165
|
|
|
121
166
|
See [examples/model_service.py](./examples/model_service.py).
|
|
@@ -226,11 +271,47 @@ handle.release()
|
|
|
226
271
|
```
|
|
227
272
|
|
|
228
273
|
Use managed mode for cached models or other reusable long-lived CUDA objects.
|
|
274
|
+
Managed object introspection now includes `created_at` and `last_accessed_at` timestamps through `get_object_info()`.
|
|
275
|
+
|
|
276
|
+
## Cache Invalidation
|
|
277
|
+
|
|
278
|
+
The library now exposes explicit cache invalidation instead of forcing process restarts when a cached object becomes stale.
|
|
279
|
+
|
|
280
|
+
```python
|
|
281
|
+
provider.invalidate_call_cache("load_model", hidden_size=4096)
|
|
282
|
+
provider.invalidate_endpoint_cache("load_model")
|
|
283
|
+
```
|
|
284
|
+
|
|
285
|
+
Client-side equivalents are also available:
|
|
286
|
+
|
|
287
|
+
```python
|
|
288
|
+
client.invalidate_call_cache("load_model", hidden_size=4096)
|
|
289
|
+
client.invalidate_endpoint_cache("load_model")
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
Use call-level invalidation when you want to evict one cache key.
|
|
293
|
+
Use endpoint-level invalidation when you want to drop all cached variants for the endpoint.
|
|
294
|
+
Invalidation removes cache lookup entries; it does not guarantee that already-issued client handles remain valid after producer death.
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
## Handle Health Checks
|
|
298
|
+
|
|
299
|
+
Managed handles now carry the producer `server_id` and support lightweight liveness probes:
|
|
300
|
+
|
|
301
|
+
```python
|
|
302
|
+
handle = client.call("load_model", hidden_size=4096)
|
|
303
|
+
info = handle.get_object_info()
|
|
304
|
+
client.ensure_handle_live(handle)
|
|
305
|
+
```
|
|
306
|
+
|
|
307
|
+
If the producer no longer owns the object, `client.ensure_handle_live(handle)` raises `SharedTensorStaleHandleError`.
|
|
308
|
+
This is still advisory, not a durability guarantee: it helps detect stale handles earlier, but it cannot make producer death safe.
|
|
229
309
|
|
|
230
310
|
## Runtime Introspection
|
|
231
311
|
|
|
232
|
-
`client.get_server_info()` now returns readiness and process metadata in addition to endpoint and capability data.
|
|
312
|
+
`client.get_server_info()` now returns readiness, stable `server_id`, cache/task counters, and process metadata in addition to endpoint and capability data.
|
|
233
313
|
In client mode, `provider.get_runtime_info()` wraps that into a provider-oriented view.
|
|
314
|
+
`AsyncSharedTensorClient` exposes the same runtime, cache invalidation, release, and handle-health helper methods as `SharedTensorClient`; the async surface is task-oriented, not capability-reduced.
|
|
234
315
|
|
|
235
316
|
```python
|
|
236
317
|
info = provider.get_runtime_info()
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "shared-tensor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.10"
|
|
8
8
|
description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "Apache-2.0"
|
|
@@ -38,15 +38,16 @@ classifiers = [
|
|
|
38
38
|
"Programming Language :: Python :: 3.11",
|
|
39
39
|
"Programming Language :: Python :: 3.12",
|
|
40
40
|
"Programming Language :: Python :: 3.13",
|
|
41
|
+
"Programming Language :: Python :: 3.14",
|
|
41
42
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
42
43
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
43
44
|
"Topic :: System :: Distributed Computing",
|
|
44
45
|
]
|
|
45
|
-
requires-python = ">=3.9
|
|
46
|
+
requires-python = ">=3.9"
|
|
46
47
|
dependencies = [
|
|
47
48
|
"cloudpickle>=3.0.0",
|
|
48
49
|
"numpy<2",
|
|
49
|
-
"torch>=2.1
|
|
50
|
+
"torch>=2.1",
|
|
50
51
|
]
|
|
51
52
|
|
|
52
53
|
[project.optional-dependencies]
|
|
@@ -1,22 +1,22 @@
|
|
|
1
1
|
"""shared_tensor: same-host same-GPU PyTorch CUDA IPC over local UDS RPC."""
|
|
2
2
|
|
|
3
3
|
from shared_tensor.async_client import AsyncSharedTensorClient
|
|
4
|
-
from shared_tensor.async_provider import AsyncSharedTensorProvider
|
|
5
4
|
from shared_tensor.async_task import TaskInfo, TaskStatus
|
|
6
5
|
from shared_tensor.client import SharedTensorClient
|
|
6
|
+
from shared_tensor.errors import SharedTensorStaleHandleError
|
|
7
7
|
from shared_tensor.managed_object import SharedObjectHandle
|
|
8
8
|
from shared_tensor.provider import SharedTensorProvider
|
|
9
9
|
from shared_tensor.server import SharedTensorServer
|
|
10
10
|
|
|
11
11
|
__all__ = [
|
|
12
12
|
"AsyncSharedTensorClient",
|
|
13
|
-
"AsyncSharedTensorProvider",
|
|
14
13
|
"SharedTensorClient",
|
|
15
14
|
"SharedObjectHandle",
|
|
15
|
+
"SharedTensorStaleHandleError",
|
|
16
16
|
"SharedTensorProvider",
|
|
17
17
|
"SharedTensorServer",
|
|
18
18
|
"TaskInfo",
|
|
19
19
|
"TaskStatus",
|
|
20
20
|
]
|
|
21
21
|
|
|
22
|
-
__version__ = "0.2.
|
|
22
|
+
__version__ = "0.2.10"
|
|
@@ -10,6 +10,7 @@ from typing import Any, cast
|
|
|
10
10
|
from shared_tensor.async_task import TaskInfo, TaskStatus
|
|
11
11
|
from shared_tensor.client import SharedTensorClient
|
|
12
12
|
from shared_tensor.errors import SharedTensorRemoteError, SharedTensorTaskError
|
|
13
|
+
from shared_tensor.managed_object import SharedObjectHandle
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger(__name__)
|
|
@@ -51,6 +52,38 @@ class AsyncSharedTensorClient:
|
|
|
51
52
|
def get_task_result(self, task_id: str) -> Any:
|
|
52
53
|
return self.result(task_id)
|
|
53
54
|
|
|
55
|
+
def ping(self) -> bool:
|
|
56
|
+
return self._client.ping()
|
|
57
|
+
|
|
58
|
+
def get_server_info(self) -> dict[str, Any]:
|
|
59
|
+
return self._client.get_server_info()
|
|
60
|
+
|
|
61
|
+
def list_endpoints(self) -> dict[str, Any]:
|
|
62
|
+
return self._client.list_endpoints()
|
|
63
|
+
|
|
64
|
+
def release(self, object_id: str) -> bool:
|
|
65
|
+
return self._client.release(object_id)
|
|
66
|
+
|
|
67
|
+
def release_many(self, object_ids: list[str]) -> dict[str, bool]:
|
|
68
|
+
return self._client.release_many(object_ids)
|
|
69
|
+
|
|
70
|
+
def get_object_info(self, object_id: str) -> dict[str, Any] | None:
|
|
71
|
+
return self._client.get_object_info(object_id)
|
|
72
|
+
|
|
73
|
+
def ensure_handle_live(
|
|
74
|
+
self,
|
|
75
|
+
handle: SharedObjectHandle[Any],
|
|
76
|
+
*,
|
|
77
|
+
refresh: bool = True,
|
|
78
|
+
) -> dict[str, Any]:
|
|
79
|
+
return self._client.ensure_handle_live(handle, refresh=refresh)
|
|
80
|
+
|
|
81
|
+
def invalidate_call_cache(self, endpoint: str, *args: Any, **kwargs: Any) -> bool:
|
|
82
|
+
return self._client.invalidate_call_cache(endpoint, *args, **kwargs)
|
|
83
|
+
|
|
84
|
+
def invalidate_endpoint_cache(self, endpoint: str) -> int:
|
|
85
|
+
return self._client.invalidate_endpoint_cache(endpoint)
|
|
86
|
+
|
|
54
87
|
def wait(
|
|
55
88
|
self,
|
|
56
89
|
task_id: str,
|
|
@@ -7,6 +7,7 @@ import socket
|
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from typing import Any, cast
|
|
9
9
|
|
|
10
|
+
from shared_tensor.async_task import TaskStatus
|
|
10
11
|
from shared_tensor.errors import (
|
|
11
12
|
SharedTensorCapabilityError,
|
|
12
13
|
SharedTensorClientError,
|
|
@@ -16,12 +17,12 @@ from shared_tensor.errors import (
|
|
|
16
17
|
SharedTensorProtocolError,
|
|
17
18
|
SharedTensorRemoteError,
|
|
18
19
|
SharedTensorSerializationError,
|
|
20
|
+
SharedTensorStaleHandleError,
|
|
19
21
|
SharedTensorTaskError,
|
|
20
22
|
)
|
|
21
23
|
from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
|
|
22
24
|
from shared_tensor.runtime import get_local_server
|
|
23
25
|
from shared_tensor.transport import recv_message, send_message
|
|
24
|
-
from shared_tensor.async_task import TaskStatus
|
|
25
26
|
from shared_tensor.utils import (
|
|
26
27
|
deserialize_payload,
|
|
27
28
|
resolve_runtime_socket_path,
|
|
@@ -29,7 +30,6 @@ from shared_tensor.utils import (
|
|
|
29
30
|
validate_payload_for_transport,
|
|
30
31
|
)
|
|
31
32
|
|
|
32
|
-
|
|
33
33
|
logger = logging.getLogger(__name__)
|
|
34
34
|
|
|
35
35
|
|
|
@@ -41,6 +41,9 @@ class _ClientReleaser(ReleaseHandle):
|
|
|
41
41
|
def release(self) -> bool:
|
|
42
42
|
return self.client.release(self.object_id)
|
|
43
43
|
|
|
44
|
+
def get_object_info(self) -> dict[str, Any] | None:
|
|
45
|
+
return self.client.get_object_info(self.object_id)
|
|
46
|
+
|
|
44
47
|
|
|
45
48
|
class SharedTensorClient:
|
|
46
49
|
"""UDS client for endpoint-oriented local RPC execution."""
|
|
@@ -76,6 +79,8 @@ class SharedTensorClient:
|
|
|
76
79
|
code = 5
|
|
77
80
|
elif isinstance(exc, SharedTensorConfigurationError):
|
|
78
81
|
code = 6
|
|
82
|
+
elif isinstance(exc, SharedTensorStaleHandleError):
|
|
83
|
+
code = 8
|
|
79
84
|
else:
|
|
80
85
|
code = 7
|
|
81
86
|
return SharedTensorRemoteError(
|
|
@@ -105,6 +110,7 @@ class SharedTensorClient:
|
|
|
105
110
|
object_id=cast(str, object_id),
|
|
106
111
|
value=value,
|
|
107
112
|
_releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
|
|
113
|
+
server_id=self._infer_server_id(),
|
|
108
114
|
)
|
|
109
115
|
|
|
110
116
|
def _send_request(self, request: dict[str, Any]) -> Any:
|
|
@@ -158,6 +164,15 @@ class SharedTensorClient:
|
|
|
158
164
|
def _request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
|
159
165
|
return self._send_request({"method": method, "params": params or {}})
|
|
160
166
|
|
|
167
|
+
def _infer_server_id(self) -> str | None:
|
|
168
|
+
local_server = self._local_server()
|
|
169
|
+
if local_server is not None:
|
|
170
|
+
return cast(str | None, getattr(local_server, "server_id", None))
|
|
171
|
+
try:
|
|
172
|
+
return cast(str | None, self.get_server_info().get("server_id"))
|
|
173
|
+
except (SharedTensorClientError, SharedTensorRemoteError, SharedTensorProtocolError):
|
|
174
|
+
return None
|
|
175
|
+
|
|
161
176
|
def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
|
|
162
177
|
if self.verbose_debug:
|
|
163
178
|
logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
|
|
@@ -245,6 +260,25 @@ class SharedTensorClient:
|
|
|
245
260
|
result = self._request("get_object_info", {"object_id": object_id})
|
|
246
261
|
return cast(dict[str, Any] | None, result.get("object"))
|
|
247
262
|
|
|
263
|
+
def ensure_handle_live(self, handle: SharedObjectHandle[Any], *, refresh: bool = True) -> dict[str, Any]:
|
|
264
|
+
info = handle.get_object_info(refresh=refresh)
|
|
265
|
+
if info is None:
|
|
266
|
+
raise SharedTensorStaleHandleError(
|
|
267
|
+
f"Managed object '{handle.object_id}' is no longer registered on the producer",
|
|
268
|
+
object_id=handle.object_id,
|
|
269
|
+
server_id=handle.server_id,
|
|
270
|
+
reason="object_missing",
|
|
271
|
+
)
|
|
272
|
+
observed_server_id = cast(str | None, info.get("server_id"))
|
|
273
|
+
if handle.server_id is not None and observed_server_id is not None and observed_server_id != handle.server_id:
|
|
274
|
+
raise SharedTensorStaleHandleError(
|
|
275
|
+
f"Managed object '{handle.object_id}' belongs to server '{handle.server_id}' but producer now reports '{observed_server_id}'",
|
|
276
|
+
object_id=handle.object_id,
|
|
277
|
+
server_id=handle.server_id,
|
|
278
|
+
reason="server_mismatch",
|
|
279
|
+
)
|
|
280
|
+
return info
|
|
281
|
+
|
|
248
282
|
def ping(self) -> bool:
|
|
249
283
|
if self._local_server() is not None:
|
|
250
284
|
return True
|
|
@@ -260,6 +294,31 @@ class SharedTensorClient:
|
|
|
260
294
|
return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
|
|
261
295
|
return cast(dict[str, Any], self._request("get_server_info"))
|
|
262
296
|
|
|
297
|
+
def invalidate_call_cache(self, endpoint: str, *args: Any, **kwargs: Any) -> bool:
|
|
298
|
+
local_server = self._local_server()
|
|
299
|
+
if local_server is not None:
|
|
300
|
+
return self._run_local(
|
|
301
|
+
lambda: bool(local_server.invalidate_call_cache(endpoint, args=tuple(args), kwargs=dict(kwargs)))
|
|
302
|
+
)
|
|
303
|
+
encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
|
|
304
|
+
result = self._request(
|
|
305
|
+
"invalidate_call_cache",
|
|
306
|
+
{
|
|
307
|
+
"endpoint": endpoint,
|
|
308
|
+
"args_bytes": args_payload,
|
|
309
|
+
"kwargs_bytes": kwargs_payload,
|
|
310
|
+
"encoding": encoding,
|
|
311
|
+
},
|
|
312
|
+
)
|
|
313
|
+
return bool(result["invalidated"])
|
|
314
|
+
|
|
315
|
+
def invalidate_endpoint_cache(self, endpoint: str) -> int:
|
|
316
|
+
local_server = self._local_server()
|
|
317
|
+
if local_server is not None:
|
|
318
|
+
return self._run_local(lambda: int(local_server.invalidate_endpoint_cache(endpoint)))
|
|
319
|
+
result = self._request("invalidate_endpoint_cache", {"endpoint": endpoint})
|
|
320
|
+
return int(result["invalidated"])
|
|
321
|
+
|
|
263
322
|
def list_endpoints(self) -> dict[str, Any]:
|
|
264
323
|
local_server = self._local_server()
|
|
265
324
|
if local_server is not None:
|
|
@@ -344,4 +403,5 @@ class SharedTensorClient:
|
|
|
344
403
|
object_id=cast(str, object_id),
|
|
345
404
|
value=value,
|
|
346
405
|
_releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
|
|
406
|
+
server_id=cast(str | None, result.get("server_id")),
|
|
347
407
|
)
|
|
@@ -15,6 +15,7 @@ __all__ = [
|
|
|
15
15
|
"SharedTensorClientError",
|
|
16
16
|
"SharedTensorServerError",
|
|
17
17
|
"SharedTensorProviderError",
|
|
18
|
+
"SharedTensorStaleHandleError",
|
|
18
19
|
]
|
|
19
20
|
|
|
20
21
|
|
|
@@ -69,3 +70,20 @@ class SharedTensorServerError(SharedTensorError):
|
|
|
69
70
|
|
|
70
71
|
class SharedTensorProviderError(SharedTensorError):
|
|
71
72
|
"""Raised for provider registration or invocation problems."""
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SharedTensorStaleHandleError(SharedTensorError):
|
|
76
|
+
"""Raised when a managed handle can no longer be trusted."""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
message: str,
|
|
81
|
+
*,
|
|
82
|
+
object_id: str | None = None,
|
|
83
|
+
server_id: str | None = None,
|
|
84
|
+
reason: str | None = None,
|
|
85
|
+
) -> None:
|
|
86
|
+
super().__init__(message)
|
|
87
|
+
self.object_id = object_id
|
|
88
|
+
self.server_id = server_id
|
|
89
|
+
self.reason = reason
|
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import time
|
|
5
6
|
import uuid
|
|
6
|
-
from dataclasses import dataclass
|
|
7
|
+
from dataclasses import dataclass, field
|
|
7
8
|
from threading import RLock
|
|
8
9
|
from typing import Any, Generic, TypeVar
|
|
9
10
|
|
|
@@ -17,6 +18,8 @@ class ManagedObjectEntry:
|
|
|
17
18
|
endpoint: str
|
|
18
19
|
cache_key: str | None
|
|
19
20
|
refcount: int = 1
|
|
21
|
+
created_at: float = 0.0
|
|
22
|
+
last_accessed_at: float = 0.0
|
|
20
23
|
|
|
21
24
|
|
|
22
25
|
@dataclass(slots=True)
|
|
@@ -42,15 +45,19 @@ class ManagedObjectRegistry:
|
|
|
42
45
|
if entry is None:
|
|
43
46
|
self._cache_index.pop(cache_key, None)
|
|
44
47
|
return None
|
|
48
|
+
entry.last_accessed_at = time.time()
|
|
45
49
|
return entry
|
|
46
50
|
|
|
47
51
|
def register(self, *, endpoint: str, value: Any, cache_key: str | None) -> ManagedObjectEntry:
|
|
48
52
|
with self._lock:
|
|
53
|
+
now = time.time()
|
|
49
54
|
entry = ManagedObjectEntry(
|
|
50
55
|
object_id=uuid.uuid4().hex,
|
|
51
56
|
value=value,
|
|
52
57
|
endpoint=endpoint,
|
|
53
58
|
cache_key=cache_key,
|
|
59
|
+
created_at=now,
|
|
60
|
+
last_accessed_at=now,
|
|
54
61
|
)
|
|
55
62
|
self._entries[entry.object_id] = entry
|
|
56
63
|
if cache_key is not None:
|
|
@@ -67,6 +74,7 @@ class ManagedObjectRegistry:
|
|
|
67
74
|
if entry is None:
|
|
68
75
|
return None
|
|
69
76
|
entry.refcount += 1
|
|
77
|
+
entry.last_accessed_at = time.time()
|
|
70
78
|
return entry
|
|
71
79
|
|
|
72
80
|
def release(self, object_id: str) -> ManagedReleaseResult:
|
|
@@ -100,6 +108,31 @@ class ManagedObjectRegistry:
|
|
|
100
108
|
"endpoint": entry.endpoint,
|
|
101
109
|
"cache_key": entry.cache_key,
|
|
102
110
|
"refcount": entry.refcount,
|
|
111
|
+
"created_at": entry.created_at,
|
|
112
|
+
"last_accessed_at": entry.last_accessed_at,
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
def invalidate_cache_key(self, cache_key: str) -> bool:
|
|
116
|
+
with self._lock:
|
|
117
|
+
object_id = self._cache_index.pop(cache_key, None)
|
|
118
|
+
return object_id is not None
|
|
119
|
+
|
|
120
|
+
def invalidate_endpoint(self, endpoint: str) -> int:
|
|
121
|
+
with self._lock:
|
|
122
|
+
keys = [
|
|
123
|
+
cache_key
|
|
124
|
+
for cache_key, object_id in self._cache_index.items()
|
|
125
|
+
if (entry := self._entries.get(object_id)) is not None and entry.endpoint == endpoint
|
|
126
|
+
]
|
|
127
|
+
for cache_key in keys:
|
|
128
|
+
self._cache_index.pop(cache_key, None)
|
|
129
|
+
return len(keys)
|
|
130
|
+
|
|
131
|
+
def stats(self) -> dict[str, int]:
|
|
132
|
+
with self._lock:
|
|
133
|
+
return {
|
|
134
|
+
"objects": len(self._entries),
|
|
135
|
+
"cached_objects": len(self._cache_index),
|
|
103
136
|
}
|
|
104
137
|
|
|
105
138
|
def clear(self) -> None:
|
|
@@ -112,6 +145,9 @@ class ReleaseHandle:
|
|
|
112
145
|
def release(self) -> bool: # pragma: no cover - protocol surface only
|
|
113
146
|
raise NotImplementedError
|
|
114
147
|
|
|
148
|
+
def get_object_info(self) -> dict[str, Any] | None: # pragma: no cover - protocol surface only
|
|
149
|
+
raise NotImplementedError
|
|
150
|
+
|
|
115
151
|
|
|
116
152
|
@dataclass(slots=True)
|
|
117
153
|
class SharedObjectHandle(Generic[T]):
|
|
@@ -119,6 +155,8 @@ class SharedObjectHandle(Generic[T]):
|
|
|
119
155
|
value: T
|
|
120
156
|
_releaser: ReleaseHandle
|
|
121
157
|
released: bool = False
|
|
158
|
+
server_id: str | None = None
|
|
159
|
+
_metadata_cache: dict[str, Any] | None = field(default=None, init=False, repr=False)
|
|
122
160
|
|
|
123
161
|
def release(self) -> bool:
|
|
124
162
|
if self.released:
|
|
@@ -126,8 +164,19 @@ class SharedObjectHandle(Generic[T]):
|
|
|
126
164
|
released = self._releaser.release()
|
|
127
165
|
if released:
|
|
128
166
|
self.released = True
|
|
167
|
+
self._metadata_cache = None
|
|
129
168
|
return released
|
|
130
169
|
|
|
170
|
+
def get_object_info(self, *, refresh: bool = False) -> dict[str, Any] | None:
|
|
171
|
+
if self.released:
|
|
172
|
+
return None
|
|
173
|
+
if self._metadata_cache is None or refresh:
|
|
174
|
+
self._metadata_cache = self._releaser.get_object_info()
|
|
175
|
+
return None if self._metadata_cache is None else dict(self._metadata_cache)
|
|
176
|
+
|
|
177
|
+
def is_stale(self) -> bool:
|
|
178
|
+
return self.get_object_info(refresh=True) is None
|
|
179
|
+
|
|
131
180
|
def __enter__(self) -> SharedObjectHandle[T]:
|
|
132
181
|
return self
|
|
133
182
|
|
|
@@ -112,6 +112,7 @@ class SharedTensorProvider:
|
|
|
112
112
|
self._async_client: Any | None = None
|
|
113
113
|
self._server: Any | None = None
|
|
114
114
|
self._cache: dict[str, Any] = {}
|
|
115
|
+
self._cache_key_index: dict[str, str] = {}
|
|
115
116
|
self._endpoints: dict[str, EndpointDefinition] = {}
|
|
116
117
|
self._registered_functions = self._endpoints
|
|
117
118
|
self._lock = RLock()
|
|
@@ -263,6 +264,35 @@ class SharedTensorProvider:
|
|
|
263
264
|
def list_tasks(self, status: str | None = None) -> dict[str, Any]:
|
|
264
265
|
return self._get_async_client().list_tasks(status=status)
|
|
265
266
|
|
|
267
|
+
def invalidate_call_cache(self, endpoint: str, *args: Any, **kwargs: Any) -> bool:
|
|
268
|
+
if self.execution_mode == "server":
|
|
269
|
+
if self._server is not None and hasattr(self._server, "invalidate_call_cache"):
|
|
270
|
+
return bool(self._server.invalidate_call_cache(endpoint, args=args, kwargs=kwargs))
|
|
271
|
+
return False
|
|
272
|
+
if self.execution_mode == "local":
|
|
273
|
+
definition = self.get_endpoint(endpoint)
|
|
274
|
+
cache_key = self._cache_key_for(endpoint, definition, args, kwargs)
|
|
275
|
+
with self._lock:
|
|
276
|
+
removed = self._cache.pop(cache_key, None)
|
|
277
|
+
self._cache_key_index.pop(cache_key, None)
|
|
278
|
+
return removed is not None
|
|
279
|
+
return self._get_client().invalidate_call_cache(endpoint, *args, **kwargs)
|
|
280
|
+
|
|
281
|
+
def invalidate_endpoint_cache(self, endpoint: str) -> int:
|
|
282
|
+
if self.execution_mode == "server":
|
|
283
|
+
if self._server is not None and hasattr(self._server, "invalidate_endpoint_cache"):
|
|
284
|
+
return int(self._server.invalidate_endpoint_cache(endpoint))
|
|
285
|
+
return 0
|
|
286
|
+
if self.execution_mode == "local":
|
|
287
|
+
self.get_endpoint(endpoint)
|
|
288
|
+
with self._lock:
|
|
289
|
+
keys = [cache_key for cache_key, cache_endpoint in self._cache_key_index.items() if cache_endpoint == endpoint]
|
|
290
|
+
for cache_key in keys:
|
|
291
|
+
self._cache.pop(cache_key, None)
|
|
292
|
+
self._cache_key_index.pop(cache_key, None)
|
|
293
|
+
return len(keys)
|
|
294
|
+
return self._get_client().invalidate_endpoint_cache(endpoint)
|
|
295
|
+
|
|
266
296
|
def invoke_local(
|
|
267
297
|
self,
|
|
268
298
|
endpoint: str,
|
|
@@ -283,6 +313,7 @@ class SharedTensorProvider:
|
|
|
283
313
|
result = definition.func(*args, **resolved_kwargs)
|
|
284
314
|
with self._lock:
|
|
285
315
|
self._cache[cache_key] = result
|
|
316
|
+
self._cache_key_index[cache_key] = endpoint
|
|
286
317
|
return result
|
|
287
318
|
|
|
288
319
|
def get_endpoint(self, endpoint: str) -> EndpointDefinition:
|
|
@@ -328,6 +359,8 @@ class SharedTensorProvider:
|
|
|
328
359
|
"device_index": self.device_index,
|
|
329
360
|
"server_socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
|
|
330
361
|
"server_running": bool(server is not None and getattr(server, "running", True)),
|
|
362
|
+
"endpoint_count": len(self._endpoints),
|
|
363
|
+
"cache_entries": len(self._cache),
|
|
331
364
|
}
|
|
332
365
|
server_info = self._get_client().get_server_info()
|
|
333
366
|
return {
|
|
@@ -338,6 +371,7 @@ class SharedTensorProvider:
|
|
|
338
371
|
"server_socket_path": server_info.get("socket_path"),
|
|
339
372
|
"server_running": bool(server_info.get("running")),
|
|
340
373
|
"server_ready": bool(server_info.get("ready")),
|
|
374
|
+
"endpoint_count": len(self._endpoints),
|
|
341
375
|
"server_info": server_info,
|
|
342
376
|
}
|
|
343
377
|
|
|
@@ -7,6 +7,7 @@ import os
|
|
|
7
7
|
import socket
|
|
8
8
|
import threading
|
|
9
9
|
import time
|
|
10
|
+
import uuid
|
|
10
11
|
from concurrent.futures import Future
|
|
11
12
|
from dataclasses import dataclass, field
|
|
12
13
|
from typing import Any
|
|
@@ -18,6 +19,7 @@ from shared_tensor.errors import (
|
|
|
18
19
|
SharedTensorProtocolError,
|
|
19
20
|
SharedTensorProviderError,
|
|
20
21
|
SharedTensorSerializationError,
|
|
22
|
+
SharedTensorStaleHandleError,
|
|
21
23
|
SharedTensorTaskError,
|
|
22
24
|
)
|
|
23
25
|
from shared_tensor.managed_object import ManagedObjectRegistry
|
|
@@ -110,6 +112,7 @@ class SharedTensorServer:
|
|
|
110
112
|
self.startup_timeout = startup_timeout
|
|
111
113
|
self.listener: socket.socket | None = None
|
|
112
114
|
self.server_process: Any | None = None
|
|
115
|
+
self.server_id = uuid.uuid4().hex
|
|
113
116
|
self.server_thread: _ServerThreadState | None = None
|
|
114
117
|
self._resolved_process_start_method: str | None = None
|
|
115
118
|
self.running = False
|
|
@@ -117,9 +120,13 @@ class SharedTensorServer:
|
|
|
117
120
|
self.stats = {
|
|
118
121
|
"requests_processed": 0,
|
|
119
122
|
"errors_encountered": 0,
|
|
123
|
+
"cache_hits": 0,
|
|
124
|
+
"cache_misses": 0,
|
|
125
|
+
"task_submissions": 0,
|
|
126
|
+
"cache_invalidations": 0,
|
|
120
127
|
}
|
|
121
128
|
self._task_manager: TaskManager | None = None
|
|
122
|
-
self._cache: dict[str,
|
|
129
|
+
self._cache: dict[str, str] = {}
|
|
123
130
|
self._local_cache: dict[str, Any] = {}
|
|
124
131
|
self._managed_objects = ManagedObjectRegistry()
|
|
125
132
|
self._inflight: dict[str, _InFlightCall] = {}
|
|
@@ -182,6 +189,10 @@ class SharedTensorServer:
|
|
|
182
189
|
return self._handle_release_objects(params)
|
|
183
190
|
if method == "get_object_info":
|
|
184
191
|
return self._handle_get_object_info(params)
|
|
192
|
+
if method == "invalidate_call_cache":
|
|
193
|
+
return self._handle_invalidate_call_cache(params)
|
|
194
|
+
if method == "invalidate_endpoint_cache":
|
|
195
|
+
return self._handle_invalidate_endpoint_cache(params)
|
|
185
196
|
raise SharedTensorProtocolError(f"Unknown RPC method '{method}'")
|
|
186
197
|
|
|
187
198
|
def _handle_call(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
@@ -224,6 +235,7 @@ class SharedTensorServer:
|
|
|
224
235
|
args: tuple[Any, ...],
|
|
225
236
|
kwargs: dict[str, Any],
|
|
226
237
|
) -> Any:
|
|
238
|
+
self.stats["task_submissions"] += 1
|
|
227
239
|
return self._task_manager_instance().submit(
|
|
228
240
|
endpoint,
|
|
229
241
|
self._execute_endpoint_result,
|
|
@@ -243,9 +255,11 @@ class SharedTensorServer:
|
|
|
243
255
|
if cache_key is not None:
|
|
244
256
|
cached = self._lookup_cached_result_value(definition, cache_key)
|
|
245
257
|
if cached is not None:
|
|
258
|
+
self.stats["cache_hits"] += 1
|
|
246
259
|
if self.verbose_debug:
|
|
247
260
|
logger.debug("Server cache hit", extra={"endpoint": endpoint, "cache_key": cache_key})
|
|
248
261
|
return cached
|
|
262
|
+
self.stats["cache_misses"] += 1
|
|
249
263
|
|
|
250
264
|
inflight_key = cache_key if cache_key is not None and definition.singleflight else None
|
|
251
265
|
if inflight_key is not None:
|
|
@@ -317,6 +331,7 @@ class SharedTensorServer:
|
|
|
317
331
|
if cache_key is not None:
|
|
318
332
|
with self._coordination_lock:
|
|
319
333
|
self._local_cache[cache_key] = value
|
|
334
|
+
self._cache[cache_key] = endpoint
|
|
320
335
|
return _EndpointResult(value=value)
|
|
321
336
|
|
|
322
337
|
def _materialize_managed_result(
|
|
@@ -337,6 +352,9 @@ class SharedTensorServer:
|
|
|
337
352
|
if self.verbose_debug:
|
|
338
353
|
logger.debug("Server created managed object", extra={"endpoint": endpoint, "cache_key": cache_key})
|
|
339
354
|
entry = self._managed_objects.register(endpoint=endpoint, value=result, cache_key=cache_key)
|
|
355
|
+
if cache_key is not None:
|
|
356
|
+
with self._coordination_lock:
|
|
357
|
+
self._cache[cache_key] = endpoint
|
|
340
358
|
return _EndpointResult(value=entry.value, object_id=entry.object_id)
|
|
341
359
|
|
|
342
360
|
def _lookup_cached_result_value(
|
|
@@ -353,9 +371,9 @@ class SharedTensorServer:
|
|
|
353
371
|
self._managed_objects.add_ref(cached.object_id)
|
|
354
372
|
return _EndpointResult(value=cached.value, object_id=cached.object_id)
|
|
355
373
|
with self._coordination_lock:
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
374
|
+
if cache_key not in self._local_cache:
|
|
375
|
+
return None
|
|
376
|
+
local_value = self._local_cache[cache_key]
|
|
359
377
|
return _EndpointResult(value=local_value)
|
|
360
378
|
|
|
361
379
|
def call_local_client(
|
|
@@ -413,22 +431,30 @@ class SharedTensorServer:
|
|
|
413
431
|
if cache_key is not None:
|
|
414
432
|
cached = self._managed_objects.get_cached(cache_key)
|
|
415
433
|
if cached is not None:
|
|
434
|
+
self.stats["cache_hits"] += 1
|
|
416
435
|
return cached.value
|
|
436
|
+
self.stats["cache_misses"] += 1
|
|
417
437
|
value = definition.func(*args, **resolved_kwargs)
|
|
418
438
|
if cache_key is not None:
|
|
419
439
|
existing = self._managed_objects.get_cached(cache_key)
|
|
420
440
|
if existing is not None:
|
|
441
|
+
self.stats["cache_hits"] += 1
|
|
421
442
|
return existing.value
|
|
422
443
|
self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
|
|
444
|
+
with self._coordination_lock:
|
|
445
|
+
self._cache[cache_key] = endpoint
|
|
423
446
|
return value
|
|
424
447
|
if cache_key is not None:
|
|
425
448
|
with self._coordination_lock:
|
|
426
449
|
if cache_key in self._local_cache:
|
|
450
|
+
self.stats["cache_hits"] += 1
|
|
427
451
|
return self._local_cache[cache_key]
|
|
452
|
+
self.stats["cache_misses"] += 1
|
|
428
453
|
value = definition.func(*args, **resolved_kwargs)
|
|
429
454
|
if cache_key is not None:
|
|
430
455
|
with self._coordination_lock:
|
|
431
456
|
self._local_cache[cache_key] = value
|
|
457
|
+
self._cache[cache_key] = endpoint
|
|
432
458
|
return value
|
|
433
459
|
|
|
434
460
|
def _cache_key(
|
|
@@ -498,7 +524,21 @@ class SharedTensorServer:
|
|
|
498
524
|
|
|
499
525
|
def _handle_get_object_info(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
500
526
|
object_id = self._require_object_id(params)
|
|
501
|
-
|
|
527
|
+
info = self._managed_objects.info(object_id)
|
|
528
|
+
if info is None:
|
|
529
|
+
return {"object": None}
|
|
530
|
+
return {"object": {**info, "server_id": self.server_id}}
|
|
531
|
+
|
|
532
|
+
def _handle_invalidate_call_cache(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
533
|
+
endpoint, args, kwargs = self._decode_call_params(params)
|
|
534
|
+
removed = self.invalidate_call_cache(endpoint, args=args, kwargs=kwargs)
|
|
535
|
+
return {"invalidated": removed}
|
|
536
|
+
|
|
537
|
+
def _handle_invalidate_endpoint_cache(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
538
|
+
endpoint = params.get("endpoint")
|
|
539
|
+
if not isinstance(endpoint, str) or not endpoint:
|
|
540
|
+
raise SharedTensorProtocolError("Missing required parameter 'endpoint'")
|
|
541
|
+
return {"invalidated": self.invalidate_endpoint_cache(endpoint)}
|
|
502
542
|
|
|
503
543
|
def _decode_call_params(self, params: dict[str, Any]) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
|
|
504
544
|
endpoint = params.get("endpoint")
|
|
@@ -523,14 +563,34 @@ class SharedTensorServer:
|
|
|
523
563
|
validate_call_payload_for_transport(kwargs, allow_dict_keys=True)
|
|
524
564
|
return endpoint, args, kwargs
|
|
525
565
|
|
|
526
|
-
def _encode_result(
|
|
566
|
+
def _encode_result(
|
|
567
|
+
self,
|
|
568
|
+
value: Any,
|
|
569
|
+
*,
|
|
570
|
+
object_id: str | None = None,
|
|
571
|
+
server_id: str | None = None,
|
|
572
|
+
) -> dict[str, Any]:
|
|
527
573
|
if value is None:
|
|
528
|
-
return {
|
|
574
|
+
return {
|
|
575
|
+
"encoding": None,
|
|
576
|
+
"payload_bytes": None,
|
|
577
|
+
"object_id": object_id,
|
|
578
|
+
"server_id": server_id,
|
|
579
|
+
}
|
|
529
580
|
encoding, payload = serialize_payload(value)
|
|
530
|
-
return {
|
|
581
|
+
return {
|
|
582
|
+
"encoding": encoding,
|
|
583
|
+
"payload_bytes": payload,
|
|
584
|
+
"object_id": object_id,
|
|
585
|
+
"server_id": server_id,
|
|
586
|
+
}
|
|
531
587
|
|
|
532
588
|
def _encode_endpoint_result(self, result: _EndpointResult) -> dict[str, Any]:
|
|
533
|
-
return self._encode_result(
|
|
589
|
+
return self._encode_result(
|
|
590
|
+
result.value,
|
|
591
|
+
object_id=result.object_id,
|
|
592
|
+
server_id=self.server_id if result.object_id is not None else None,
|
|
593
|
+
)
|
|
534
594
|
|
|
535
595
|
def _task_manager_instance(self) -> TaskManager:
|
|
536
596
|
if self._task_manager is None:
|
|
@@ -540,6 +600,44 @@ class SharedTensorServer:
|
|
|
540
600
|
)
|
|
541
601
|
return self._task_manager
|
|
542
602
|
|
|
603
|
+
def invalidate_call_cache(
|
|
604
|
+
self,
|
|
605
|
+
endpoint: str,
|
|
606
|
+
*,
|
|
607
|
+
args: tuple[Any, ...] = (),
|
|
608
|
+
kwargs: dict[str, Any] | None = None,
|
|
609
|
+
) -> bool:
|
|
610
|
+
definition = self.provider.get_endpoint(endpoint)
|
|
611
|
+
resolved_kwargs = kwargs or {}
|
|
612
|
+
cache_key = self._cache_key(endpoint, definition, args, resolved_kwargs)
|
|
613
|
+
if cache_key is None:
|
|
614
|
+
return False
|
|
615
|
+
invalidated_managed = False
|
|
616
|
+
if definition.managed:
|
|
617
|
+
invalidated_managed = self._managed_objects.invalidate_cache_key(cache_key)
|
|
618
|
+
with self._coordination_lock:
|
|
619
|
+
removed = self._local_cache.pop(cache_key, None)
|
|
620
|
+
self._cache.pop(cache_key, None)
|
|
621
|
+
invalidated = invalidated_managed or removed is not None
|
|
622
|
+
if invalidated:
|
|
623
|
+
self.stats["cache_invalidations"] += 1
|
|
624
|
+
return invalidated
|
|
625
|
+
|
|
626
|
+
def invalidate_endpoint_cache(self, endpoint: str) -> int:
|
|
627
|
+
self.provider.get_endpoint(endpoint)
|
|
628
|
+
removed = 0
|
|
629
|
+
with self._coordination_lock:
|
|
630
|
+
keys = [cache_key for cache_key, cache_endpoint in self._cache.items() if cache_endpoint == endpoint]
|
|
631
|
+
for cache_key in keys:
|
|
632
|
+
self._cache.pop(cache_key, None)
|
|
633
|
+
if cache_key in self._local_cache:
|
|
634
|
+
self._local_cache.pop(cache_key, None)
|
|
635
|
+
removed += 1
|
|
636
|
+
removed += self._managed_objects.invalidate_endpoint(endpoint)
|
|
637
|
+
if removed:
|
|
638
|
+
self.stats["cache_invalidations"] += removed
|
|
639
|
+
return removed
|
|
640
|
+
|
|
543
641
|
@staticmethod
|
|
544
642
|
def _require_task_id(params: dict[str, Any]) -> str:
|
|
545
643
|
task_id = params.get("task_id")
|
|
@@ -559,6 +657,7 @@ class SharedTensorServer:
|
|
|
559
657
|
return {
|
|
560
658
|
"server": "SharedTensorServer",
|
|
561
659
|
"version": _server_version(),
|
|
660
|
+
"server_id": self.server_id,
|
|
562
661
|
"socket_path": self.socket_path,
|
|
563
662
|
"uptime": uptime,
|
|
564
663
|
"running": self.running,
|
|
@@ -567,7 +666,13 @@ class SharedTensorServer:
|
|
|
567
666
|
"ppid": os.getppid(),
|
|
568
667
|
"device_index": resolve_device_index(self.provider.device_index),
|
|
569
668
|
"process_start_method": self._resolved_process_start_method,
|
|
570
|
-
"stats":
|
|
669
|
+
"stats": {
|
|
670
|
+
**dict(self.stats),
|
|
671
|
+
"cache_entries": len(self._local_cache),
|
|
672
|
+
"inflight_calls": len(self._inflight),
|
|
673
|
+
**self._managed_objects.stats(),
|
|
674
|
+
"task_count": 0 if self._task_manager is None else len(self._task_manager.list()),
|
|
675
|
+
},
|
|
571
676
|
"capabilities": capability_snapshot(),
|
|
572
677
|
"endpoints": list(self.provider.list_endpoints().keys()),
|
|
573
678
|
}
|
|
@@ -737,4 +842,6 @@ class SharedTensorServer:
|
|
|
737
842
|
return 5
|
|
738
843
|
if isinstance(exc, SharedTensorConfigurationError):
|
|
739
844
|
return 6
|
|
845
|
+
if isinstance(exc, SharedTensorStaleHandleError):
|
|
846
|
+
return 8
|
|
740
847
|
return 7
|
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
"""Deprecated compatibility shim for task-oriented provider usage."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from collections.abc import Callable
|
|
6
|
-
from functools import wraps
|
|
7
|
-
from typing import Any, cast
|
|
8
|
-
|
|
9
|
-
from shared_tensor.provider import SharedTensorProvider
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class AsyncSharedTensorProvider(SharedTensorProvider):
|
|
13
|
-
def register(
|
|
14
|
-
self,
|
|
15
|
-
func: Callable[..., Any],
|
|
16
|
-
*,
|
|
17
|
-
cache: bool = True,
|
|
18
|
-
cache_format_key: str | None = None,
|
|
19
|
-
managed: bool = False,
|
|
20
|
-
async_default_wait: bool = True,
|
|
21
|
-
execution: str = "task",
|
|
22
|
-
concurrency: str = "parallel",
|
|
23
|
-
singleflight: bool = True,
|
|
24
|
-
wait: bool | None = None,
|
|
25
|
-
) -> Callable[..., Any]:
|
|
26
|
-
resolved_wait = async_default_wait if wait is None else wait
|
|
27
|
-
registered = super().register(
|
|
28
|
-
func,
|
|
29
|
-
cache=cache,
|
|
30
|
-
cache_format_key=cache_format_key,
|
|
31
|
-
managed=managed,
|
|
32
|
-
async_default_wait=resolved_wait,
|
|
33
|
-
execution=cast(Any, execution),
|
|
34
|
-
concurrency=cast(Any, concurrency),
|
|
35
|
-
singleflight=singleflight,
|
|
36
|
-
)
|
|
37
|
-
if self.execution_mode in {"server", "local"}:
|
|
38
|
-
return registered
|
|
39
|
-
|
|
40
|
-
endpoint_name = func.__name__
|
|
41
|
-
|
|
42
|
-
@wraps(func)
|
|
43
|
-
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
44
|
-
if resolved_wait:
|
|
45
|
-
return self.call(endpoint_name, *args, **kwargs)
|
|
46
|
-
return self.submit(endpoint_name, *args, **kwargs)
|
|
47
|
-
|
|
48
|
-
wrapped = cast(Any, wrapper)
|
|
49
|
-
wrapped.submit_async = lambda *args, **kwargs: self.submit(endpoint_name, *args, **kwargs)
|
|
50
|
-
wrapped.execute_async = lambda *args, wait=resolved_wait, timeout=None, callback=None, **kwargs: self.execute(
|
|
51
|
-
endpoint_name,
|
|
52
|
-
*args,
|
|
53
|
-
wait=wait,
|
|
54
|
-
timeout=timeout,
|
|
55
|
-
callback=callback,
|
|
56
|
-
**kwargs,
|
|
57
|
-
)
|
|
58
|
-
return cast(Callable[..., Any], wrapped)
|
|
59
|
-
|
|
60
|
-
def share(
|
|
61
|
-
self,
|
|
62
|
-
func: Callable[..., Any] | None = None,
|
|
63
|
-
*,
|
|
64
|
-
cache: bool = True,
|
|
65
|
-
cache_format_key: str | None = None,
|
|
66
|
-
managed: bool = False,
|
|
67
|
-
execution: str = "task",
|
|
68
|
-
concurrency: str = "parallel",
|
|
69
|
-
singleflight: bool = True,
|
|
70
|
-
wait: bool | None = None,
|
|
71
|
-
**_: Any,
|
|
72
|
-
) -> Callable[[Callable[..., Any]], Callable[..., Any]] | Callable[..., Any]:
|
|
73
|
-
if func is not None:
|
|
74
|
-
return self.register(
|
|
75
|
-
func,
|
|
76
|
-
cache=cache,
|
|
77
|
-
cache_format_key=cache_format_key,
|
|
78
|
-
managed=managed,
|
|
79
|
-
execution=execution,
|
|
80
|
-
concurrency=concurrency,
|
|
81
|
-
singleflight=singleflight,
|
|
82
|
-
wait=wait,
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
def decorator(inner: Callable[..., Any]) -> Callable[..., Any]:
|
|
86
|
-
return self.register(
|
|
87
|
-
inner,
|
|
88
|
-
cache=cache,
|
|
89
|
-
cache_format_key=cache_format_key,
|
|
90
|
-
managed=managed,
|
|
91
|
-
execution=execution,
|
|
92
|
-
concurrency=concurrency,
|
|
93
|
-
singleflight=singleflight,
|
|
94
|
-
wait=wait,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
return decorator
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|