shared-tensor 0.2.8__tar.gz → 0.2.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
@@ -21,15 +21,16 @@ Classifier: Programming Language :: Python :: 3.10
21
21
  Classifier: Programming Language :: Python :: 3.11
22
22
  Classifier: Programming Language :: Python :: 3.12
23
23
  Classifier: Programming Language :: Python :: 3.13
24
+ Classifier: Programming Language :: Python :: 3.14
24
25
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
25
26
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
27
  Classifier: Topic :: System :: Distributed Computing
27
- Requires-Python: <3.14,>=3.9
28
+ Requires-Python: >=3.9
28
29
  Description-Content-Type: text/markdown
29
30
  License-File: LICENSE
30
31
  Requires-Dist: cloudpickle>=3.0.0
31
32
  Requires-Dist: numpy<2
32
- Requires-Dist: torch<2.8,>=2.1
33
+ Requires-Dist: torch>=2.1
33
34
  Provides-Extra: dev
34
35
  Requires-Dist: pytest>=6.0; extra == "dev"
35
36
  Requires-Dist: pytest-cov>=2.0; extra == "dev"
@@ -64,7 +65,7 @@ Supported:
64
65
  - explicit endpoint registration
65
66
  - sync `call` and task-backed `submit`
66
67
  - managed object handles with explicit release
67
- - server-side caching, `cache_format_key`, and singleflight
68
+ - server-side caching, `cache_format_key`, singleflight, and explicit cache invalidation
68
69
  - manual two-process deployment as the primary production path
69
70
  - zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
70
71
 
@@ -107,6 +108,8 @@ Production should prefer two explicitly started processes: one server process th
107
108
 
108
109
  See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
109
110
 
111
+ The server-oriented example modules construct providers with explicit `execution_mode="server"` so importing the module already reflects the intended deployment role.
112
+
110
113
  Server process:
111
114
 
112
115
  ```python
@@ -147,10 +150,32 @@ executes endpoint functions reopens CUDA objects via torch IPC
147
150
  manages cache and refcounts releases managed handles explicitly
148
151
  ```
149
152
 
153
+ ## Lifetime And Failure Contract
154
+
155
+ `shared_tensor` follows native PyTorch CUDA IPC semantics. It does not virtualize or harden producer lifetime.
156
+
157
+ Core assumption:
158
+ - the server process that owns the original CUDA allocation must stay alive while clients are still using reopened CUDA tensors or modules
159
+ - handle health checks can detect some stale-object conditions, but they do not remove the producer-liveness requirement
160
+
161
+ If the server exits, crashes, or is killed before the client is done with the shared CUDA object, behavior is no longer guaranteed by this library. Depending on PyTorch and CUDA runtime state, the client may see CUDA runtime errors, invalid resource handle failures, broken module execution, or process-level instability.
162
+
163
+ So the production contract is:
164
+ - client-side handles are only valid while the producer process remains alive
165
+ - `handle.release()` is explicit lifecycle cleanup, not durability
166
+ - this library does not promise survivability across producer death
167
+
168
+ Treat producer liveness as a hard requirement, not a soft optimization.
169
+
150
170
  ## Example: Same Code, Two Processes
151
171
 
152
172
  See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
153
173
 
174
+ Resolution rule:
175
+ - `SHARED_TENSOR_ENABLED` unset or false: provider stays local
176
+ - `SHARED_TENSOR_ENABLED=1` and `SHARED_TENSOR_ROLE=server`: provider resolves to server and auto-starts the thread-backed local server
177
+ - `SHARED_TENSOR_ENABLED=1` and role unset or `client`: provider resolves to client
178
+
154
179
  ```bash
155
180
  SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
156
181
  SHARED_TENSOR_ENABLED=1 python demo.py
@@ -168,6 +193,27 @@ shared function runs locally shared function becomes RPC call
168
193
  CUDA object stays on same GPU CUDA object is reopened via torch IPC
169
194
  ```
170
195
 
196
+ ## Example: Task Submission And Wait
197
+
198
+ See [examples/async_service.py](./examples/async_service.py).
199
+
200
+ ```python
201
+ from shared_tensor import AsyncSharedTensorClient, SharedTensorProvider
202
+
203
+ provider = SharedTensorProvider(execution_mode="server")
204
+
205
+ @provider.share(execution="task")
206
+ def build_delayed_model(delay: float = 0.1):
207
+ ...
208
+
209
+ client = AsyncSharedTensorClient()
210
+ task_id = client.submit("build_delayed_model", delay=0.1)
211
+ model = client.wait_for_task(task_id, timeout=30)
212
+ ```
213
+
214
+ Use `SharedTensorProvider(execution="task")` for task-backed endpoints.
215
+ Use `AsyncSharedTensorClient` when you want a task-oriented waiting interface.
216
+
171
217
  ## Example: Reusable Model Registry
172
218
 
173
219
  See [examples/model_service.py](./examples/model_service.py).
@@ -278,11 +324,47 @@ handle.release()
278
324
  ```
279
325
 
280
326
  Use managed mode for cached models or other reusable long-lived CUDA objects.
327
+ Managed object introspection now includes `created_at` and `last_accessed_at` timestamps through `get_object_info()`.
328
+
329
+ ## Cache Invalidation
330
+
331
+ The library now exposes explicit cache invalidation instead of forcing process restarts when a cached object becomes stale.
332
+
333
+ ```python
334
+ provider.invalidate_call_cache("load_model", hidden_size=4096)
335
+ provider.invalidate_endpoint_cache("load_model")
336
+ ```
337
+
338
+ Client-side equivalents are also available:
339
+
340
+ ```python
341
+ client.invalidate_call_cache("load_model", hidden_size=4096)
342
+ client.invalidate_endpoint_cache("load_model")
343
+ ```
344
+
345
+ Use call-level invalidation when you want to evict one cache key.
346
+ Use endpoint-level invalidation when you want to drop all cached variants for the endpoint.
347
+ Invalidation removes cache lookup entries; it does not guarantee that already-issued client handles remain valid after producer death.
348
+
349
+
350
+ ## Handle Health Checks
351
+
352
+ Managed handles now carry the producer `server_id` and support lightweight liveness probes:
353
+
354
+ ```python
355
+ handle = client.call("load_model", hidden_size=4096)
356
+ info = handle.get_object_info()
357
+ client.ensure_handle_live(handle)
358
+ ```
359
+
360
+ If the producer no longer owns the object, `client.ensure_handle_live(handle)` raises `SharedTensorStaleHandleError`.
361
+ This is still advisory, not a durability guarantee: it helps detect stale handles earlier, but it cannot make producer death safe.
281
362
 
282
363
  ## Runtime Introspection
283
364
 
284
- `client.get_server_info()` now returns readiness and process metadata in addition to endpoint and capability data.
365
+ `client.get_server_info()` now returns readiness, stable `server_id`, cache/task counters, and process metadata in addition to endpoint and capability data.
285
366
  In client mode, `provider.get_runtime_info()` wraps that into a provider-oriented view.
367
+ `AsyncSharedTensorClient` exposes the same runtime, cache invalidation, release, and handle-health helper methods as `SharedTensorClient`; the async surface is task-oriented, not capability-reduced.
286
368
 
287
369
  ```python
