shared-tensor 0.2.6__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
@@ -16,18 +16,20 @@ Classifier: Intended Audience :: Developers
16
16
  Classifier: Intended Audience :: Science/Research
17
17
  Classifier: Operating System :: POSIX :: Linux
18
18
  Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.9
19
20
  Classifier: Programming Language :: Python :: 3.10
20
21
  Classifier: Programming Language :: Python :: 3.11
21
22
  Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
22
24
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
25
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
26
  Classifier: Topic :: System :: Distributed Computing
25
- Requires-Python: >=3.10
27
+ Requires-Python: <3.14,>=3.9
26
28
  Description-Content-Type: text/markdown
27
29
  License-File: LICENSE
28
30
  Requires-Dist: cloudpickle>=3.0.0
29
31
  Requires-Dist: numpy<2
30
- Requires-Dist: torch>=2.2.0
32
+ Requires-Dist: torch<2.8,>=2.1
31
33
  Provides-Extra: dev
32
34
  Requires-Dist: pytest>=6.0; extra == "dev"
33
35
  Requires-Dist: pytest-cov>=2.0; extra == "dev"
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.6"
7
+ version = "0.2.7"
8
8
  description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -33,18 +33,20 @@ classifiers = [
33
33
  "Intended Audience :: Science/Research",
34
34
  "Operating System :: POSIX :: Linux",
35
35
  "Programming Language :: Python :: 3",
36
+ "Programming Language :: Python :: 3.9",
36
37
  "Programming Language :: Python :: 3.10",
37
38
  "Programming Language :: Python :: 3.11",
38
39
  "Programming Language :: Python :: 3.12",
40
+ "Programming Language :: Python :: 3.13",
39
41
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
40
42
  "Topic :: Software Development :: Libraries :: Python Modules",
41
43
  "Topic :: System :: Distributed Computing",
42
44
  ]
43
- requires-python = ">=3.10"
45
+ requires-python = ">=3.9,<3.14"
44
46
  dependencies = [
45
47
  "cloudpickle>=3.0.0",
46
48
  "numpy<2",
47
- "torch>=2.2.0",
49
+ "torch>=2.1,<2.8",
48
50
  ]
49
51
 
50
52
  [project.optional-dependencies]
@@ -89,7 +91,7 @@ shared_tensor = ["*.so", "*.dll", "*.dylib"]
89
91
 
90
92
  [tool.black]
91
93
  line-length = 88
92
- target-version = ['py310', 'py311', 'py312']
94
+ target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
93
95
  include = '\.pyi?$'
