shared-tensor 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
7
7
  License-Expression: Apache-2.0
8
8
  Project-URL: Homepage, https://github.com/world-sim-dev/shared-tensor
9
9
  Project-URL: Repository, https://github.com/world-sim-dev/shared-tensor
10
- Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/wiki
10
+ Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/tree/main/docs
11
11
  Project-URL: Bug Reports, https://github.com/world-sim-dev/shared-tensor/issues
12
12
  Project-URL: Changelog, https://github.com/world-sim-dev/shared-tensor/releases
13
13
  Keywords: gpu,memory,sharing,ipc,inter-process-communication,pytorch,cuda,model-serving,inference,torch,torch-ipc
@@ -16,18 +16,20 @@ Classifier: Intended Audience :: Developers
16
16
  Classifier: Intended Audience :: Science/Research
17
17
  Classifier: Operating System :: POSIX :: Linux
18
18
  Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.9
19
20
  Classifier: Programming Language :: Python :: 3.10
20
21
  Classifier: Programming Language :: Python :: 3.11
21
22
  Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
22
24
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
25
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
26
  Classifier: Topic :: System :: Distributed Computing
25
- Requires-Python: >=3.10
27
+ Requires-Python: <3.14,>=3.9
26
28
  Description-Content-Type: text/markdown
27
29
  License-File: LICENSE
28
30
  Requires-Dist: cloudpickle>=3.0.0
29
31
  Requires-Dist: numpy<2
30
- Requires-Dist: torch>=2.2.0
32
+ Requires-Dist: torch<2.8,>=2.1
31
33
  Provides-Extra: dev
32
34
  Requires-Dist: pytest>=6.0; extra == "dev"
33
35
  Requires-Dist: pytest-cov>=2.0; extra == "dev"
@@ -75,7 +77,7 @@ Not supported:
75
77
 
76
78
  ## Install
77
79
 
78
- Use Python `3.10+` and a CUDA-enabled PyTorch build.
80
+ Use Python `3.9+` and a CUDA-enabled PyTorch build.
79
81
 
80
82
  ```bash
81
83
  pip install shared-tensor
@@ -89,6 +91,16 @@ conda activate shared-tensor-dev
89
91
  pip install -e ".[dev,test]"
90
92
  ```
91
93
 
94
+ ## Docs
95
+
96
+ Read the examples first, then the design notes:
97
+
98
+ - `docs/overview.md`
99
+ - `docs/patterns.md`
100
+ - `docs/architecture.md`
101
+ - `docs/lifecycle.md`
102
+ - `docs/diagrams.md`
103
+
92
104
  ## Example: Manual Two-Process Deployment
93
105
 
94
106
  Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
@@ -25,7 +25,7 @@ Not supported:
25
25
 
26
26
  ## Install
27
27
 
28
- Use Python `3.10+` and a CUDA-enabled PyTorch build.
28
+ Use Python `3.9+` and a CUDA-enabled PyTorch build.
29
29
 
30
30
  ```bash
31
31
  pip install shared-tensor
@@ -39,6 +39,16 @@ conda activate shared-tensor-dev
39
39
  pip install -e ".[dev,test]"
40
40
  ```
41
41
 
42
+ ## Docs
43
+
44
+ Read the examples first, then the design notes:
45
+
46
+ - `docs/overview.md`
47
+ - `docs/patterns.md`
48
+ - `docs/architecture.md`
49
+ - `docs/lifecycle.md`
50
+ - `docs/diagrams.md`
51
+
42
52
  ## Example: Manual Two-Process Deployment
43
53
 
44
54
  Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.6"
7
+ version = "0.2.8"
8
8
  description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -33,18 +33,20 @@ classifiers = [
33
33
  "Intended Audience :: Science/Research",
34
34
  "Operating System :: POSIX :: Linux",
35
35
  "Programming Language :: Python :: 3",
36
+ "Programming Language :: Python :: 3.9",
36
37
  "Programming Language :: Python :: 3.10",
37
38
  "Programming Language :: Python :: 3.11",
38
39
  "Programming Language :: Python :: 3.12",
40
+ "Programming Language :: Python :: 3.13",
39
41
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
40
42
  "Topic :: Software Development :: Libraries :: Python Modules",
41
43
  "Topic :: System :: Distributed Computing",
42
44
  ]
43
- requires-python = ">=3.10"
45
+ requires-python = ">=3.9,<3.14"
44
46
  dependencies = [
45
47
  "cloudpickle>=3.0.0",
46
48
  "numpy<2",
47
- "torch>=2.2.0",
49
+ "torch>=2.1,<2.8",
48
50
  ]
49
51
 