288
370
  info = provider.get_runtime_info()
@@ -12,7 +12,7 @@ Supported:
12
12
  - explicit endpoint registration
13
13
  - sync `call` and task-backed `submit`
14
14
  - managed object handles with explicit release
15
- - server-side caching, `cache_format_key`, and singleflight
15
+ - server-side caching, `cache_format_key`, singleflight, and explicit cache invalidation
16
16
  - manual two-process deployment as the primary production path
17
17
  - zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
18
18
 
@@ -55,6 +55,8 @@ Production should prefer two explicitly started processes: one server process th
55
55
 
56
56
  See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
57
57
 
58
+ The server-oriented example modules construct providers with explicit `execution_mode="server"` so importing the module already reflects the intended deployment role.
59
+
58
60
  Server process:
59
61
 
60
62
  ```python
@@ -95,10 +97,32 @@ executes endpoint functions reopens CUDA objects via torch IPC
95
97
  manages cache and refcounts releases managed handles explicitly
96
98
  ```
97
99
 
100
+ ## Lifetime And Failure Contract
101
+
102
+ `shared_tensor` follows native PyTorch CUDA IPC semantics. It does not virtualize or harden producer lifetime.
103
+
104
+ Core assumption:
105
+ - the server process that owns the original CUDA allocation must stay alive while clients are still using reopened CUDA tensors or modules
106
+ - handle health checks can detect some stale-object conditions, but they do not remove the producer-liveness requirement
107
+
108
+ If the server exits, crashes, or is killed before the client is done with the shared CUDA object, behavior is no longer guaranteed by this library. Depending on PyTorch and CUDA runtime state, the client may see CUDA runtime errors, invalid resource handle failures, broken module execution, or process-level instability.
109
+
110
+ So the production contract is:
111
+ - client-side handles are only valid while the producer process remains alive
112
+ - `handle.release()` is explicit lifecycle cleanup, not durability
113
+ - this library does not promise survivability across producer death
114
+
115
+ Treat producer liveness as a hard requirement, not a soft optimization.
116
+
98
117
  ## Example: Same Code, Two Processes
99
118
 
100
119
  See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
101
120
 
121
+ Resolution rule:
122
+ - `SHARED_TENSOR_ENABLED` unset or false: provider stays local
123
+ - `SHARED_TENSOR_ENABLED=1` and `SHARED_TENSOR_ROLE=server`: provider resolves to server and auto-starts the thread-backed local server
124
+ - `SHARED_TENSOR_ENABLED=1` and role unset or `client`: provider resolves to client
125
+
102
126
  ```bash
103
127
  SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
104
128
  SHARED_TENSOR_ENABLED=1 python demo.py
@@ -116,6 +140,27 @@ shared function runs locally shared function becomes RPC call
116
140
  CUDA object stays on same GPU CUDA object is reopened via torch IPC
117
141
  ```
118
142
 
143
+ ## Example: Task Submission And Wait
144
+
145
+ See [examples/async_service.py](./examples/async_service.py).
146
+
147
+ ```python
148
+ from shared_tensor import AsyncSharedTensorClient, SharedTensorProvider
149
+
150
+ provider = SharedTensorProvider(execution_mode="server")
151
+
152
+ @provider.share(execution="task")
153
+ def build_delayed_model(delay: float = 0.1):
154
+ ...
155
+
156
+ client = AsyncSharedTensorClient()
157
+ task_id = client.submit("build_delayed_model", delay=0.1)
158
+ model = client.wait_for_task(task_id, timeout=30)
159
+ ```
160
+
161
+ Use `SharedTensorProvider(execution="task")` for task-backed endpoints.
162
+ Use `AsyncSharedTensorClient` when you want a task-oriented waiting interface.
163
+
119
164
  ## Example: Reusable Model Registry
120
165
 
121
166
  See [examples/model_service.py](./examples/model_service.py).