94
96
  extend-exclude = '''
95
97
  /(
@@ -115,7 +117,7 @@ use_parentheses = true
115
117
  ensure_newline_before_comments = true
116
118
 
117
119
  [tool.mypy]
118
- python_version = "3.10"
120
+ python_version = "3.9"
119
121
  warn_return_any = true
120
122
  warn_unused_configs = true
121
123
  disallow_untyped_defs = true
@@ -180,7 +182,7 @@ exclude_lines = [
180
182
  ]
181
183
 
182
184
  [tool.ruff]
183
- target-version = "py310"
185
+ target-version = "py39"
184
186
  line-length = 88
185
187
 
186
188
  [tool.ruff.lint]
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.6"
22
+ __version__ = "0.2.7"
@@ -66,6 +66,7 @@ class TaskInfo:
66
66
  class _TaskEntry:
67
67
  info: TaskInfo
68
68
  future: Future[Any]
69
+ local_result: Any = None
69
70
 
70
71
 
71
72
  class TaskManager:
@@ -139,6 +140,8 @@ class TaskManager:
139
140
  )
140
141
  return
141
142
 
143
+ self._store_local_result(task_id, result)
144
+
142
145
  if result is None:
143
146
  self._transition(
144
147
  task_id,
@@ -191,6 +194,13 @@ class TaskManager:
191
194
  for key, value in updates.items():
192
195
  setattr(entry.info, key, value)
193
196
 
197
+ def _store_local_result(self, task_id: str, value: Any) -> None:
198
+ with self._lock:
199
+ entry = self._tasks.get(task_id)
200
+ if entry is None:
201
+ return
202
+ entry.local_result = value
203
+
194
204
  def get(self, task_id: str) -> TaskInfo:
195
205
  self._maybe_cleanup()
196
206
  with self._lock:
@@ -207,6 +217,24 @@ class TaskManager:
207
217
  return None
208
218
  return deserialize_payload(encoding, payload_bytes)
209
219
 
220
+ def result_local(self, task_id: str) -> Any:
221
+ self._maybe_cleanup()
222
+ with self._lock:
223
+ entry = self._tasks.get(task_id)
224
+ if entry is None:
225
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
226
+ info = copy.deepcopy(entry.info)
227
+ value = entry.local_result
228
+ if info.status == TaskStatus.CANCELLED:
229
+ raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
230
+ if info.status == TaskStatus.FAILED:
231
+ raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
232
+ if info.status != TaskStatus.COMPLETED:
233
+ raise SharedTensorTaskError(
234
+ f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
235
+ )
236
+ return value
237
+
210
238
  def wait_result_payload(
211
239
  self,
212
240
  task_id: str,
@@ -242,6 +270,21 @@ class TaskManager:
242
270
  "object_id": info.metadata.get("object_id"),
243
271
  }
244
272
 
273
+ def wait_result_local(self, task_id: str, timeout: float | None = None) -> Any:
274
+ self._maybe_cleanup()
275
+ with self._lock:
276
+ entry = self._tasks.get(task_id)
277
+ if entry is None:
278
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
279
+ future = entry.future
280
+ try:
281
+ future.result(timeout=timeout)
282
+ except FutureTimeoutError as exc:
283
+ raise SharedTensorTaskError(
284
+ f"Task '{task_id}' did not complete within {timeout} seconds"
285
+ ) from exc
286
+ return self.result_local(task_id)
287
+
245
288
  def cancel(self, task_id: str) -> bool:
246
289
  self._maybe_cleanup()
247
290
  with self._lock:
@@ -8,16 +8,25 @@ from dataclasses import dataclass
8
8
  from typing import Any, cast
9
9
 
10
10
  from shared_tensor.errors import (
11
+ SharedTensorCapabilityError,
11
12
  SharedTensorClientError,
13
+ SharedTensorConfigurationError,
14
+ SharedTensorError,
15
+ SharedTensorProviderError,
12
16
  SharedTensorProtocolError,
13
17
  SharedTensorRemoteError,
18
+ SharedTensorSerializationError,
19
+ SharedTensorTaskError,
14
20
  )
15
21
  from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
22
+ from shared_tensor.runtime import get_local_server
16
23
  from shared_tensor.transport import recv_message, send_message
24
+ from shared_tensor.async_task import TaskStatus
17
25
  from shared_tensor.utils import (
18
26
  deserialize_payload,
19
27
  resolve_runtime_socket_path,
20
28
  serialize_call_payloads,
29
+ validate_payload_for_transport,
21
30
  )
22
31
 
23
32
 
@@ -50,6 +59,54 @@ class SharedTensorClient:
50
59
  self.timeout = timeout
51
60
  self.verbose_debug = verbose_debug
52
61
 
62
+ def _local_server(self):
63
+ return get_local_server(self.socket_path)
64
+
65
+ @staticmethod
66
+ def _remote_error_from_local(exc: SharedTensorError) -> SharedTensorRemoteError:
67
+ if isinstance(exc, SharedTensorProtocolError):
68
+ code = 1
69
+ elif isinstance(exc, SharedTensorProviderError):
70
+ code = 2
71
+ elif isinstance(exc, SharedTensorSerializationError):
72
+ code = 3
73
+ elif isinstance(exc, SharedTensorCapabilityError):
74
+ code = 4
75
+ elif isinstance(exc, SharedTensorTaskError):
76
+ code = 5
77
+ elif isinstance(exc, SharedTensorConfigurationError):
78
+ code = 6
79
+ else:
80
+ code = 7
81
+ return SharedTensorRemoteError(
82
+ f"Remote error [{code}]: {exc}",
83
+ code=code,
84
+ data=None,
85
+ error_type=type(exc).__name__,
86
+ )
87
+
88
+ def _run_local(self, operation):
89
+ try:
90
+ return operation()
91
+ except SharedTensorError as exc:
92
+ raise self._remote_error_from_local(exc) from exc
93
+
94
+ def _decode_local_result(self, result: Any) -> Any:
95
+ if result is None:
96
+ return None
97
+ value = result.value
98
+ if value is None:
99
+ return None
100
+ validate_payload_for_transport(value, allow_dict_keys=isinstance(value, dict))
101
+ object_id = result.object_id
102
+ if object_id is None:
103
+ return value
104
+ return SharedObjectHandle(
105
+ object_id=cast(str, object_id),
106
+ value=value,
107
+ _releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
108
+ )
109
+
53
110
  def _send_request(self, request: dict[str, Any]) -> Any:
54
111
  method = request.get("method", "<unknown>")
55
112
  if self.verbose_debug:
@@ -104,6 +161,13 @@ class SharedTensorClient:
104
161
  def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
105
162
  if self.verbose_debug:
106
163
  logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
164
+ local_server = self._local_server()
165
+ if local_server is not None:
166
+ return self._run_local(
167
+ lambda: self._decode_local_result(
168
+ local_server.call_local_client(endpoint, args=tuple(args), kwargs=dict(kwargs))
169
+ )
170
+ )
107
171
  encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
108
172
  result = self._request(
109
173
  "call",
@@ -119,6 +183,19 @@ class SharedTensorClient:
119
183
  def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
120
184
  if self.verbose_debug:
121
185
  logger.debug("Client submitting task", extra={"endpoint": endpoint})
186
+ local_server = self._local_server()
187
+ if local_server is not None:
188
+ return self._run_local(
189
+ lambda: cast(
190
+ str,
191
+ local_server._submit_endpoint_task(
192
+ endpoint,
193
+ local_server.provider.get_endpoint(endpoint),
194
+ tuple(args),
195
+ dict(kwargs),
196
+ ).task_id,
197
+ )
198
+ )
122
199
  encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
123
200
  result = self._request(
124
201
  "submit",
@@ -134,18 +211,43 @@ class SharedTensorClient:
134
211
  def release(self, object_id: str) -> bool:
135
212
  if self.verbose_debug:
136
213
  logger.debug("Client releasing managed object", extra={"object_id": object_id})
214
+ local_server = self._local_server()
215
+ if local_server is not None:
216
+ return self._run_local(
217
+ lambda: bool(local_server._handle_release_object({"object_id": object_id})["released"])
218
+ )
137
219
  result = self._request("release_object", {"object_id": object_id})
138
220
  return bool(result["released"])
139
221
 
140
222
  def release_many(self, object_ids: list[str]) -> dict[str, bool]:
223
+ local_server = self._local_server()
224
+ if local_server is not None:
225
+ return self._run_local(
226
+ lambda: {
227
+ object_id: bool(released)
228
+ for object_id, released in local_server._handle_release_objects({"object_ids": object_ids})[
229
+ "released"
230
+ ].items()
231
+ }
232
+ )
141
233
  result = self._request("release_objects", {"object_ids": object_ids})
142
234
  return {object_id: bool(released) for object_id, released in result["released"].items()}
143
235
 
144
236
  def get_object_info(self, object_id: str) -> dict[str, Any] | None:
237
+ local_server = self._local_server()
238
+ if local_server is not None:
239
+ return self._run_local(
240
+ lambda: cast(
241
+ dict[str, Any] | None,
242
+ local_server._handle_get_object_info({"object_id": object_id}).get("object"),
243
+ )
244
+ )
145
245
  result = self._request("get_object_info", {"object_id": object_id})
146
246
  return cast(dict[str, Any] | None, result.get("object"))
147
247
 
148
248
  def ping(self) -> bool:
249
+ if self._local_server() is not None:
250
+ return True
149
251
  try:
150
252
  self._request("ping")
151
253
  except (SharedTensorClientError, SharedTensorRemoteError):
@@ -153,29 +255,66 @@ class SharedTensorClient:
153
255
  return True
154
256
 
155
257
  def get_server_info(self) -> dict[str, Any]:
258
+ local_server = self._local_server()
259
+ if local_server is not None:
260
+ return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
156
261
  return cast(dict[str, Any], self._request("get_server_info"))
157
262
 
158
263
  def list_endpoints(self) -> dict[str, Any]:
264
+ local_server = self._local_server()
265
+ if local_server is not None:
266
+ return self._run_local(lambda: cast(dict[str, Any], local_server.provider.list_endpoints()))
159
267
  return cast(dict[str, Any], self._request("list_endpoints"))
160
268
 
161
269
  def get_task_status(self, task_id: str) -> dict[str, Any]:
270
+ local_server = self._local_server()
271
+ if local_server is not None:
272
+ return self._run_local(
273
+ lambda: cast(dict[str, Any], local_server._task_manager_instance().get(task_id).to_dict())
274
+ )
162
275
  return cast(dict[str, Any], self._request("get_task", {"task_id": task_id}))
163
276
 
164
277
  def get_task_result(self, task_id: str) -> Any:
278
+ local_server = self._local_server()
279
+ if local_server is not None:
280
+ return self._run_local(
281
+ lambda: self._decode_local_result(local_server.get_task_result_local(task_id))
282
+ )
165
283
  return self._decode_rpc_payload(self._request("get_task_result", {"task_id": task_id}))
166
284
 
167
285
  def wait_task(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
168
286
  if self.verbose_debug:
169
287
  logger.debug("Client waiting for task", extra={"task_id": task_id, "timeout": timeout})
288
+ local_server = self._local_server()
289
+ if local_server is not None:
290
+ return self._run_local(
291
+ lambda: cast(dict[str, Any], local_server.wait_task_local(task_id, timeout=timeout))
292
+ )
170
293
  params = {"task_id": task_id}
171
294
  if timeout is not None:
172
295
  params["timeout"] = timeout
173
296
  return cast(dict[str, Any], self._request("wait_task", params))
174
297
 
175
298
  def cancel_task(self, task_id: str) -> bool:
299
+ local_server = self._local_server()
300
+ if local_server is not None:
301
+ return self._run_local(lambda: bool(local_server._task_manager_instance().cancel(task_id)))
176
302
  return bool(self._request("cancel_task", {"task_id": task_id})["cancelled"])
177
303
 
178
304
  def list_tasks(self, status: str | None = None) -> dict[str, Any]:
305
+ local_server = self._local_server()
306
+ if local_server is not None:
307
+ return self._run_local(
308
+ lambda: cast(
309
+ dict[str, Any],
310
+ {
311
+ listed_task_id: info.to_dict()
312
+ for listed_task_id, info in local_server._task_manager_instance()
313
+ .list(status=None if status is None else TaskStatus(status))
314
+ .items()
315
+ },
316
+ )
317
+ )
179
318
  params = {"status": status} if status else None
180
319
  return cast(dict[str, Any], self._request("list_tasks", params))
181
320
 
@@ -0,0 +1,30 @@
1
+ """In-process runtime registry for thread-backed local servers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from threading import RLock
6
+ from typing import TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from shared_tensor.server import SharedTensorServer
10
+
11
+
12
+ _LOCK = RLock()
13
+ _SERVERS: dict[str, "SharedTensorServer"] = {}
14
+
15
+
16
+ def register_local_server(socket_path: str, server: "SharedTensorServer") -> None:
17
+ with _LOCK:
18
+ _SERVERS[socket_path] = server
19
+
20
+
21
+ def unregister_local_server(socket_path: str, server: "SharedTensorServer") -> None:
22
+ with _LOCK:
23
+ current = _SERVERS.get(socket_path)
24
+ if current is server:
25
+ _SERVERS.pop(socket_path, None)
26
+
27
+
28
+ def get_local_server(socket_path: str) -> "SharedTensorServer | None":
29
+ with _LOCK:
30
+ return _SERVERS.get(socket_path)
@@ -22,6 +22,7 @@ from shared_tensor.errors import (
22
22
  )
23
23
  from shared_tensor.managed_object import ManagedObjectRegistry
24
24
  from shared_tensor.provider import EndpointDefinition, SharedTensorProvider
25
+ from shared_tensor.runtime import register_local_server, unregister_local_server
25
26
  from shared_tensor.transport import recv_message, send_message
26
27
  from shared_tensor.utils import (
27
28
  CONTROL_ENCODING,
@@ -59,6 +60,12 @@ class _ServerThreadState:
59
60
  error: BaseException | None = None
60
61
 
61
62
 
63
+ @dataclass(slots=True)
64
+ class _EndpointResult:
65
+ value: Any
66
+ object_id: str | None = None
67
+
68
+
62
69
  class SharedTensorServer:
63
70
  def __init__(
64
71
  self,
@@ -197,22 +204,22 @@ class SharedTensorServer:
197
204
  ) -> Any:
198
205
  return self._task_manager_instance().submit(
199
206
  endpoint,
200
- self._execute_endpoint_call,
207
+ self._execute_endpoint_result,
201
208
  (endpoint, definition, args, kwargs),
202
209
  {},
203
- result_encoder=lambda payload: payload,
210
+ result_encoder=self._encode_endpoint_result,
204
211
  )
205
212
 
206
- def _execute_endpoint_call(
213
+ def _execute_endpoint_result(
207
214
  self,
208
215
  endpoint: str,
209
216
  definition: EndpointDefinition,
210
217
  args: tuple[Any, ...],
211
218
  kwargs: dict[str, Any],
212
- ) -> dict[str, Any]:
219
+ ) -> _EndpointResult:
213
220
  cache_key = self._cache_key(endpoint, definition, args, kwargs)
214
221
  if cache_key is not None:
215
- cached = self._lookup_cached_result(definition, cache_key)
222
+ cached = self._lookup_cached_result_value(definition, cache_key)
216
223
  if cached is not None:
217
224
  if self.verbose_debug:
218
225
  logger.debug("Server cache hit", extra={"endpoint": endpoint, "cache_key": cache_key})
@@ -224,20 +231,15 @@ class SharedTensorServer:
224
231
  if self.verbose_debug and owner:
225
232
  logger.debug("Server created singleflight entry", extra={"endpoint": endpoint, "cache_key": inflight_key})
226
233
  if not owner:
227
- if self.verbose_debug:
228
- logger.debug("Server joined singleflight entry", extra={"endpoint": endpoint, "cache_key": inflight_key})
229
- if definition.managed:
230
- payload = future.result()
231
- object_id = payload.get("object_id")
232
- if object_id is not None:
233
- self._managed_objects.add_ref(object_id)
234
- return payload
235
- return future.result()
234
+ result = future.result()
235
+ if definition.managed and result.object_id is not None:
236
+ self._managed_objects.add_ref(result.object_id)
237
+ return result
236
238
  else:
237
239
  future = None
238
240
 
239
241
  try:
240
- encoded = self._run_endpoint_under_policy(endpoint, definition, args, kwargs, cache_key)
242
+ result = self._run_endpoint_under_policy(endpoint, definition, args, kwargs, cache_key)
241
243
  except Exception as exc:
242
244
  if future is not None:
243
245
  future.set_exception(exc)
@@ -245,9 +247,20 @@ class SharedTensorServer:
245
247
  raise
246
248
 
247
249
  if future is not None:
248
- future.set_result(encoded)
250
+ future.set_result(result)
249
251
  self._release_inflight(inflight_key, future)
250
- return encoded
252
+ return result
253
+
254
+ def _execute_endpoint_call(
255
+ self,
256
+ endpoint: str,
257
+ definition: EndpointDefinition,
258
+ args: tuple[Any, ...],
259
+ kwargs: dict[str, Any],
260
+ ) -> dict[str, Any]:
261
+ return self._encode_endpoint_result(
262
+ self._execute_endpoint_result(endpoint, definition, args, kwargs)
263
+ )
251
264
 
252
265
  def _run_endpoint_under_policy(
253
266
  self,
@@ -256,11 +269,11 @@ class SharedTensorServer:
256
269
  args: tuple[Any, ...],
257
270
  kwargs: dict[str, Any],
258
271
  cache_key: str | None,
259
- ) -> dict[str, Any]:
272
+ ) -> _EndpointResult:
260
273
  if definition.concurrency == "serialized":
261
274
  lock = self._endpoint_lock(endpoint)
262
275
  with lock:
263
- cached = self._lookup_cached_result(definition, cache_key)
276
+ cached = self._lookup_cached_result_value(definition, cache_key)
264
277
  if cached is not None:
265
278
  return cached
266
279
  return self._materialize_endpoint_result(endpoint, definition, args, kwargs, cache_key)
@@ -273,17 +286,15 @@ class SharedTensorServer:
273
286
  args: tuple[Any, ...],
274
287
  kwargs: dict[str, Any],
275
288
  cache_key: str | None,
276
- ) -> dict[str, Any]:
289
+ ) -> _EndpointResult:
277
290
  if definition.managed:
278
291
  return self._materialize_managed_result(endpoint, definition, args, kwargs, cache_key)
279
292
  value = definition.func(*args, **kwargs)
280
293
  if self.verbose_debug:
281
294
  logger.debug("Server executed direct endpoint", extra={"endpoint": endpoint})
282
- result = self._encode_result(value)
283
295
  if cache_key is not None:
284
- self._cache[cache_key] = result
285
296
  self._local_cache[cache_key] = value
286
- return result
297
+ return _EndpointResult(value=value)
287
298
 
288
299
  def _materialize_managed_result(
289
300
  self,
@@ -292,24 +303,24 @@ class SharedTensorServer:
292
303
  args: tuple[Any, ...],
293
304
  kwargs: dict[str, Any],
294
305
  cache_key: str | None,
295
- ) -> dict[str, Any]:
306
+ ) -> _EndpointResult:
296
307
  if cache_key is not None:
297
308
  cached = self._managed_objects.get_cached(cache_key)
298
309
  if cached is not None:
299
310
  self._managed_objects.add_ref(cached.object_id)
300
- return self._encode_result(cached.value, object_id=cached.object_id)
311
+ return _EndpointResult(value=cached.value, object_id=cached.object_id)
301
312
 
302
313
  result = definition.func(*args, **kwargs)
303
314
  if self.verbose_debug:
304
315
  logger.debug("Server created managed object", extra={"endpoint": endpoint, "cache_key": cache_key})
305
316
  entry = self._managed_objects.register(endpoint=endpoint, value=result, cache_key=cache_key)
306
- return self._encode_result(entry.value, object_id=entry.object_id)
317
+ return _EndpointResult(value=entry.value, object_id=entry.object_id)
307
318
 
308
- def _lookup_cached_result(
319
+ def _lookup_cached_result_value(
309
320
  self,
310
321
  definition: EndpointDefinition,
311
322
  cache_key: str | None,
312
- ) -> dict[str, Any] | None:
323
+ ) -> _EndpointResult | None:
313
324
  if cache_key is None:
314
325
  return None
315
326
  if definition.managed:
@@ -317,16 +328,52 @@ class SharedTensorServer:
317
328
  if cached is None:
318
329
  return None
319
330
  self._managed_objects.add_ref(cached.object_id)
320
- return self._encode_result(cached.value, object_id=cached.object_id)
321
- cached = self._cache.get(cache_key)
322
- if cached is not None:
323
- return cached
331
+ return _EndpointResult(value=cached.value, object_id=cached.object_id)
324
332
  local_value = self._local_cache.get(cache_key)
325
333
  if local_value is None:
326
334
  return None
327
- encoded = self._encode_result(local_value)
328
- self._cache[cache_key] = encoded
329
- return encoded
335
+ return _EndpointResult(value=local_value)
336
+
337
+ def call_local_client(
338
+ self,
339
+ endpoint: str,
340
+ *,
341
+ args: tuple[Any, ...] = (),
342
+ kwargs: dict[str, Any] | None = None,
343
+ ) -> _EndpointResult | None:
344
+ definition = self.provider.get_endpoint(endpoint)
345
+ resolved_kwargs = kwargs or {}
346
+ if definition.execution == "task":
347
+ task_info = self._submit_endpoint_task(endpoint, definition, args, resolved_kwargs)
348
+ return self.wait_task_result_local(task_info.task_id)
349
+ return self._execute_endpoint_result(endpoint, definition, args, resolved_kwargs)
350
+
351
+ def get_task_result_local(self, task_id: str) -> _EndpointResult | None:
352
+ result = self._task_manager_instance().result_local(task_id)
353
+ if result is None:
354
+ return None
355
+ return result
356
+
357
+ def wait_task_result_local(self, task_id: str, timeout: float | None = None) -> _EndpointResult | None:
358
+ result = self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
359
+ if result is None:
360
+ return None
361
+ return result
362
+
363
+ def wait_task_local(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
364
+ try:
365
+ self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
366
+ except SharedTensorTaskError:
367
+ info = self._task_manager_instance().get(task_id)
368
+ if info.status in {TaskStatus.PENDING, TaskStatus.RUNNING}:
369
+ return info.to_dict()
370
+ raise
371
+ return self._task_manager_instance().get(task_id).to_dict()
372
+
373
+ def encode_local_result(self, result: _EndpointResult | None) -> dict[str, Any]:
374
+ if result is None:
375
+ return {"encoding": None, "payload_bytes": None, "object_id": None}
376
+ return self._encode_endpoint_result(result)
330
377
 
331
378
  def invoke_local(
332
379
  self,
@@ -455,6 +502,9 @@ class SharedTensorServer:
455
502
  encoding, payload = serialize_payload(value)
456
503
  return {"encoding": encoding, "payload_bytes": payload, "object_id": object_id}
457
504
 
505
+ def _encode_endpoint_result(self, result: _EndpointResult) -> dict[str, Any]:
506
+ return self._encode_result(result.value, object_id=result.object_id)
507
+
458
508
  def _task_manager_instance(self) -> TaskManager:
459
509
  if self._task_manager is None:
460
510
  self._task_manager = TaskManager(
@@ -560,6 +610,7 @@ class SharedTensorServer:
560
610
  self.listener = listener
561
611
  self.running = True
562
612
  self.started_at = time.time()
613
+ register_local_server(self.socket_path, self)
563
614
  if started_event is not None:
564
615
  started_event.set()
565
616
  while self.running:
@@ -632,6 +683,7 @@ class SharedTensorServer:
632
683
  self._local_cache.clear()
633
684
  self._inflight.clear()
634
685
  self._endpoint_locks.clear()
686
+ unregister_local_server(self.socket_path, self)
635
687
  unlink_socket_path(self.socket_path)
636
688
 
637
689
  def __enter__(self) -> SharedTensorServer:
@@ -10,6 +10,7 @@ shared_tensor/client.py
10
10
  shared_tensor/errors.py
11
11
  shared_tensor/managed_object.py
12
12
  shared_tensor/provider.py
13
+ shared_tensor/runtime.py
13
14
  shared_tensor/server.py
14
15
  shared_tensor/transport.py
15
16
  shared_tensor/utils.py
File without changes
File without changes
File without changes
File without changes