50
52
  [project.optional-dependencies]
@@ -73,7 +75,7 @@ docs = [
73
75
  [project.urls]
74
76
  Homepage = "https://github.com/world-sim-dev/shared-tensor"
75
77
  Repository = "https://github.com/world-sim-dev/shared-tensor"
76
- Documentation = "https://github.com/world-sim-dev/shared-tensor/wiki"
78
+ Documentation = "https://github.com/world-sim-dev/shared-tensor/tree/main/docs"
77
79
  "Bug Reports" = "https://github.com/world-sim-dev/shared-tensor/issues"
78
80
  Changelog = "https://github.com/world-sim-dev/shared-tensor/releases"
79
81
 
@@ -89,7 +91,7 @@ shared_tensor = ["*.so", "*.dll", "*.dylib"]
89
91
 
90
92
  [tool.black]
91
93
  line-length = 88
92
- target-version = ['py310', 'py311', 'py312']
94
+ target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
93
95
  include = '\.pyi?$'
94
96
  extend-exclude = '''
95
97
  /(
@@ -115,7 +117,7 @@ use_parentheses = true
115
117
  ensure_newline_before_comments = true
116
118
 
117
119
  [tool.mypy]
118
- python_version = "3.10"
120
+ python_version = "3.9"
119
121
  warn_return_any = true
120
122
  warn_unused_configs = true
121
123
  disallow_untyped_defs = true
@@ -180,7 +182,7 @@ exclude_lines = [
180
182
  ]
181
183
 
182
184
  [tool.ruff]
183
- target-version = "py310"
185
+ target-version = "py39"
184
186
  line-length = 88
185
187
 
186
188
  [tool.ruff.lint]
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.6"
22
+ __version__ = "0.2.8"
@@ -33,8 +33,6 @@ class TaskInfo:
33
33
  created_at: float
34
34
  started_at: float | None = None
35
35
  completed_at: float | None = None
36
- result_encoding: str | None = None
37
- result_payload: bytes | None = None
38
36
  error_type: str | None = None
39
37
  error_message: str | None = None
40
38
  metadata: dict[str, Any] = field(default_factory=dict)
@@ -47,8 +45,6 @@ class TaskInfo:
47
45
  "created_at": self.created_at,
48
46
  "started_at": self.started_at,
49
47
  "completed_at": self.completed_at,
50
- "result_encoding": self.result_encoding,
51
- "result_payload": self.result_payload,
52
48
  "error_type": self.error_type,
53
49
  "error_message": self.error_message,
54
50
  "metadata": dict(self.metadata),
@@ -66,6 +62,9 @@ class TaskInfo:
66
62
  class _TaskEntry:
67
63
  info: TaskInfo
68
64
  future: Future[Any]
65
+ result_encoding: str | None = None
66
+ result_payload: bytes | None = None
67
+ local_result: Any = None
69
68
 
70
69
 
71
70
  class TaskManager:
@@ -86,6 +85,7 @@ class TaskManager:
86
85
  self._last_cleanup = 0.0
87
86
  self._lock = RLock()
88
87
  self._tasks: dict[str, _TaskEntry] = {}
88
+ self._accepting_submissions = True
89
89
 
90
90
  def submit(
91
91
  self,
@@ -97,6 +97,8 @@ class TaskManager:
97
97
  ) -> TaskInfo:
98
98
  self._maybe_cleanup()
99
99
  with self._lock:
100
+ if not self._accepting_submissions:
101
+ raise SharedTensorTaskError("Task manager is shutting down and is not accepting new tasks")
100
102
  self._drop_oldest_finished_tasks_if_needed()
101
103
  if len(self._tasks) >= self._max_tasks:
102
104
  raise SharedTensorTaskError("Task capacity exceeded")
@@ -139,13 +141,14 @@ class TaskManager:
139
141
  )
140
142
  return
141
143
 
144
+ self._store_local_result(task_id, result)
145
+
142
146
  if result is None:
147
+ self._store_payload(task_id, encoding=None, payload=None, object_id=None)
143
148
  self._transition(
144
149
  task_id,
145
150
  status=TaskStatus.COMPLETED,
146
151
  completed_at=time.time(),
147
- result_encoding=None,
148
- result_payload=None,
149
152
  )
150
153
  return
151
154
 
@@ -165,13 +168,16 @@ class TaskManager:
165
168
  )
166
169
  return
167
170
 
171
+ self._store_payload(
172
+ task_id,
173
+ encoding=payload["encoding"],
174
+ payload=payload["payload_bytes"],
175
+ object_id=payload.get("object_id"),
176
+ )
168
177
  self._transition(
169
178
  task_id,
170
179
  status=TaskStatus.COMPLETED,
171
180
  completed_at=time.time(),
172
- result_encoding=payload["encoding"],
173
- result_payload=payload["payload_bytes"],
174
- metadata={"object_id": payload.get("object_id")},
175
181
  )
176
182
 
177
183
  @staticmethod
@@ -191,6 +197,31 @@ class TaskManager:
191
197
  for key, value in updates.items():
192
198
  setattr(entry.info, key, value)
193
199
 
200
+ def _store_local_result(self, task_id: str, value: Any) -> None:
201
+ with self._lock:
202
+ entry = self._tasks.get(task_id)
203
+ if entry is None:
204
+ return
205
+ entry.local_result = value
206
+
207
+ def _store_payload(
208
+ self,
209
+ task_id: str,
210
+ *,
211
+ encoding: str | None,
212
+ payload: bytes | None,
213
+ object_id: str | None,
214
+ ) -> None:
215
+ with self._lock:
216
+ entry = self._tasks.get(task_id)
217
+ if entry is None:
218
+ return
219
+ entry.result_encoding = encoding
220
+ entry.result_payload = payload
221
+ metadata = dict(entry.info.metadata)
222
+ metadata["object_id"] = object_id
223
+ entry.info.metadata = metadata
224
+
194
225
  def get(self, task_id: str) -> TaskInfo:
195
226
  self._maybe_cleanup()
196
227
  with self._lock:
@@ -207,6 +238,24 @@ class TaskManager:
207
238
  return None
208
239
  return deserialize_payload(encoding, payload_bytes)
209
240
 
241
+ def result_local(self, task_id: str) -> Any:
242
+ self._maybe_cleanup()
243
+ with self._lock:
244
+ entry = self._tasks.get(task_id)
245
+ if entry is None:
246
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
247
+ info = copy.deepcopy(entry.info)
248
+ value = entry.local_result
249
+ if info.status == TaskStatus.CANCELLED:
250
+ raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
251
+ if info.status == TaskStatus.FAILED:
252
+ raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
253
+ if info.status != TaskStatus.COMPLETED:
254
+ raise SharedTensorTaskError(
255
+ f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
256
+ )
257
+ return value
258
+
210
259
  def wait_result_payload(
211
260
  self,
212
261
  task_id: str,
@@ -227,7 +276,14 @@ class TaskManager:
227
276
  return self.result_payload(task_id)
228
277
 
229
278
  def result_payload(self, task_id: str) -> dict[str, str | bytes | None]:
230
- info = self.get(task_id)
279
+ self._maybe_cleanup()
280
+ with self._lock:
281
+ entry = self._tasks.get(task_id)
282
+ if entry is None:
283
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
284
+ info = copy.deepcopy(entry.info)
285
+ encoding = entry.result_encoding
286
+ payload = entry.result_payload
231
287
  if info.status == TaskStatus.CANCELLED:
232
288
  raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
233
289
  if info.status == TaskStatus.FAILED:
@@ -237,11 +293,26 @@ class TaskManager:
237
293
  f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
238
294
  )
239
295
  return {
240
- "encoding": info.result_encoding,
241
- "payload_bytes": info.result_payload,
296
+ "encoding": encoding,
297
+ "payload_bytes": payload,
242
298
  "object_id": info.metadata.get("object_id"),
243
299
  }
244
300
 
301
+ def wait_result_local(self, task_id: str, timeout: float | None = None) -> Any:
302
+ self._maybe_cleanup()
303
+ with self._lock:
304
+ entry = self._tasks.get(task_id)
305
+ if entry is None:
306
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
307
+ future = entry.future
308
+ try:
309
+ future.result(timeout=timeout)
310
+ except FutureTimeoutError as exc:
311
+ raise SharedTensorTaskError(
312
+ f"Task '{task_id}' did not complete within {timeout} seconds"
313
+ ) from exc
314
+ return self.result_local(task_id)
315
+
245
316
  def cancel(self, task_id: str) -> bool:
246
317
  self._maybe_cleanup()
247
318
  with self._lock:
@@ -266,8 +337,10 @@ class TaskManager:
266
337
  }
267
338
  return items
268
339
 
269
- def shutdown(self, *, wait: bool = True) -> None:
270
- self._executor.shutdown(wait=wait, cancel_futures=True)
340
+ def shutdown(self, *, wait: bool = True, cancel_futures: bool = True) -> None:
341
+ with self._lock:
342
+ self._accepting_submissions = False
343
+ self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
271
344
 
272
345
  def _maybe_cleanup(self) -> None:
273
346
  now = time.time()
@@ -8,16 +8,25 @@ from dataclasses import dataclass
8
8
  from typing import Any, cast
9
9
 
10
10
  from shared_tensor.errors import (
11
+ SharedTensorCapabilityError,
11
12
  SharedTensorClientError,
13
+ SharedTensorConfigurationError,
14
+ SharedTensorError,
15
+ SharedTensorProviderError,
12
16
  SharedTensorProtocolError,
13
17
  SharedTensorRemoteError,
18
+ SharedTensorSerializationError,
19
+ SharedTensorTaskError,
14
20
  )
15
21
  from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
22
+ from shared_tensor.runtime import get_local_server
16
23
  from shared_tensor.transport import recv_message, send_message
24
+ from shared_tensor.async_task import TaskStatus
17
25
  from shared_tensor.utils import (
18
26
  deserialize_payload,
19
27
  resolve_runtime_socket_path,
20
28
  serialize_call_payloads,
29
+ validate_payload_for_transport,
21
30
  )
22
31
 
23
32
 
@@ -50,6 +59,54 @@ class SharedTensorClient:
50
59
  self.timeout = timeout
51
60
  self.verbose_debug = verbose_debug
52
61
 
62
+ def _local_server(self):
63
+ return get_local_server(self.socket_path)
64
+
65
+ @staticmethod
66
+ def _remote_error_from_local(exc: SharedTensorError) -> SharedTensorRemoteError:
67
+ if isinstance(exc, SharedTensorProtocolError):
68
+ code = 1
69
+ elif isinstance(exc, SharedTensorProviderError):
70
+ code = 2
71
+ elif isinstance(exc, SharedTensorSerializationError):
72
+ code = 3
73
+ elif isinstance(exc, SharedTensorCapabilityError):
74
+ code = 4
75
+ elif isinstance(exc, SharedTensorTaskError):
76
+ code = 5
77
+ elif isinstance(exc, SharedTensorConfigurationError):
78
+ code = 6
79
+ else:
80
+ code = 7
81
+ return SharedTensorRemoteError(
82
+ f"Remote error [{code}]: {exc}",
83
+ code=code,
84
+ data=None,
85
+ error_type=type(exc).__name__,
86
+ )
87
+
88
+ def _run_local(self, operation):
89
+ try:
90
+ return operation()
91
+ except SharedTensorError as exc:
92
+ raise self._remote_error_from_local(exc) from exc
93
+
94
+ def _decode_local_result(self, result: Any) -> Any:
95
+ if result is None:
96
+ return None
97
+ value = result.value
98
+ if value is None:
99
+ return None
100
+ validate_payload_for_transport(value, allow_dict_keys=isinstance(value, dict))
101
+ object_id = result.object_id
102
+ if object_id is None:
103
+ return value
104
+ return SharedObjectHandle(
105
+ object_id=cast(str, object_id),
106
+ value=value,
107
+ _releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
108
+ )
109
+
53
110
  def _send_request(self, request: dict[str, Any]) -> Any:
54
111
  method = request.get("method", "<unknown>")
55
112
  if self.verbose_debug:
@@ -104,6 +161,13 @@ class SharedTensorClient:
104
161
  def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
105
162
  if self.verbose_debug:
106
163
  logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
164
+ local_server = self._local_server()
165
+ if local_server is not None:
166
+ return self._run_local(
167
+ lambda: self._decode_local_result(
168
+ local_server.call_local_client(endpoint, args=tuple(args), kwargs=dict(kwargs))
169
+ )
170
+ )
107
171
  encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
108
172
  result = self._request(
109
173
  "call",
@@ -119,6 +183,19 @@ class SharedTensorClient:
119
183
  def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
120
184
  if self.verbose_debug:
121
185
  logger.debug("Client submitting task", extra={"endpoint": endpoint})
186
+ local_server = self._local_server()
187
+ if local_server is not None:
188
+ return self._run_local(
189
+ lambda: cast(
190
+ str,
191
+ local_server._submit_endpoint_task(
192
+ endpoint,
193
+ local_server.provider.get_endpoint(endpoint),
194
+ tuple(args),
195
+ dict(kwargs),
196
+ ).task_id,
197
+ )
198
+ )
122
199
  encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
123
200
  result = self._request(
124
201
  "submit",
@@ -134,18 +211,43 @@ class SharedTensorClient:
134
211
  def release(self, object_id: str) -> bool:
135
212
  if self.verbose_debug:
136
213
  logger.debug("Client releasing managed object", extra={"object_id": object_id})
214
+ local_server = self._local_server()
215
+ if local_server is not None:
216
+ return self._run_local(
217
+ lambda: bool(local_server._handle_release_object({"object_id": object_id})["released"])
218
+ )
137
219
  result = self._request("release_object", {"object_id": object_id})
138
220
  return bool(result["released"])
139
221
 
140
222
  def release_many(self, object_ids: list[str]) -> dict[str, bool]:
223
+ local_server = self._local_server()
224
+ if local_server is not None:
225
+ return self._run_local(
226
+ lambda: {
227
+ object_id: bool(released)
228
+ for object_id, released in local_server._handle_release_objects({"object_ids": object_ids})[
229
+ "released"
230
+ ].items()
231
+ }
232
+ )
141
233
  result = self._request("release_objects", {"object_ids": object_ids})
142
234
  return {object_id: bool(released) for object_id, released in result["released"].items()}
143
235
 
144
236
  def get_object_info(self, object_id: str) -> dict[str, Any] | None:
237
+ local_server = self._local_server()
238
+ if local_server is not None:
239
+ return self._run_local(
240
+ lambda: cast(
241
+ dict[str, Any] | None,
242
+ local_server._handle_get_object_info({"object_id": object_id}).get("object"),
243
+ )
244
+ )
145
245
  result = self._request("get_object_info", {"object_id": object_id})
146
246
  return cast(dict[str, Any] | None, result.get("object"))
147
247
 
148
248
  def ping(self) -> bool:
249
+ if self._local_server() is not None:
250
+ return True
149
251
  try:
150
252
  self._request("ping")
151
253
  except (SharedTensorClientError, SharedTensorRemoteError):
@@ -153,29 +255,66 @@ class SharedTensorClient:
153
255
  return True
154
256
 
155
257
  def get_server_info(self) -> dict[str, Any]:
258
+ local_server = self._local_server()
259
+ if local_server is not None:
260
+ return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
156
261
  return cast(dict[str, Any], self._request("get_server_info"))
157
262
 
158
263
  def list_endpoints(self) -> dict[str, Any]:
264
+ local_server = self._local_server()
265
+ if local_server is not None:
266
+ return self._run_local(lambda: cast(dict[str, Any], local_server.provider.list_endpoints()))
159
267
  return cast(dict[str, Any], self._request("list_endpoints"))
160
268
 
161
269
  def get_task_status(self, task_id: str) -> dict[str, Any]:
270
+ local_server = self._local_server()
271
+ if local_server is not None:
272
+ return self._run_local(
273
+ lambda: cast(dict[str, Any], local_server._task_manager_instance().get(task_id).to_dict())
274
+ )
162
275
  return cast(dict[str, Any], self._request("get_task", {"task_id": task_id}))
163
276
 
164
277
  def get_task_result(self, task_id: str) -> Any:
278
+ local_server = self._local_server()
279
+ if local_server is not None:
280
+ return self._run_local(
281
+ lambda: self._decode_local_result(local_server.get_task_result_local(task_id))
282
+ )
165
283
  return self._decode_rpc_payload(self._request("get_task_result", {"task_id": task_id}))
166
284
 
167
285
  def wait_task(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
168
286
  if self.verbose_debug:
169
287
  logger.debug("Client waiting for task", extra={"task_id": task_id, "timeout": timeout})
288
+ local_server = self._local_server()
289
+ if local_server is not None:
290
+ return self._run_local(
291
+ lambda: cast(dict[str, Any], local_server.wait_task_local(task_id, timeout=timeout))
292
+ )
170
293
  params = {"task_id": task_id}
171
294
  if timeout is not None:
172
295
  params["timeout"] = timeout
173
296
  return cast(dict[str, Any], self._request("wait_task", params))
174
297
 
175
298
  def cancel_task(self, task_id: str) -> bool:
299
+ local_server = self._local_server()
300
+ if local_server is not None:
301
+ return self._run_local(lambda: bool(local_server._task_manager_instance().cancel(task_id)))
176
302
  return bool(self._request("cancel_task", {"task_id": task_id})["cancelled"])
177
303
 
178
304
  def list_tasks(self, status: str | None = None) -> dict[str, Any]:
305
+ local_server = self._local_server()
306
+ if local_server is not None:
307
+ return self._run_local(
308
+ lambda: cast(
309
+ dict[str, Any],
310
+ {
311
+ listed_task_id: info.to_dict()
312
+ for listed_task_id, info in local_server._task_manager_instance()
313
+ .list(status=None if status is None else TaskStatus(status))
314
+ .items()
315
+ },
316
+ )
317
+ )
179
318
  params = {"status": status} if status else None
180
319
  return cast(dict[str, Any], self._request("list_tasks", params))
181
320