@@ -226,11 +271,47 @@ handle.release()
226
271
  ```
227
272
 
228
273
  Use managed mode for cached models or other reusable long-lived CUDA objects.
274
+ Managed object introspection now includes `created_at` and `last_accessed_at` timestamps through `get_object_info()`.
275
+
276
+ ## Cache Invalidation
277
+
278
+ The library now exposes explicit cache invalidation instead of forcing process restarts when a cached object becomes stale.
279
+
280
+ ```python
281
+ provider.invalidate_call_cache("load_model", hidden_size=4096)
282
+ provider.invalidate_endpoint_cache("load_model")
283
+ ```
284
+
285
+ Client-side equivalents are also available:
286
+
287
+ ```python
288
+ client.invalidate_call_cache("load_model", hidden_size=4096)
289
+ client.invalidate_endpoint_cache("load_model")
290
+ ```
291
+
292
+ Use call-level invalidation when you want to evict one cache key.
293
+ Use endpoint-level invalidation when you want to drop all cached variants for the endpoint.
294
+ Invalidation removes cache lookup entries; it does not guarantee that already-issued client handles remain valid after producer death.
295
+
296
+
297
+ ## Handle Health Checks
298
+
299
+ Managed handles now carry the producer `server_id` and support lightweight liveness probes:
300
+
301
+ ```python
302
+ handle = client.call("load_model", hidden_size=4096)
303
+ info = handle.get_object_info()
304
+ client.ensure_handle_live(handle)
305
+ ```
306
+
307
+ If the producer no longer owns the object, `client.ensure_handle_live(handle)` raises `SharedTensorStaleHandleError`.
308
+ This is still advisory, not a durability guarantee: it helps detect stale handles earlier, but it cannot make producer death safe.
229
309
 
230
310
  ## Runtime Introspection
231
311
 
232
- `client.get_server_info()` now returns readiness and process metadata in addition to endpoint and capability data.
312
+ `client.get_server_info()` now returns readiness, stable `server_id`, cache/task counters, and process metadata in addition to endpoint and capability data.
233
313
  In client mode, `provider.get_runtime_info()` wraps that into a provider-oriented view.
314
+ `AsyncSharedTensorClient` exposes the same runtime, cache invalidation, release, and handle-health helper methods as `SharedTensorClient`; the async surface is task-oriented, not capability-reduced.
234
315
 
235
316
  ```python
236
317
  info = provider.get_runtime_info()
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.8"
7
+ version = "0.2.10"
8
8
  description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -38,15 +38,16 @@ classifiers = [
38
38
  "Programming Language :: Python :: 3.11",
39
39
  "Programming Language :: Python :: 3.12",
40
40
  "Programming Language :: Python :: 3.13",
41
+ "Programming Language :: Python :: 3.14",
41
42
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
42
43
  "Topic :: Software Development :: Libraries :: Python Modules",
43
44
  "Topic :: System :: Distributed Computing",
44
45
  ]
45
- requires-python = ">=3.9,<3.14"
46
+ requires-python = ">=3.9"
46
47
  dependencies = [
47
48
  "cloudpickle>=3.0.0",
48
49
  "numpy<2",
49
- "torch>=2.1,<2.8",
50
+ "torch>=2.1",
50
51
  ]
51
52
 
52
53
  [project.optional-dependencies]
@@ -1,22 +1,22 @@
1
1
  """shared_tensor: same-host same-GPU PyTorch CUDA IPC over local UDS RPC."""
2
2
 
3
3
  from shared_tensor.async_client import AsyncSharedTensorClient
4
- from shared_tensor.async_provider import AsyncSharedTensorProvider
5
4
  from shared_tensor.async_task import TaskInfo, TaskStatus
6
5
  from shared_tensor.client import SharedTensorClient
6
+ from shared_tensor.errors import SharedTensorStaleHandleError
7
7
  from shared_tensor.managed_object import SharedObjectHandle
8
8
  from shared_tensor.provider import SharedTensorProvider
9
9
  from shared_tensor.server import SharedTensorServer
10
10
 
11
11
  __all__ = [
12
12
  "AsyncSharedTensorClient",
13
- "AsyncSharedTensorProvider",
14
13
  "SharedTensorClient",
15
14
  "SharedObjectHandle",
15
+ "SharedTensorStaleHandleError",
16
16
  "SharedTensorProvider",
17
17
  "SharedTensorServer",
18
18
  "TaskInfo",
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.8"
22
+ __version__ = "0.2.10"
@@ -10,6 +10,7 @@ from typing import Any, cast
10
10
  from shared_tensor.async_task import TaskInfo, TaskStatus
11
11
  from shared_tensor.client import SharedTensorClient
12
12
  from shared_tensor.errors import SharedTensorRemoteError, SharedTensorTaskError
13
+ from shared_tensor.managed_object import SharedObjectHandle
13
14
 
14
15
 
15
16
  logger = logging.getLogger(__name__)
@@ -51,6 +52,38 @@ class AsyncSharedTensorClient:
51
52
  def get_task_result(self, task_id: str) -> Any:
52
53
  return self.result(task_id)
53
54
 
55
+ def ping(self) -> bool:
56
+ return self._client.ping()
57
+
58
+ def get_server_info(self) -> dict[str, Any]:
59
+ return self._client.get_server_info()
60
+
61
+ def list_endpoints(self) -> dict[str, Any]:
62
+ return self._client.list_endpoints()
63
+
64
+ def release(self, object_id: str) -> bool:
65
+ return self._client.release(object_id)
66
+
67
+ def release_many(self, object_ids: list[str]) -> dict[str, bool]:
68
+ return self._client.release_many(object_ids)
69
+
70
+ def get_object_info(self, object_id: str) -> dict[str, Any] | None:
71
+ return self._client.get_object_info(object_id)
72
+
73
+ def ensure_handle_live(
74
+ self,
75
+ handle: SharedObjectHandle[Any],
76
+ *,
77
+ refresh: bool = True,
78
+ ) -> dict[str, Any]:
79
+ return self._client.ensure_handle_live(handle, refresh=refresh)
80
+
81
+ def invalidate_call_cache(self, endpoint: str, *args: Any, **kwargs: Any) -> bool:
82
+ return self._client.invalidate_call_cache(endpoint, *args, **kwargs)
83
+
84
+ def invalidate_endpoint_cache(self, endpoint: str) -> int:
85
+ return self._client.invalidate_endpoint_cache(endpoint)
86
+
54
87
  def wait(
55
88
  self,
56
89
  task_id: str,
@@ -7,6 +7,7 @@ import socket
7
7
  from dataclasses import dataclass
8
8
  from typing import Any, cast
9
9
 
10
+ from shared_tensor.async_task import TaskStatus
10
11
  from shared_tensor.errors import (
11
12
  SharedTensorCapabilityError,
12
13
  SharedTensorClientError,
@@ -16,12 +17,12 @@ from shared_tensor.errors import (
16
17
  SharedTensorProtocolError,
17
18
  SharedTensorRemoteError,
18
19
  SharedTensorSerializationError,
20
+ SharedTensorStaleHandleError,
19
21
  SharedTensorTaskError,
20
22
  )
21
23
  from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
22
24
  from shared_tensor.runtime import get_local_server
23
25
  from shared_tensor.transport import recv_message, send_message
24
- from shared_tensor.async_task import TaskStatus
25
26
  from shared_tensor.utils import (
26
27
  deserialize_payload,
27
28
  resolve_runtime_socket_path,
@@ -29,7 +30,6 @@ from shared_tensor.utils import (
29
30
  validate_payload_for_transport,
30
31
  )
31
32
 
32
-
33
33
  logger = logging.getLogger(__name__)
34
34
 
35
35
 
@@ -41,6 +41,9 @@ class _ClientReleaser(ReleaseHandle):
41
41
  def release(self) -> bool:
42
42
  return self.client.release(self.object_id)
43
43
 
44
+ def get_object_info(self) -> dict[str, Any] | None:
45
+ return self.client.get_object_info(self.object_id)
46
+
44
47
 
45
48
  class SharedTensorClient:
46
49
  """UDS client for endpoint-oriented local RPC execution."""
@@ -76,6 +79,8 @@ class SharedTensorClient:
76
79
  code = 5
77
80
  elif isinstance(exc, SharedTensorConfigurationError):
78
81
  code = 6
82
+ elif isinstance(exc, SharedTensorStaleHandleError):
83
+ code = 8
79
84
  else:
80
85
  code = 7
81
86
  return SharedTensorRemoteError(
@@ -105,6 +110,7 @@ class SharedTensorClient:
105
110
  object_id=cast(str, object_id),
106
111
  value=value,
107
112
  _releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
113
+ server_id=self._infer_server_id(),
108
114
  )
109
115
 
110
116
  def _send_request(self, request: dict[str, Any]) -> Any:
@@ -158,6 +164,15 @@ class SharedTensorClient:
158
164
  def _request(self, method: str, params: dict[str, Any] | None = None) -> Any:
159
165
  return self._send_request({"method": method, "params": params or {}})
160
166
 
167
+ def _infer_server_id(self) -> str | None:
168
+ local_server = self._local_server()
169
+ if local_server is not None:
170
+ return cast(str | None, getattr(local_server, "server_id", None))
171
+ try:
172
+ return cast(str | None, self.get_server_info().get("server_id"))
173
+ except (SharedTensorClientError, SharedTensorRemoteError, SharedTensorProtocolError):
174
+ return None
175
+
161
176
  def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
162
177
  if self.verbose_debug:
163
178
  logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
@@ -245,6 +260,25 @@ class SharedTensorClient:
245
260
  result = self._request("get_object_info", {"object_id": object_id})
246
261
  return cast(dict[str, Any] | None, result.get("object"))
247
262
 
263
+ def ensure_handle_live(self, handle: SharedObjectHandle[Any], *, refresh: bool = True) -> dict[str, Any]:
264
+ info = handle.get_object_info(refresh=refresh)
265
+ if info is None:
266
+ raise SharedTensorStaleHandleError(
267
+ f"Managed object '{handle.object_id}' is no longer registered on the producer",
268
+ object_id=handle.object_id,
269
+ server_id=handle.server_id,
270
+ reason="object_missing",
271
+ )
272
+ observed_server_id = cast(str | None, info.get("server_id"))
273
+ if handle.server_id is not None and observed_server_id is not None and observed_server_id != handle.server_id:
274
+ raise SharedTensorStaleHandleError(
275
+ f"Managed object '{handle.object_id}' belongs to server '{handle.server_id}' but producer now reports '{observed_server_id}'",
276
+ object_id=handle.object_id,
277
+ server_id=handle.server_id,
278
+ reason="server_mismatch",
279
+ )
280
+ return info
281
+
248
282
  def ping(self) -> bool:
249
283
  if self._local_server() is not None:
250
284
  return True
@@ -260,6 +294,31 @@ class SharedTensorClient:
260
294
  return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
261
295
  return cast(dict[str, Any], self._request("get_server_info"))
262
296
 
297
+ def invalidate_call_cache(self, endpoint: str, *args: Any, **kwargs: Any) -> bool:
298
+ local_server = self._local_server()
299
+ if local_server is not None:
300
+ return self._run_local(
301
+ lambda: bool(local_server.invalidate_call_cache(endpoint, args=tuple(args), kwargs=dict(kwargs)))
302
+ )
303
+ encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
304
+ result = self._request(
305
+ "invalidate_call_cache",
306
+ {
307
+ "endpoint": endpoint,
308
+ "args_bytes": args_payload,
309
+ "kwargs_bytes": kwargs_payload,
310
+ "encoding": encoding,
311
+ },
312
+ )
313
+ return bool(result["invalidated"])
314
+
315
+ def invalidate_endpoint_cache(self, endpoint: str) -> int:
316
+ local_server = self._local_server()
317
+ if local_server is not None:
318
+ return self._run_local(lambda: int(local_server.invalidate_endpoint_cache(endpoint)))
319
+ result = self._request("invalidate_endpoint_cache", {"endpoint": endpoint})
320
+ return int(result["invalidated"])
321
+
263
322
  def list_endpoints(self) -> dict[str, Any]:
264
323
  local_server = self._local_server()
265
324
  if local_server is not None:
@@ -344,4 +403,5 @@ class SharedTensorClient:
344
403
  object_id=cast(str, object_id),
345
404
  value=value,
346
405
  _releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
406
+ server_id=cast(str | None, result.get("server_id")),
347
407
  )
@@ -15,6 +15,7 @@ __all__ = [
15
15
  "SharedTensorClientError",
16
16
  "SharedTensorServerError",
17
17
  "SharedTensorProviderError",
18
+ "SharedTensorStaleHandleError",
18
19
  ]
19
20
 
20
21
 
@@ -69,3 +70,20 @@ class SharedTensorServerError(SharedTensorError):
69
70
 
70
71
  class SharedTensorProviderError(SharedTensorError):
71
72
  """Raised for provider registration or invocation problems."""
73
+
74
+
75
+ class SharedTensorStaleHandleError(SharedTensorError):
76
+ """Raised when a managed handle can no longer be trusted."""
77
+
78
+ def __init__(
79
+ self,
80
+ message: str,
81
+ *,
82
+ object_id: str | None = None,
83
+ server_id: str | None = None,
84
+ reason: str | None = None,
85
+ ) -> None:
86
+ super().__init__(message)
87
+ self.object_id = object_id
88
+ self.server_id = server_id
89
+ self.reason = reason
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import time
5
6
  import uuid
6
- from dataclasses import dataclass
7
+ from dataclasses import dataclass, field
7
8
  from threading import RLock
8
9
  from typing import Any, Generic, TypeVar
9
10
 
@@ -17,6 +18,8 @@ class ManagedObjectEntry:
17
18
  endpoint: str
18
19
  cache_key: str | None
19
20
  refcount: int = 1
21
+ created_at: float = 0.0
22
+ last_accessed_at: float = 0.0
20
23
 
21
24
 
22
25
  @dataclass(slots=True)
@@ -42,15 +45,19 @@ class ManagedObjectRegistry:
42
45
  if entry is None:
43
46
  self._cache_index.pop(cache_key, None)
44
47
  return None
48
+ entry.last_accessed_at = time.time()
45
49
  return entry
46
50
 
47
51
  def register(self, *, endpoint: str, value: Any, cache_key: str | None) -> ManagedObjectEntry:
48
52
  with self._lock:
53
+ now = time.time()
49
54
  entry = ManagedObjectEntry(
50
55
  object_id=uuid.uuid4().hex,
51
56
  value=value,
52
57
  endpoint=endpoint,
53
58
  cache_key=cache_key,
59
+ created_at=now,
60
+ last_accessed_at=now,
54
61
  )
55
62
  self._entries[entry.object_id] = entry
56
63
  if cache_key is not None:
@@ -67,6 +74,7 @@ class ManagedObjectRegistry:
67
74
  if entry is None:
68
75
  return None
69
76
  entry.refcount += 1
77
+ entry.last_accessed_at = time.time()
70
78
  return entry
71
79
 
72
80
  def release(self, object_id: str) -> ManagedReleaseResult:
@@ -100,6 +108,31 @@ class ManagedObjectRegistry:
100
108
  "endpoint": entry.endpoint,
101
109
  "cache_key": entry.cache_key,
102
110
  "refcount": entry.refcount,
111
+ "created_at": entry.created_at,
112
+ "last_accessed_at": entry.last_accessed_at,
113
+ }
114
+
115
+ def invalidate_cache_key(self, cache_key: str) -> bool:
116
+ with self._lock:
117
+ object_id = self._cache_index.pop(cache_key, None)
118
+ return object_id is not None
119
+
120
+ def invalidate_endpoint(self, endpoint: str) -> int:
121
+ with self._lock:
122
+ keys = [
123
+ cache_key
124
+ for cache_key, object_id in self._cache_index.items()
125
+ if (entry := self._entries.get(object_id)) is not None and entry.endpoint == endpoint
126
+ ]
127
+ for cache_key in keys:
128
+ self._cache_index.pop(cache_key, None)
129
+ return len(keys)
130
+
131
+ def stats(self) -> dict[str, int]:
132
+ with self._lock:
133
+ return {
134
+ "objects": len(self._entries),
135
+ "cached_objects": len(self._cache_index),
103
136
  }
104
137
 
105
138
  def clear(self) -> None:
@@ -112,6 +145,9 @@ class ReleaseHandle:
112
145
  def release(self) -> bool: # pragma: no cover - protocol surface only
113
146
  raise NotImplementedError
114
147
 
148
+ def get_object_info(self) -> dict[str, Any] | None: # pragma: no cover - protocol surface only
149
+ raise NotImplementedError
150
+
115
151
 
116
152
  @dataclass(slots=True)
117
153
  class SharedObjectHandle(Generic[T]):
@@ -119,6 +155,8 @@ class SharedObjectHandle(Generic[T]):
119
155
  value: T
120
156
  _releaser: ReleaseHandle
121
157
  released: bool = False
158
+ server_id: str | None = None
159
+ _metadata_cache: dict[str, Any] | None = field(default=None, init=False, repr=False)
122
160
 
123
161
  def release(self) -> bool:
124
162
  if self.released:
@@ -126,8 +164,19 @@ class SharedObjectHandle(Generic[T]):
126
164
  released = self._releaser.release()
127
165
  if released:
128
166
  self.released = True
167
+ self._metadata_cache = None
129
168
  return released
130
169
 
170
+ def get_object_info(self, *, refresh: bool = False) -> dict[str, Any] | None:
171
+ if self.released:
172
+ return None
173
+ if self._metadata_cache is None or refresh:
174
+ self._metadata_cache = self._releaser.get_object_info()
175
+ return None if self._metadata_cache is None else dict(self._metadata_cache)
176
+
177
+ def is_stale(self) -> bool:
178
+ return self.get_object_info(refresh=True) is None
179
+
131
180
  def __enter__(self) -> SharedObjectHandle[T]:
132
181
  return self
133
182
 
@@ -112,6 +112,7 @@ class SharedTensorProvider:
112
112
  self._async_client: Any | None = None
113
113
  self._server: Any | None = None
114
114
  self._cache: dict[str, Any] = {}
115
+ self._cache_key_index: dict[str, str] = {}
115
116
  self._endpoints: dict[str, EndpointDefinition] = {}
116
117
  self._registered_functions = self._endpoints
117
118
  self._lock = RLock()
@@ -263,6 +264,35 @@ class SharedTensorProvider:
263
264
  def list_tasks(self, status: str | None = None) -> dict[str, Any]:
264
265
  return self._get_async_client().list_tasks(status=status)
265
266
 
267
+ def invalidate_call_cache(self, endpoint: str, *args: Any, **kwargs: Any) -> bool:
268
+ if self.execution_mode == "server":
269
+ if self._server is not None and hasattr(self._server, "invalidate_call_cache"):
270
+ return bool(self._server.invalidate_call_cache(endpoint, args=args, kwargs=kwargs))
271
+ return False
272
+ if self.execution_mode == "local":
273
+ definition = self.get_endpoint(endpoint)
274
+ cache_key = self._cache_key_for(endpoint, definition, args, kwargs)
275
+ with self._lock:
276
+ removed = self._cache.pop(cache_key, None)
277
+ self._cache_key_index.pop(cache_key, None)
278
+ return removed is not None
279
+ return self._get_client().invalidate_call_cache(endpoint, *args, **kwargs)
280
+
281
+ def invalidate_endpoint_cache(self, endpoint: str) -> int:
282
+ if self.execution_mode == "server":
283
+ if self._server is not None and hasattr(self._server, "invalidate_endpoint_cache"):
284
+ return int(self._server.invalidate_endpoint_cache(endpoint))
285
+ return 0
286
+ if self.execution_mode == "local":
287
+ self.get_endpoint(endpoint)
288
+ with self._lock:
289
+ keys = [cache_key for cache_key, cache_endpoint in self._cache_key_index.items() if cache_endpoint == endpoint]
290
+ for cache_key in keys:
291
+ self._cache.pop(cache_key, None)
292
+ self._cache_key_index.pop(cache_key, None)
293
+ return len(keys)
294
+ return self._get_client().invalidate_endpoint_cache(endpoint)
295
+
266
296
  def invoke_local(
267
297
  self,
268
298
  endpoint: str,
@@ -283,6 +313,7 @@ class SharedTensorProvider:
283
313
  result = definition.func(*args, **resolved_kwargs)
284
314
  with self._lock:
285
315
  self._cache[cache_key] = result
316
+ self._cache_key_index[cache_key] = endpoint
286
317
  return result
287
318
 
288
319
  def get_endpoint(self, endpoint: str) -> EndpointDefinition:
@@ -328,6 +359,8 @@ class SharedTensorProvider:
328
359
  "device_index": self.device_index,
329
360
  "server_socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
330
361
  "server_running": bool(server is not None and getattr(server, "running", True)),
362
+ "endpoint_count": len(self._endpoints),
363
+ "cache_entries": len(self._cache),
331
364
  }
332
365
  server_info = self._get_client().get_server_info()
333
366
  return {
@@ -338,6 +371,7 @@ class SharedTensorProvider:
338
371
  "server_socket_path": server_info.get("socket_path"),
339
372
  "server_running": bool(server_info.get("running")),
340
373
  "server_ready": bool(server_info.get("ready")),
374
+ "endpoint_count": len(self._endpoints),
341
375
  "server_info": server_info,
342
376
  }
343
377
 
@@ -7,6 +7,7 @@ import os
7
7
  import socket
8
8
  import threading
9
9
  import time
10
+ import uuid
10
11
  from concurrent.futures import Future
11
12
  from dataclasses import dataclass, field
12
13
  from typing import Any
@@ -18,6 +19,7 @@ from shared_tensor.errors import (
18
19
  SharedTensorProtocolError,
19
20
  SharedTensorProviderError,
20
21
  SharedTensorSerializationError,
22
+ SharedTensorStaleHandleError,
21
23
  SharedTensorTaskError,
22
24
  )
23
25
  from shared_tensor.managed_object import ManagedObjectRegistry
@@ -110,6 +112,7 @@ class SharedTensorServer:
110
112
  self.startup_timeout = startup_timeout
111
113
  self.listener: socket.socket | None = None
112
114
  self.server_process: Any | None = None
115
+ self.server_id = uuid.uuid4().hex
113
116
  self.server_thread: _ServerThreadState | None = None
114
117
  self._resolved_process_start_method: str | None = None
115
118
  self.running = False
@@ -117,9 +120,13 @@ class SharedTensorServer:
117
120
  self.stats = {
118
121
  "requests_processed": 0,
119
122
  "errors_encountered": 0,
123
+ "cache_hits": 0,
124
+ "cache_misses": 0,
125
+ "task_submissions": 0,
126
+ "cache_invalidations": 0,
120
127
  }
121
128
  self._task_manager: TaskManager | None = None
122
- self._cache: dict[str, dict[str, Any]] = {}
129
+ self._cache: dict[str, str] = {}
123
130
  self._local_cache: dict[str, Any] = {}
124
131
  self._managed_objects = ManagedObjectRegistry()
125
132
  self._inflight: dict[str, _InFlightCall] = {}
@@ -182,6 +189,10 @@ class SharedTensorServer:
182
189
  return self._handle_release_objects(params)
183
190
  if method == "get_object_info":
184
191
  return self._handle_get_object_info(params)
192
+ if method == "invalidate_call_cache":
193
+ return self._handle_invalidate_call_cache(params)
194
+ if method == "invalidate_endpoint_cache":
195
+ return self._handle_invalidate_endpoint_cache(params)
185
196
  raise SharedTensorProtocolError(f"Unknown RPC method '{method}'")
186
197
 
187
198
  def _handle_call(self, params: dict[str, Any]) -> dict[str, Any]:
@@ -224,6 +235,7 @@ class SharedTensorServer:
224
235
  args: tuple[Any, ...],
225
236
  kwargs: dict[str, Any],
226
237
  ) -> Any:
238
+ self.stats["task_submissions"] += 1
227
239
  return self._task_manager_instance().submit(
228
240
  endpoint,
229
241
  self._execute_endpoint_result,
@@ -243,9 +255,11 @@ class SharedTensorServer:
243
255
  if cache_key is not None:
244
256
  cached = self._lookup_cached_result_value(definition, cache_key)
245
257
  if cached is not None:
258
+ self.stats["cache_hits"] += 1
246
259
  if self.verbose_debug:
247
260
  logger.debug("Server cache hit", extra={"endpoint": endpoint, "cache_key": cache_key})
248
261
  return cached
262
+ self.stats["cache_misses"] += 1
249
263
 
250
264
  inflight_key = cache_key if cache_key is not None and definition.singleflight else None
251
265
  if inflight_key is not None:
@@ -317,6 +331,7 @@ class SharedTensorServer:
317
331
  if cache_key is not None:
318
332
  with self._coordination_lock:
319
333
  self._local_cache[cache_key] = value
334
+ self._cache[cache_key] = endpoint
320
335
  return _EndpointResult(value=value)
321
336
 
322
337
  def _materialize_managed_result(
@@ -337,6 +352,9 @@ class SharedTensorServer:
337
352
  if self.verbose_debug:
338
353
  logger.debug("Server created managed object", extra={"endpoint": endpoint, "cache_key": cache_key})
339
354
  entry = self._managed_objects.register(endpoint=endpoint, value=result, cache_key=cache_key)
355
+ if cache_key is not None:
356
+ with self._coordination_lock:
357
+ self._cache[cache_key] = endpoint
340
358
  return _EndpointResult(value=entry.value, object_id=entry.object_id)
341
359
 
342
360
  def _lookup_cached_result_value(
@@ -353,9 +371,9 @@ class SharedTensorServer:
353
371
  self._managed_objects.add_ref(cached.object_id)
354
372
  return _EndpointResult(value=cached.value, object_id=cached.object_id)
355
373
  with self._coordination_lock:
356
- local_value = self._local_cache.get(cache_key)
357
- if local_value is None:
358
- return None
374
+ if cache_key not in self._local_cache:
375
+ return None
376
+ local_value = self._local_cache[cache_key]
359
377
  return _EndpointResult(value=local_value)
360
378
 
361
379
  def call_local_client(
@@ -413,22 +431,30 @@ class SharedTensorServer:
413
431
  if cache_key is not None:
414
432
  cached = self._managed_objects.get_cached(cache_key)
415
433
  if cached is not None:
434
+ self.stats["cache_hits"] += 1
416
435
  return cached.value
436
+ self.stats["cache_misses"] += 1
417
437
  value = definition.func(*args, **resolved_kwargs)
418
438
  if cache_key is not None:
419
439
  existing = self._managed_objects.get_cached(cache_key)
420
440
  if existing is not None:
441
+ self.stats["cache_hits"] += 1
421
442
  return existing.value
422
443
  self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
444
+ with self._coordination_lock:
445
+ self._cache[cache_key] = endpoint
423
446
  return value
424
447
  if cache_key is not None:
425
448
  with self._coordination_lock:
426
449
  if cache_key in self._local_cache:
450
+ self.stats["cache_hits"] += 1
427
451
  return self._local_cache[cache_key]
452
+ self.stats["cache_misses"] += 1
428
453
  value = definition.func(*args, **resolved_kwargs)
429
454
  if cache_key is not None:
430
455
  with self._coordination_lock:
431
456
  self._local_cache[cache_key] = value
457
+ self._cache[cache_key] = endpoint
432
458
  return value
433
459
 
434
460
  def _cache_key(
@@ -498,7 +524,21 @@ class SharedTensorServer:
498
524
 
499
525
  def _handle_get_object_info(self, params: dict[str, Any]) -> dict[str, Any]:
500
526
  object_id = self._require_object_id(params)
501
- return {"object": self._managed_objects.info(object_id)}
527
+ info = self._managed_objects.info(object_id)
528
+ if info is None:
529
+ return {"object": None}
530
+ return {"object": {**info, "server_id": self.server_id}}
531
+
532
+ def _handle_invalidate_call_cache(self, params: dict[str, Any]) -> dict[str, Any]:
533
+ endpoint, args, kwargs = self._decode_call_params(params)
534
+ removed = self.invalidate_call_cache(endpoint, args=args, kwargs=kwargs)
535
+ return {"invalidated": removed}
536
+
537
+ def _handle_invalidate_endpoint_cache(self, params: dict[str, Any]) -> dict[str, Any]:
538
+ endpoint = params.get("endpoint")
539
+ if not isinstance(endpoint, str) or not endpoint:
540
+ raise SharedTensorProtocolError("Missing required parameter 'endpoint'")
541
+ return {"invalidated": self.invalidate_endpoint_cache(endpoint)}
502
542
 
503
543
  def _decode_call_params(self, params: dict[str, Any]) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
504
544
  endpoint = params.get("endpoint")
@@ -523,14 +563,34 @@ class SharedTensorServer:
523
563
  validate_call_payload_for_transport(kwargs, allow_dict_keys=True)
524
564
  return endpoint, args, kwargs
525
565
 
526
- def _encode_result(self, value: Any, *, object_id: str | None = None) -> dict[str, Any]:
566
+ def _encode_result(
567
+ self,
568
+ value: Any,
569
+ *,
570
+ object_id: str | None = None,
571
+ server_id: str | None = None,
572
+ ) -> dict[str, Any]:
527
573
  if value is None:
528
- return {"encoding": None, "payload_bytes": None, "object_id": object_id}
574
+ return {
575
+ "encoding": None,
576
+ "payload_bytes": None,
577
+ "object_id": object_id,
578
+ "server_id": server_id,
579
+ }
529
580
  encoding, payload = serialize_payload(value)
530
- return {"encoding": encoding, "payload_bytes": payload, "object_id": object_id}
581
+ return {
582
+ "encoding": encoding,
583
+ "payload_bytes": payload,
584
+ "object_id": object_id,
585
+ "server_id": server_id,
586
+ }
531
587
 
532
588
  def _encode_endpoint_result(self, result: _EndpointResult) -> dict[str, Any]:
533
- return self._encode_result(result.value, object_id=result.object_id)
589
+ return self._encode_result(
590
+ result.value,
591
+ object_id=result.object_id,
592
+ server_id=self.server_id if result.object_id is not None else None,
593
+ )
534
594
 
535
595
  def _task_manager_instance(self) -> TaskManager:
536
596
  if self._task_manager is None:
@@ -540,6 +600,44 @@ class SharedTensorServer:
540
600
  )
541
601
  return self._task_manager
542
602
 
603
+ def invalidate_call_cache(
604
+ self,
605
+ endpoint: str,
606
+ *,
607
+ args: tuple[Any, ...] = (),
608
+ kwargs: dict[str, Any] | None = None,
609
+ ) -> bool:
610
+ definition = self.provider.get_endpoint(endpoint)
611
+ resolved_kwargs = kwargs or {}
612
+ cache_key = self._cache_key(endpoint, definition, args, resolved_kwargs)
613
+ if cache_key is None:
614
+ return False
615
+ invalidated_managed = False
616
+ if definition.managed:
617
+ invalidated_managed = self._managed_objects.invalidate_cache_key(cache_key)
618
+ with self._coordination_lock:
619
+ removed = self._local_cache.pop(cache_key, None)
620
+ self._cache.pop(cache_key, None)
621
+ invalidated = invalidated_managed or removed is not None
622
+ if invalidated:
623
+ self.stats["cache_invalidations"] += 1
624
+ return invalidated
625
+
626
+ def invalidate_endpoint_cache(self, endpoint: str) -> int:
627
+ self.provider.get_endpoint(endpoint)
628
+ removed = 0
629
+ with self._coordination_lock:
630
+ keys = [cache_key for cache_key, cache_endpoint in self._cache.items() if cache_endpoint == endpoint]
631
+ for cache_key in keys:
632
+ self._cache.pop(cache_key, None)
633
+ if cache_key in self._local_cache:
634
+ self._local_cache.pop(cache_key, None)
635
+ removed += 1
636
+ removed += self._managed_objects.invalidate_endpoint(endpoint)
637
+ if removed:
638
+ self.stats["cache_invalidations"] += removed
639
+ return removed
640
+
543
641
  @staticmethod
544
642
  def _require_task_id(params: dict[str, Any]) -> str:
545
643
  task_id = params.get("task_id")
@@ -559,6 +657,7 @@ class SharedTensorServer:
559
657
  return {
560
658
  "server": "SharedTensorServer",
561
659
  "version": _server_version(),
660
+ "server_id": self.server_id,
562
661
  "socket_path": self.socket_path,
563
662
  "uptime": uptime,
564
663
  "running": self.running,
@@ -567,7 +666,13 @@ class SharedTensorServer:
567
666
  "ppid": os.getppid(),
568
667
  "device_index": resolve_device_index(self.provider.device_index),
569
668
  "process_start_method": self._resolved_process_start_method,
570
- "stats": dict(self.stats),
669
+ "stats": {
670
+ **dict(self.stats),
671
+ "cache_entries": len(self._local_cache),
672
+ "inflight_calls": len(self._inflight),
673
+ **self._managed_objects.stats(),
674
+ "task_count": 0 if self._task_manager is None else len(self._task_manager.list()),
675
+ },
571
676
  "capabilities": capability_snapshot(),
572
677
  "endpoints": list(self.provider.list_endpoints().keys()),
573
678
  }
@@ -737,4 +842,6 @@ class SharedTensorServer:
737
842
  return 5
738
843
  if isinstance(exc, SharedTensorConfigurationError):
739
844
  return 6
845
+ if isinstance(exc, SharedTensorStaleHandleError):
846
+ return 8
740
847
  return 7
@@ -4,7 +4,6 @@ README.md
4
4
  pyproject.toml
5
5
  shared_tensor/__init__.py
6
6
  shared_tensor/async_client.py
7
- shared_tensor/async_provider.py
8
7
  shared_tensor/async_task.py
9
8
  shared_tensor/client.py
10
9
  shared_tensor/errors.py
@@ -1,97 +0,0 @@
1
- """Deprecated compatibility shim for task-oriented provider usage."""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Callable
6
- from functools import wraps
7
- from typing import Any, cast
8
-
9
- from shared_tensor.provider import SharedTensorProvider
10
-
11
-
12
- class AsyncSharedTensorProvider(SharedTensorProvider):
13
- def register(
14
- self,
15
- func: Callable[..., Any],
16
- *,
17
- cache: bool = True,
18
- cache_format_key: str | None = None,
19
- managed: bool = False,
20
- async_default_wait: bool = True,
21
- execution: str = "task",
22
- concurrency: str = "parallel",
23
- singleflight: bool = True,
24
- wait: bool | None = None,
25
- ) -> Callable[..., Any]:
26
- resolved_wait = async_default_wait if wait is None else wait
27
- registered = super().register(
28
- func,
29
- cache=cache,
30
- cache_format_key=cache_format_key,
31
- managed=managed,
32
- async_default_wait=resolved_wait,
33
- execution=cast(Any, execution),
34
- concurrency=cast(Any, concurrency),
35
- singleflight=singleflight,
36
- )
37
- if self.execution_mode in {"server", "local"}:
38
- return registered
39
-
40
- endpoint_name = func.__name__
41
-
42
- @wraps(func)
43
- def wrapper(*args: Any, **kwargs: Any) -> Any:
44
- if resolved_wait:
45
- return self.call(endpoint_name, *args, **kwargs)
46
- return self.submit(endpoint_name, *args, **kwargs)
47
-
48
- wrapped = cast(Any, wrapper)
49
- wrapped.submit_async = lambda *args, **kwargs: self.submit(endpoint_name, *args, **kwargs)
50
- wrapped.execute_async = lambda *args, wait=resolved_wait, timeout=None, callback=None, **kwargs: self.execute(
51
- endpoint_name,
52
- *args,
53
- wait=wait,
54
- timeout=timeout,
55
- callback=callback,
56
- **kwargs,
57
- )
58
- return cast(Callable[..., Any], wrapped)
59
-
60
- def share(
61
- self,
62
- func: Callable[..., Any] | None = None,
63
- *,
64
- cache: bool = True,
65
- cache_format_key: str | None = None,
66
- managed: bool = False,
67
- execution: str = "task",
68
- concurrency: str = "parallel",
69
- singleflight: bool = True,
70
- wait: bool | None = None,
71
- **_: Any,
72
- ) -> Callable[[Callable[..., Any]], Callable[..., Any]] | Callable[..., Any]:
73
- if func is not None:
74
- return self.register(
75
- func,
76
- cache=cache,
77
- cache_format_key=cache_format_key,
78
- managed=managed,
79
- execution=execution,
80
- concurrency=concurrency,
81
- singleflight=singleflight,
82
- wait=wait,
83
- )
84
-
85
- def decorator(inner: Callable[..., Any]) -> Callable[..., Any]:
86
- return self.register(
87
- inner,
88
- cache=cache,
89
- cache_format_key=cache_format_key,
90
- managed=managed,
91
- execution=execution,
92
- concurrency=concurrency,
93
- singleflight=singleflight,
94
- wait=wait,
95
- )
96
-
97
- return decorator
File without changes
File without changes