shared-tensor 0.2.5__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
@@ -16,18 +16,20 @@ Classifier: Intended Audience :: Developers
16
16
  Classifier: Intended Audience :: Science/Research
17
17
  Classifier: Operating System :: POSIX :: Linux
18
18
  Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.9
19
20
  Classifier: Programming Language :: Python :: 3.10
20
21
  Classifier: Programming Language :: Python :: 3.11
21
22
  Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
22
24
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
25
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
26
  Classifier: Topic :: System :: Distributed Computing
25
- Requires-Python: >=3.10
27
+ Requires-Python: <3.14,>=3.9
26
28
  Description-Content-Type: text/markdown
27
29
  License-File: LICENSE
28
30
  Requires-Dist: cloudpickle>=3.0.0
29
31
  Requires-Dist: numpy<2
30
- Requires-Dist: torch>=2.2.0
32
+ Requires-Dist: torch<2.8,>=2.1
31
33
  Provides-Extra: dev
32
34
  Requires-Dist: pytest>=6.0; extra == "dev"
33
35
  Requires-Dist: pytest-cov>=2.0; extra == "dev"
@@ -63,6 +65,7 @@ Supported:
63
65
  - sync `call` and task-backed `submit`
64
66
  - managed object handles with explicit release
65
67
  - server-side caching, `cache_format_key`, and singleflight
68
+ - manual two-process deployment as the primary production path
66
69
  - zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
67
70
 
68
71
  Not supported:
@@ -88,46 +91,58 @@ conda activate shared-tensor-dev
88
91
  pip install -e ".[dev,test]"
89
92
  ```
90
93
 
91
- ## Example: Same Code, Two Processes
94
+ ## Example: Manual Two-Process Deployment
95
+
96
+ Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
92
97
 
93
- See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
98
+ See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
99
+
100
+ Server process:
94
101
 
95
102
  ```python
96
- import torch
103
+ from shared_tensor import SharedTensorProvider, SharedTensorServer
97
104
 
98
- from shared_tensor import SharedObjectHandle, SharedTensorProvider
105
+ provider = SharedTensorProvider(execution_mode="server")
99
106
 
100
- provider = SharedTensorProvider()
107
+ @provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
108
+ def load_model(hidden_size: int = 4):
109
+ ...
101
110
 
111
+ server = SharedTensorServer(provider)
112
+ server.start(blocking=True)
113
+ ```
102
114
 
103
- @provider.share(
104
- execution="task",
105
- managed=True,
106
- concurrency="serialized",
107
- cache_format_key="model:{hidden_size}",
108
- )
109
- def load_model(hidden_size: int = 4) -> torch.nn.Module:
110
- return torch.nn.Linear(hidden_size, 2, device="cuda")
115
+ Client process:
111
116
 
117
+ ```python
118
+ import torch
119
+
120
+ from shared_tensor import SharedObjectHandle, SharedTensorClient
112
121
 
122
+ client = SharedTensorClient()
113
123
  x = torch.ones(1, 4, device="cuda")
114
- result = load_model(hidden_size=4)
124
+ result = client.call("load_model", hidden_size=4)
115
125
  if isinstance(result, SharedObjectHandle):
116
126
  with result as handle:
117
127
  y = handle.value(x)
118
- else:
119
- y = result(x)
120
128
  ```
121
129
 
122
- Server process:
130
+ This keeps the contract explicit:
123
131
 
124
- ```bash
125
- SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
132
+ ```text
133
+ server process client process
134
+ ------------------------------ ------------------------------
135
+ owns CUDA allocations issues local UDS RPC requests
136
+ executes endpoint functions reopens CUDA objects via torch IPC
137
+ manages cache and refcounts releases managed handles explicitly
126
138
  ```
127
139
 
128
- Client process with the exact same file:
140
+ ## Example: Same Code, Two Processes
141
+
142
+ See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
129
143
 
130
144
  ```bash
145
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
131
146
  SHARED_TENSOR_ENABLED=1 python demo.py
132
147
  ```
133
148
 
@@ -138,7 +153,7 @@ same code
138
153
 
139
154
  server process client process
140
155
  ------------------------------ ------------------------------
141
- provider auto-starts UDS daemon provider builds client wrappers
156
+ provider auto-starts local thread provider builds client wrappers
142
157
  shared function runs locally shared function becomes RPC call
143
158
  CUDA object stays on same GPU CUDA object is reopened via torch IPC
144
159
  ```
@@ -201,19 +216,19 @@ SharedTensorProvider(enabled=None)
201
216
  Provider runtime controls:
202
217
 
203
218
  ```python
204
- SharedTensorProvider(server_process_start_method="fork")
205
219
  SharedTensorProvider(server_startup_timeout=30.0)
206
220
  provider.get_runtime_info()
207
221
  ```
208
222
 
209
- Use `server_process_start_method="fork"` when you explicitly want POSIX fork behavior.
210
- Leave it as `None` to let the library choose a safer default for the current entrypoint.
223
+ Non-blocking provider autostart runs the UDS server in a background thread inside the current process.
211
224
 
212
225
  `execution_mode="auto"` behaves as follows:
213
226
  - disabled: local mode
214
- - enabled + `SHARED_TENSOR_ROLE=server`: auto-start local server and execute endpoints locally
227
+ - enabled + `SHARED_TENSOR_ROLE=server`: auto-start a local background server thread and execute endpoints locally
215
228
  - enabled + role unset: build client wrappers
216
229
 
230
+ For production deployment, prefer explicit `SharedTensorServer(...).start(blocking=True)` in a dedicated server process.
231
+
217
232
  Socket selection is per CUDA device:
218
233
  - base path comes from `SHARED_TENSOR_BASE_PATH` or `/tmp/shared-tensor`
219
234
  - runtime socket path is `<base_path>-<device_index>.sock`
@@ -13,6 +13,7 @@ Supported:
13
13
  - sync `call` and task-backed `submit`
14
14
  - managed object handles with explicit release
15
15
  - server-side caching, `cache_format_key`, and singleflight
16
+ - manual two-process deployment as the primary production path
16
17
  - zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
17
18
 
18
19
  Not supported:
@@ -38,46 +39,58 @@ conda activate shared-tensor-dev
38
39
  pip install -e ".[dev,test]"
39
40
  ```
40
41
 
41
- ## Example: Same Code, Two Processes
42
+ ## Example: Manual Two-Process Deployment
43
+
44
+ Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
42
45
 
43
- See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
46
+ See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
47
+
48
+ Server process:
44
49
 
45
50
  ```python
46
- import torch
51
+ from shared_tensor import SharedTensorProvider, SharedTensorServer
47
52
 
48
- from shared_tensor import SharedObjectHandle, SharedTensorProvider
53
+ provider = SharedTensorProvider(execution_mode="server")
49
54
 
50
- provider = SharedTensorProvider()
55
+ @provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
56
+ def load_model(hidden_size: int = 4):
57
+ ...
51
58
 
59
+ server = SharedTensorServer(provider)
60
+ server.start(blocking=True)
61
+ ```
52
62
 
53
- @provider.share(
54
- execution="task",
55
- managed=True,
56
- concurrency="serialized",
57
- cache_format_key="model:{hidden_size}",
58
- )
59
- def load_model(hidden_size: int = 4) -> torch.nn.Module:
60
- return torch.nn.Linear(hidden_size, 2, device="cuda")
63
+ Client process:
61
64
 
65
+ ```python
66
+ import torch
67
+
68
+ from shared_tensor import SharedObjectHandle, SharedTensorClient
62
69
 
70
+ client = SharedTensorClient()
63
71
  x = torch.ones(1, 4, device="cuda")
64
- result = load_model(hidden_size=4)
72
+ result = client.call("load_model", hidden_size=4)
65
73
  if isinstance(result, SharedObjectHandle):
66
74
  with result as handle:
67
75
  y = handle.value(x)
68
- else:
69
- y = result(x)
70
76
  ```
71
77
 
72
- Server process:
78
+ This keeps the contract explicit:
73
79
 
74
- ```bash
75
- SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
80
+ ```text
81
+ server process client process
82
+ ------------------------------ ------------------------------
83
+ owns CUDA allocations issues local UDS RPC requests
84
+ executes endpoint functions reopens CUDA objects via torch IPC
85
+ manages cache and refcounts releases managed handles explicitly
76
86
  ```
77
87
 
78
- Client process with the exact same file:
88
+ ## Example: Same Code, Two Processes
89
+
90
+ See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
79
91
 
80
92
  ```bash
93
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
81
94
  SHARED_TENSOR_ENABLED=1 python demo.py
82
95
  ```
83
96
 
@@ -88,7 +101,7 @@ same code
88
101
 
89
102
  server process client process
90
103
  ------------------------------ ------------------------------
91
- provider auto-starts UDS daemon provider builds client wrappers
104
+ provider auto-starts local thread provider builds client wrappers
92
105
  shared function runs locally shared function becomes RPC call
93
106
  CUDA object stays on same GPU CUDA object is reopened via torch IPC
94
107
  ```
@@ -151,19 +164,19 @@ SharedTensorProvider(enabled=None)
151
164
  Provider runtime controls:
152
165
 
153
166
  ```python
154
- SharedTensorProvider(server_process_start_method="fork")
155
167
  SharedTensorProvider(server_startup_timeout=30.0)
156
168
  provider.get_runtime_info()
157
169
  ```
158
170
 
159
- Use `server_process_start_method="fork"` when you explicitly want POSIX fork behavior.
160
- Leave it as `None` to let the library choose a safer default for the current entrypoint.
171
+ Non-blocking provider autostart runs the UDS server in a background thread inside the current process.
161
172
 
162
173
  `execution_mode="auto"` behaves as follows:
163
174
  - disabled: local mode
164
- - enabled + `SHARED_TENSOR_ROLE=server`: auto-start local server and execute endpoints locally
175
+ - enabled + `SHARED_TENSOR_ROLE=server`: auto-start a local background server thread and execute endpoints locally
165
176
  - enabled + role unset: build client wrappers
166
177
 
178
+ For production deployment, prefer explicit `SharedTensorServer(...).start(blocking=True)` in a dedicated server process.
179
+
167
180
  Socket selection is per CUDA device:
168
181
  - base path comes from `SHARED_TENSOR_BASE_PATH` or `/tmp/shared-tensor`
169
182
  - runtime socket path is `<base_path>-<device_index>.sock`
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.5"
7
+ version = "0.2.7"
8
8
  description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -33,18 +33,20 @@ classifiers = [
33
33
  "Intended Audience :: Science/Research",
34
34
  "Operating System :: POSIX :: Linux",
35
35
  "Programming Language :: Python :: 3",
36
+ "Programming Language :: Python :: 3.9",
36
37
  "Programming Language :: Python :: 3.10",
37
38
  "Programming Language :: Python :: 3.11",
38
39
  "Programming Language :: Python :: 3.12",
40
+ "Programming Language :: Python :: 3.13",
39
41
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
40
42
  "Topic :: Software Development :: Libraries :: Python Modules",
41
43
  "Topic :: System :: Distributed Computing",
42
44
  ]
43
- requires-python = ">=3.10"
45
+ requires-python = ">=3.9,<3.14"
44
46
  dependencies = [
45
47
  "cloudpickle>=3.0.0",
46
48
  "numpy<2",
47
- "torch>=2.2.0",
49
+ "torch>=2.1,<2.8",
48
50
  ]
49
51
 
50
52
  [project.optional-dependencies]
@@ -89,7 +91,7 @@ shared_tensor = ["*.so", "*.dll", "*.dylib"]
89
91
 
90
92
  [tool.black]
91
93
  line-length = 88
92
- target-version = ['py310', 'py311', 'py312']
94
+ target-version = ['py39', 'py310', 'py311', 'py312', 'py313']
93
95
  include = '\.pyi?$'
94
96
  extend-exclude = '''
95
97
  /(
@@ -115,7 +117,7 @@ use_parentheses = true
115
117
  ensure_newline_before_comments = true
116
118
 
117
119
  [tool.mypy]
118
- python_version = "3.10"
120
+ python_version = "3.9"
119
121
  warn_return_any = true
120
122
  warn_unused_configs = true
121
123
  disallow_untyped_defs = true
@@ -180,7 +182,7 @@ exclude_lines = [
180
182
  ]
181
183
 
182
184
  [tool.ruff]
183
- target-version = "py310"
185
+ target-version = "py39"
184
186
  line-length = 88
185
187
 
186
188
  [tool.ruff.lint]
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.5"
22
+ __version__ = "0.2.7"
@@ -66,6 +66,7 @@ class TaskInfo:
66
66
  class _TaskEntry:
67
67
  info: TaskInfo
68
68
  future: Future[Any]
69
+ local_result: Any = None
69
70
 
70
71
 
71
72
  class TaskManager:
@@ -139,6 +140,8 @@ class TaskManager:
139
140
  )
140
141
  return
141
142
 
143
+ self._store_local_result(task_id, result)
144
+
142
145
  if result is None:
143
146
  self._transition(
144
147
  task_id,
@@ -191,6 +194,13 @@ class TaskManager:
191
194
  for key, value in updates.items():
192
195
  setattr(entry.info, key, value)
193
196
 
197
+ def _store_local_result(self, task_id: str, value: Any) -> None:
198
+ with self._lock:
199
+ entry = self._tasks.get(task_id)
200
+ if entry is None:
201
+ return
202
+ entry.local_result = value
203
+
194
204
  def get(self, task_id: str) -> TaskInfo:
195
205
  self._maybe_cleanup()
196
206
  with self._lock:
@@ -207,6 +217,24 @@ class TaskManager:
207
217
  return None
208
218
  return deserialize_payload(encoding, payload_bytes)
209
219
 
220
+ def result_local(self, task_id: str) -> Any:
221
+ self._maybe_cleanup()
222
+ with self._lock:
223
+ entry = self._tasks.get(task_id)
224
+ if entry is None:
225
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
226
+ info = copy.deepcopy(entry.info)
227
+ value = entry.local_result
228
+ if info.status == TaskStatus.CANCELLED:
229
+ raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
230
+ if info.status == TaskStatus.FAILED:
231
+ raise SharedTensorTaskError(info.error_message or f"Task '{task_id}' failed")
232
+ if info.status != TaskStatus.COMPLETED:
233
+ raise SharedTensorTaskError(
234
+ f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
235
+ )
236
+ return value
237
+
210
238
  def wait_result_payload(
211
239
  self,
212
240
  task_id: str,
@@ -242,6 +270,21 @@ class TaskManager:
242
270
  "object_id": info.metadata.get("object_id"),
243
271
  }
244
272
 
273
+ def wait_result_local(self, task_id: str, timeout: float | None = None) -> Any:
274
+ self._maybe_cleanup()
275
+ with self._lock:
276
+ entry = self._tasks.get(task_id)
277
+ if entry is None:
278
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
279
+ future = entry.future
280
+ try:
281
+ future.result(timeout=timeout)
282
+ except FutureTimeoutError as exc:
283
+ raise SharedTensorTaskError(
284
+ f"Task '{task_id}' did not complete within {timeout} seconds"
285
+ ) from exc
286
+ return self.result_local(task_id)
287
+
245
288
  def cancel(self, task_id: str) -> bool:
246
289
  self._maybe_cleanup()
247
290
  with self._lock:
@@ -8,16 +8,25 @@ from dataclasses import dataclass
8
8
  from typing import Any, cast
9
9
 
10
10
  from shared_tensor.errors import (
11
+ SharedTensorCapabilityError,
11
12
  SharedTensorClientError,
13
+ SharedTensorConfigurationError,
14
+ SharedTensorError,
15
+ SharedTensorProviderError,
12
16
  SharedTensorProtocolError,
13
17
  SharedTensorRemoteError,
18
+ SharedTensorSerializationError,
19
+ SharedTensorTaskError,
14
20
  )
15
21
  from shared_tensor.managed_object import ReleaseHandle, SharedObjectHandle
22
+ from shared_tensor.runtime import get_local_server
16
23
  from shared_tensor.transport import recv_message, send_message
24
+ from shared_tensor.async_task import TaskStatus
17
25
  from shared_tensor.utils import (
18
26
  deserialize_payload,
19
27
  resolve_runtime_socket_path,
20
28
  serialize_call_payloads,
29
+ validate_payload_for_transport,
21
30
  )
22
31
 
23
32
 
@@ -50,6 +59,54 @@ class SharedTensorClient:
50
59
  self.timeout = timeout
51
60
  self.verbose_debug = verbose_debug
52
61
 
62
+ def _local_server(self):
63
+ return get_local_server(self.socket_path)
64
+
65
+ @staticmethod
66
+ def _remote_error_from_local(exc: SharedTensorError) -> SharedTensorRemoteError:
67
+ if isinstance(exc, SharedTensorProtocolError):
68
+ code = 1
69
+ elif isinstance(exc, SharedTensorProviderError):
70
+ code = 2
71
+ elif isinstance(exc, SharedTensorSerializationError):
72
+ code = 3
73
+ elif isinstance(exc, SharedTensorCapabilityError):
74
+ code = 4
75
+ elif isinstance(exc, SharedTensorTaskError):
76
+ code = 5
77
+ elif isinstance(exc, SharedTensorConfigurationError):
78
+ code = 6
79
+ else:
80
+ code = 7
81
+ return SharedTensorRemoteError(
82
+ f"Remote error [{code}]: {exc}",
83
+ code=code,
84
+ data=None,
85
+ error_type=type(exc).__name__,
86
+ )
87
+
88
+ def _run_local(self, operation):
89
+ try:
90
+ return operation()
91
+ except SharedTensorError as exc:
92
+ raise self._remote_error_from_local(exc) from exc
93
+
94
+ def _decode_local_result(self, result: Any) -> Any:
95
+ if result is None:
96
+ return None
97
+ value = result.value
98
+ if value is None:
99
+ return None
100
+ validate_payload_for_transport(value, allow_dict_keys=isinstance(value, dict))
101
+ object_id = result.object_id
102
+ if object_id is None:
103
+ return value
104
+ return SharedObjectHandle(
105
+ object_id=cast(str, object_id),
106
+ value=value,
107
+ _releaser=_ClientReleaser(client=self, object_id=cast(str, object_id)),
108
+ )
109
+
53
110
  def _send_request(self, request: dict[str, Any]) -> Any:
54
111
  method = request.get("method", "<unknown>")
55
112
  if self.verbose_debug:
@@ -104,6 +161,13 @@ class SharedTensorClient:
104
161
  def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
105
162
  if self.verbose_debug:
106
163
  logger.debug("Client calling endpoint", extra={"endpoint": endpoint})
164
+ local_server = self._local_server()
165
+ if local_server is not None:
166
+ return self._run_local(
167
+ lambda: self._decode_local_result(
168
+ local_server.call_local_client(endpoint, args=tuple(args), kwargs=dict(kwargs))
169
+ )
170
+ )
107
171
  encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
108
172
  result = self._request(
109
173
  "call",
@@ -119,6 +183,19 @@ class SharedTensorClient:
119
183
  def submit(self, endpoint: str, *args: Any, **kwargs: Any) -> str:
120
184
  if self.verbose_debug:
121
185
  logger.debug("Client submitting task", extra={"endpoint": endpoint})
186
+ local_server = self._local_server()
187
+ if local_server is not None:
188
+ return self._run_local(
189
+ lambda: cast(
190
+ str,
191
+ local_server._submit_endpoint_task(
192
+ endpoint,
193
+ local_server.provider.get_endpoint(endpoint),
194
+ tuple(args),
195
+ dict(kwargs),
196
+ ).task_id,
197
+ )
198
+ )
122
199
  encoding, args_payload, kwargs_payload = serialize_call_payloads(tuple(args), dict(kwargs))
123
200
  result = self._request(
124
201
  "submit",
@@ -134,18 +211,43 @@ class SharedTensorClient:
134
211
  def release(self, object_id: str) -> bool:
135
212
  if self.verbose_debug:
136
213
  logger.debug("Client releasing managed object", extra={"object_id": object_id})
214
+ local_server = self._local_server()
215
+ if local_server is not None:
216
+ return self._run_local(
217
+ lambda: bool(local_server._handle_release_object({"object_id": object_id})["released"])
218
+ )
137
219
  result = self._request("release_object", {"object_id": object_id})
138
220
  return bool(result["released"])
139
221
 
140
222
  def release_many(self, object_ids: list[str]) -> dict[str, bool]:
223
+ local_server = self._local_server()
224
+ if local_server is not None:
225
+ return self._run_local(
226
+ lambda: {
227
+ object_id: bool(released)
228
+ for object_id, released in local_server._handle_release_objects({"object_ids": object_ids})[
229
+ "released"
230
+ ].items()
231
+ }
232
+ )
141
233
  result = self._request("release_objects", {"object_ids": object_ids})
142
234
  return {object_id: bool(released) for object_id, released in result["released"].items()}
143
235
 
144
236
  def get_object_info(self, object_id: str) -> dict[str, Any] | None:
237
+ local_server = self._local_server()
238
+ if local_server is not None:
239
+ return self._run_local(
240
+ lambda: cast(
241
+ dict[str, Any] | None,
242
+ local_server._handle_get_object_info({"object_id": object_id}).get("object"),
243
+ )
244
+ )
145
245
  result = self._request("get_object_info", {"object_id": object_id})
146
246
  return cast(dict[str, Any] | None, result.get("object"))
147
247
 
148
248
  def ping(self) -> bool:
249
+ if self._local_server() is not None:
250
+ return True
149
251
  try:
150
252
  self._request("ping")
151
253
  except (SharedTensorClientError, SharedTensorRemoteError):
@@ -153,29 +255,66 @@ class SharedTensorClient:
153
255
  return True
154
256
 
155
257
  def get_server_info(self) -> dict[str, Any]:
258
+ local_server = self._local_server()
259
+ if local_server is not None:
260
+ return self._run_local(lambda: cast(dict[str, Any], local_server._get_server_info()))
156
261
  return cast(dict[str, Any], self._request("get_server_info"))
157
262
 
158
263
  def list_endpoints(self) -> dict[str, Any]:
264
+ local_server = self._local_server()
265
+ if local_server is not None:
266
+ return self._run_local(lambda: cast(dict[str, Any], local_server.provider.list_endpoints()))
159
267
  return cast(dict[str, Any], self._request("list_endpoints"))
160
268
 
161
269
  def get_task_status(self, task_id: str) -> dict[str, Any]:
270
+ local_server = self._local_server()
271
+ if local_server is not None:
272
+ return self._run_local(
273
+ lambda: cast(dict[str, Any], local_server._task_manager_instance().get(task_id).to_dict())
274
+ )
162
275
  return cast(dict[str, Any], self._request("get_task", {"task_id": task_id}))
163
276
 
164
277
  def get_task_result(self, task_id: str) -> Any:
278
+ local_server = self._local_server()
279
+ if local_server is not None:
280
+ return self._run_local(
281
+ lambda: self._decode_local_result(local_server.get_task_result_local(task_id))
282
+ )
165
283
  return self._decode_rpc_payload(self._request("get_task_result", {"task_id": task_id}))
166
284
 
167
285
  def wait_task(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
168
286
  if self.verbose_debug:
169
287
  logger.debug("Client waiting for task", extra={"task_id": task_id, "timeout": timeout})
288
+ local_server = self._local_server()
289
+ if local_server is not None:
290
+ return self._run_local(
291
+ lambda: cast(dict[str, Any], local_server.wait_task_local(task_id, timeout=timeout))
292
+ )
170
293
  params = {"task_id": task_id}
171
294
  if timeout is not None:
172
295
  params["timeout"] = timeout
173
296
  return cast(dict[str, Any], self._request("wait_task", params))
174
297
 
175
298
  def cancel_task(self, task_id: str) -> bool:
299
+ local_server = self._local_server()
300
+ if local_server is not None:
301
+ return self._run_local(lambda: bool(local_server._task_manager_instance().cancel(task_id)))
176
302
  return bool(self._request("cancel_task", {"task_id": task_id})["cancelled"])
177
303
 
178
304
  def list_tasks(self, status: str | None = None) -> dict[str, Any]:
305
+ local_server = self._local_server()
306
+ if local_server is not None:
307
+ return self._run_local(
308
+ lambda: cast(
309
+ dict[str, Any],
310
+ {
311
+ listed_task_id: info.to_dict()
312
+ for listed_task_id, info in local_server._task_manager_instance()
313
+ .list(status=None if status is None else TaskStatus(status))
314
+ .items()
315
+ },
316
+ )
317
+ )
179
318
  params = {"status": status} if status else None
180
319
  return cast(dict[str, Any], self._request("list_tasks", params))
181
320
 
@@ -92,7 +92,6 @@ class SharedTensorProvider:
92
92
  device_index: int | None = None,
93
93
  timeout: float = 30.0,
94
94
  execution_mode: str = "auto",
95
- server_process_start_method: str | None = None,
96
95
  server_startup_timeout: float = 30.0,
97
96
  verbose_debug: bool = False,
98
97
  ) -> None:
@@ -106,7 +105,6 @@ class SharedTensorProvider:
106
105
  self.timeout = timeout
107
106
  self.execution_mode = resolved_mode
108
107
  self.auto_mode = auto_mode
109
- self.server_process_start_method = server_process_start_method
110
108
  self.server_startup_timeout = server_startup_timeout
111
109
  self.verbose_debug = verbose_debug
112
110
  self._client: Any | None = None
@@ -165,9 +163,6 @@ class SharedTensorProvider:
165
163
  if self._should_autostart_server():
166
164
  self._restart_autostart_server()
167
165
 
168
- if self.execution_mode == "server":
169
- return func
170
-
171
166
  @wraps(func)
172
167
  def wrapper(*args: Any, **kwargs: Any) -> Any:
173
168
  return self.call(endpoint_name, *args, **kwargs)
@@ -215,7 +210,11 @@ class SharedTensorProvider:
215
210
  def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
216
211
  if self.verbose_debug:
217
212
  logger.debug("Provider dispatching call", extra={"endpoint": endpoint, "mode": self.execution_mode})
218
- if self.execution_mode in {"server", "local"}:
213
+ if self.execution_mode == "server":
214
+ if self._server is not None and hasattr(self._server, "invoke_local"):
215
+ return self._server.invoke_local(endpoint, args=args, kwargs=kwargs)
216
+ return self.invoke_local(endpoint, args=args, kwargs=kwargs)
217
+ if self.execution_mode == "local":
219
218
  return self.invoke_local(endpoint, args=args, kwargs=kwargs)
220
219
  return self._get_client().call(endpoint, *args, **kwargs)
221
220
 
@@ -370,7 +369,6 @@ class SharedTensorProvider:
370
369
  "Provider restarting autostart server",
371
370
  extra={
372
371
  "socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
373
- "process_start_method": self.server_process_start_method,
374
372
  },
375
373
  )
376
374
  if self._server is not None:
@@ -378,7 +376,6 @@ class SharedTensorProvider:
378
376
  self._server = SharedTensorServer(
379
377
  self,
380
378
  socket_path=resolve_runtime_socket_path(self.base_path, self.device_index),
381
- process_start_method=self.server_process_start_method,
382
379
  startup_timeout=self.server_startup_timeout,
383
380
  verbose_debug=self.verbose_debug,
384
381
  )
@@ -0,0 +1,30 @@
1
+ """In-process runtime registry for thread-backed local servers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from threading import RLock
6
+ from typing import TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from shared_tensor.server import SharedTensorServer
10
+
11
+
12
+ _LOCK = RLock()
13
+ _SERVERS: dict[str, "SharedTensorServer"] = {}
14
+
15
+
16
+ def register_local_server(socket_path: str, server: "SharedTensorServer") -> None:
17
+ with _LOCK:
18
+ _SERVERS[socket_path] = server
19
+
20
+
21
+ def unregister_local_server(socket_path: str, server: "SharedTensorServer") -> None:
22
+ with _LOCK:
23
+ current = _SERVERS.get(socket_path)
24
+ if current is server:
25
+ _SERVERS.pop(socket_path, None)
26
+
27
+
28
+ def get_local_server(socket_path: str) -> "SharedTensorServer | None":
29
+ with _LOCK:
30
+ return _SERVERS.get(socket_path)
@@ -2,16 +2,13 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import cloudpickle
6
5
  import logging
7
- import multiprocessing as mp
8
6
  import os
9
- import sys
10
7
  import socket
11
8
  import threading
12
9
  import time
13
10
  from concurrent.futures import Future
14
- from dataclasses import dataclass
11
+ from dataclasses import dataclass, field
15
12
  from typing import Any
16
13
 
17
14
  from shared_tensor.async_task import TaskManager, TaskStatus
@@ -25,6 +22,7 @@ from shared_tensor.errors import (
25
22
  )
26
23
  from shared_tensor.managed_object import ManagedObjectRegistry
27
24
  from shared_tensor.provider import EndpointDefinition, SharedTensorProvider
25
+ from shared_tensor.runtime import register_local_server, unregister_local_server
28
26
  from shared_tensor.transport import recv_message, send_message
29
27
  from shared_tensor.utils import (
30
28
  CONTROL_ENCODING,
@@ -41,11 +39,33 @@ from shared_tensor.utils import (
41
39
  logger = logging.getLogger(__name__)
42
40
 
43
41
 
42
+ def _server_version() -> str:
43
+ try:
44
+ from shared_tensor import __version__
45
+ except ImportError:
46
+ return "unknown"
47
+ return __version__
48
+
49
+
44
50
  @dataclass(slots=True)
45
51
  class _InFlightCall:
46
52
  future: Future[dict[str, Any]]
47
53
 
48
54
 
55
+ @dataclass(slots=True)
56
+ class _ServerThreadState:
57
+ thread: threading.Thread
58
+ ready: threading.Event = field(default_factory=threading.Event)
59
+ stopped: threading.Event = field(default_factory=threading.Event)
60
+ error: BaseException | None = None
61
+
62
+
63
+ @dataclass(slots=True)
64
+ class _EndpointResult:
65
+ value: Any
66
+ object_id: str | None = None
67
+
68
+
49
69
  class SharedTensorServer:
50
70
  def __init__(
51
71
  self,
@@ -72,6 +92,7 @@ class SharedTensorServer:
72
92
  self.startup_timeout = startup_timeout
73
93
  self.listener: socket.socket | None = None
74
94
  self.server_process: Any | None = None
95
+ self.server_thread: _ServerThreadState | None = None
75
96
  self._resolved_process_start_method: str | None = None
76
97
  self.running = False
77
98
  self.started_at: float | None = None
@@ -81,10 +102,13 @@ class SharedTensorServer:
81
102
  }
82
103
  self._task_manager: TaskManager | None = None
83
104
  self._cache: dict[str, dict[str, Any]] = {}
105
+ self._local_cache: dict[str, Any] = {}
84
106
  self._managed_objects = ManagedObjectRegistry()
85
107
  self._inflight: dict[str, _InFlightCall] = {}
86
108
  self._endpoint_locks: dict[str, threading.Lock] = {}
87
109
  self._coordination_lock = threading.RLock()
110
+ if getattr(self.provider, "_server", None) is None:
111
+ self.provider._server = self
88
112
 
89
113
  def process_request(self, request: dict[str, Any]) -> dict[str, Any]:
90
114
  if self.verbose_debug:
@@ -180,22 +204,22 @@ class SharedTensorServer:
180
204
  ) -> Any:
181
205
  return self._task_manager_instance().submit(
182
206
  endpoint,
183
- self._execute_endpoint_call,
207
+ self._execute_endpoint_result,
184
208
  (endpoint, definition, args, kwargs),
185
209
  {},
186
- result_encoder=lambda payload: payload,
210
+ result_encoder=self._encode_endpoint_result,
187
211
  )
188
212
 
189
- def _execute_endpoint_call(
213
+ def _execute_endpoint_result(
190
214
  self,
191
215
  endpoint: str,
192
216
  definition: EndpointDefinition,
193
217
  args: tuple[Any, ...],
194
218
  kwargs: dict[str, Any],
195
- ) -> dict[str, Any]:
219
+ ) -> _EndpointResult:
196
220
  cache_key = self._cache_key(endpoint, definition, args, kwargs)
197
221
  if cache_key is not None:
198
- cached = self._lookup_cached_result(definition, cache_key)
222
+ cached = self._lookup_cached_result_value(definition, cache_key)
199
223
  if cached is not None:
200
224
  if self.verbose_debug:
201
225
  logger.debug("Server cache hit", extra={"endpoint": endpoint, "cache_key": cache_key})
@@ -207,20 +231,15 @@ class SharedTensorServer:
207
231
  if self.verbose_debug and owner:
208
232
  logger.debug("Server created singleflight entry", extra={"endpoint": endpoint, "cache_key": inflight_key})
209
233
  if not owner:
210
- if self.verbose_debug:
211
- logger.debug("Server joined singleflight entry", extra={"endpoint": endpoint, "cache_key": inflight_key})
212
- if definition.managed:
213
- payload = future.result()
214
- object_id = payload.get("object_id")
215
- if object_id is not None:
216
- self._managed_objects.add_ref(object_id)
217
- return payload
218
- return future.result()
234
+ result = future.result()
235
+ if definition.managed and result.object_id is not None:
236
+ self._managed_objects.add_ref(result.object_id)
237
+ return result
219
238
  else:
220
239
  future = None
221
240
 
222
241
  try:
223
- encoded = self._run_endpoint_under_policy(endpoint, definition, args, kwargs, cache_key)
242
+ result = self._run_endpoint_under_policy(endpoint, definition, args, kwargs, cache_key)
224
243
  except Exception as exc:
225
244
  if future is not None:
226
245
  future.set_exception(exc)
@@ -228,9 +247,20 @@ class SharedTensorServer:
228
247
  raise
229
248
 
230
249
  if future is not None:
231
- future.set_result(encoded)
250
+ future.set_result(result)
232
251
  self._release_inflight(inflight_key, future)
233
- return encoded
252
+ return result
253
+
254
+ def _execute_endpoint_call(
255
+ self,
256
+ endpoint: str,
257
+ definition: EndpointDefinition,
258
+ args: tuple[Any, ...],
259
+ kwargs: dict[str, Any],
260
+ ) -> dict[str, Any]:
261
+ return self._encode_endpoint_result(
262
+ self._execute_endpoint_result(endpoint, definition, args, kwargs)
263
+ )
234
264
 
235
265
  def _run_endpoint_under_policy(
236
266
  self,
@@ -239,11 +269,11 @@ class SharedTensorServer:
239
269
  args: tuple[Any, ...],
240
270
  kwargs: dict[str, Any],
241
271
  cache_key: str | None,
242
- ) -> dict[str, Any]:
272
+ ) -> _EndpointResult:
243
273
  if definition.concurrency == "serialized":
244
274
  lock = self._endpoint_lock(endpoint)
245
275
  with lock:
246
- cached = self._lookup_cached_result(definition, cache_key)
276
+ cached = self._lookup_cached_result_value(definition, cache_key)
247
277
  if cached is not None:
248
278
  return cached
249
279
  return self._materialize_endpoint_result(endpoint, definition, args, kwargs, cache_key)
@@ -256,16 +286,15 @@ class SharedTensorServer:
256
286
  args: tuple[Any, ...],
257
287
  kwargs: dict[str, Any],
258
288
  cache_key: str | None,
259
- ) -> dict[str, Any]:
289
+ ) -> _EndpointResult:
260
290
  if definition.managed:
261
291
  return self._materialize_managed_result(endpoint, definition, args, kwargs, cache_key)
262
292
  value = definition.func(*args, **kwargs)
263
293
  if self.verbose_debug:
264
294
  logger.debug("Server executed direct endpoint", extra={"endpoint": endpoint})
265
- result = self._encode_result(value)
266
295
  if cache_key is not None:
267
- self._cache[cache_key] = result
268
- return result
296
+ self._local_cache[cache_key] = value
297
+ return _EndpointResult(value=value)
269
298
 
270
299
  def _materialize_managed_result(
271
300
  self,
@@ -274,24 +303,24 @@ class SharedTensorServer:
274
303
  args: tuple[Any, ...],
275
304
  kwargs: dict[str, Any],
276
305
  cache_key: str | None,
277
- ) -> dict[str, Any]:
306
+ ) -> _EndpointResult:
278
307
  if cache_key is not None:
279
308
  cached = self._managed_objects.get_cached(cache_key)
280
309
  if cached is not None:
281
310
  self._managed_objects.add_ref(cached.object_id)
282
- return self._encode_result(cached.value, object_id=cached.object_id)
311
+ return _EndpointResult(value=cached.value, object_id=cached.object_id)
283
312
 
284
313
  result = definition.func(*args, **kwargs)
285
314
  if self.verbose_debug:
286
315
  logger.debug("Server created managed object", extra={"endpoint": endpoint, "cache_key": cache_key})
287
316
  entry = self._managed_objects.register(endpoint=endpoint, value=result, cache_key=cache_key)
288
- return self._encode_result(entry.value, object_id=entry.object_id)
317
+ return _EndpointResult(value=entry.value, object_id=entry.object_id)
289
318
 
290
- def _lookup_cached_result(
319
+ def _lookup_cached_result_value(
291
320
  self,
292
321
  definition: EndpointDefinition,
293
322
  cache_key: str | None,
294
- ) -> dict[str, Any] | None:
323
+ ) -> _EndpointResult | None:
295
324
  if cache_key is None:
296
325
  return None
297
326
  if definition.managed:
@@ -299,8 +328,81 @@ class SharedTensorServer:
299
328
  if cached is None:
300
329
  return None
301
330
  self._managed_objects.add_ref(cached.object_id)
302
- return self._encode_result(cached.value, object_id=cached.object_id)
303
- return self._cache.get(cache_key)
331
+ return _EndpointResult(value=cached.value, object_id=cached.object_id)
332
+ local_value = self._local_cache.get(cache_key)
333
+ if local_value is None:
334
+ return None
335
+ return _EndpointResult(value=local_value)
336
+
337
+ def call_local_client(
338
+ self,
339
+ endpoint: str,
340
+ *,
341
+ args: tuple[Any, ...] = (),
342
+ kwargs: dict[str, Any] | None = None,
343
+ ) -> _EndpointResult | None:
344
+ definition = self.provider.get_endpoint(endpoint)
345
+ resolved_kwargs = kwargs or {}
346
+ if definition.execution == "task":
347
+ task_info = self._submit_endpoint_task(endpoint, definition, args, resolved_kwargs)
348
+ return self.wait_task_result_local(task_info.task_id)
349
+ return self._execute_endpoint_result(endpoint, definition, args, resolved_kwargs)
350
+
351
+ def get_task_result_local(self, task_id: str) -> _EndpointResult | None:
352
+ result = self._task_manager_instance().result_local(task_id)
353
+ if result is None:
354
+ return None
355
+ return result
356
+
357
+ def wait_task_result_local(self, task_id: str, timeout: float | None = None) -> _EndpointResult | None:
358
+ result = self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
359
+ if result is None:
360
+ return None
361
+ return result
362
+
363
+ def wait_task_local(self, task_id: str, timeout: float | None = None) -> dict[str, Any]:
364
+ try:
365
+ self._task_manager_instance().wait_result_local(task_id, timeout=timeout)
366
+ except SharedTensorTaskError:
367
+ info = self._task_manager_instance().get(task_id)
368
+ if info.status in {TaskStatus.PENDING, TaskStatus.RUNNING}:
369
+ return info.to_dict()
370
+ raise
371
+ return self._task_manager_instance().get(task_id).to_dict()
372
+
373
+ def encode_local_result(self, result: _EndpointResult | None) -> dict[str, Any]:
374
+ if result is None:
375
+ return {"encoding": None, "payload_bytes": None, "object_id": None}
376
+ return self._encode_endpoint_result(result)
377
+
378
+ def invoke_local(
379
+ self,
380
+ endpoint: str,
381
+ *,
382
+ args: tuple[Any, ...] = (),
383
+ kwargs: dict[str, Any] | None = None,
384
+ ) -> Any:
385
+ definition = self.provider.get_endpoint(endpoint)
386
+ resolved_kwargs = kwargs or {}
387
+ cache_key = self._cache_key(endpoint, definition, args, resolved_kwargs)
388
+ if definition.managed:
389
+ if cache_key is not None:
390
+ cached = self._managed_objects.get_cached(cache_key)
391
+ if cached is not None:
392
+ return cached.value
393
+ value = definition.func(*args, **resolved_kwargs)
394
+ if cache_key is not None:
395
+ existing = self._managed_objects.get_cached(cache_key)
396
+ if existing is not None:
397
+ return existing.value
398
+ self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
399
+ return value
400
+ if cache_key is not None and cache_key in self._local_cache:
401
+ return self._local_cache[cache_key]
402
+ value = definition.func(*args, **resolved_kwargs)
403
+ if cache_key is not None:
404
+ self._local_cache[cache_key] = value
405
+ return value
304
406
 
305
407
  def _cache_key(
306
408
  self,
@@ -400,6 +502,9 @@ class SharedTensorServer:
400
502
  encoding, payload = serialize_payload(value)
401
503
  return {"encoding": encoding, "payload_bytes": payload, "object_id": object_id}
402
504
 
505
+ def _encode_endpoint_result(self, result: _EndpointResult) -> dict[str, Any]:
506
+ return self._encode_result(result.value, object_id=result.object_id)
507
+
403
508
  def _task_manager_instance(self) -> TaskManager:
404
509
  if self._task_manager is None:
405
510
  self._task_manager = TaskManager(
@@ -426,7 +531,7 @@ class SharedTensorServer:
426
531
  uptime = 0.0 if self.started_at is None else time.time() - self.started_at
427
532
  return {
428
533
  "server": "SharedTensorServer",
429
- "version": "0.2.4",
534
+ "version": _server_version(),
430
535
  "socket_path": self.socket_path,
431
536
  "uptime": uptime,
432
537
  "running": self.running,
@@ -448,101 +553,66 @@ class SharedTensorServer:
448
553
  "data": None,
449
554
  }
450
555
 
451
- def _resolve_process_start_method(self) -> str:
452
- if self.process_start_method is not None:
453
- allowed = set(mp.get_all_start_methods())
454
- if self.process_start_method not in allowed:
455
- raise SharedTensorConfigurationError(
456
- f"Unsupported process_start_method '{self.process_start_method}'"
457
- )
458
- return self.process_start_method
459
- if os.name != "posix":
460
- return "spawn"
461
- try:
462
- import torch
463
- except ImportError:
464
- torch = None
465
- if torch is not None and torch.cuda.is_available() and torch.cuda.is_initialized():
466
- return "spawn"
467
- if not hasattr(sys.modules.get("__main__"), "__file__"):
468
- return "fork"
469
- return "spawn"
470
-
471
556
  def start(self, blocking: bool = True) -> None:
472
557
  if self.verbose_debug:
473
558
  logger.info("Server starting", extra={"socket_path": self.socket_path, "blocking": blocking})
474
- if self.running:
559
+ if self.running or self.server_thread is not None:
475
560
  raise SharedTensorConfigurationError("Server is already running")
476
561
  if blocking:
477
562
  self._resolved_process_start_method = None
478
563
  self._serve_forever()
479
564
  return
480
- if os.name != "posix":
565
+ if self.process_start_method is not None:
481
566
  raise SharedTensorConfigurationError(
482
- "Non-blocking shared_tensor servers require POSIX multiprocessing support"
567
+ "process_start_method is not supported for thread-backed non-blocking servers"
483
568
  )
484
- start_method = self._resolve_process_start_method()
485
- payload = cloudpickle.dumps(self.provider)
486
- process = mp.get_context(start_method).Process(
487
- target=self._serve_forever_from_payload,
488
- args=(
489
- payload,
490
- self.socket_path,
491
- self.max_request_bytes,
492
- self.max_workers,
493
- self.result_ttl,
494
- self.verbose_debug,
495
- start_method,
496
- ),
497
- name=f"shared-tensor-daemon:{self.socket_path}",
569
+ thread = threading.Thread(
570
+ target=self._serve_forever_in_thread,
571
+ name=f"shared-tensor-server:{self.socket_path}",
572
+ daemon=True,
498
573
  )
499
- process.start()
500
- if self.verbose_debug:
501
- logger.info(
502
- "Server spawned background process",
503
- extra={"socket_path": self.socket_path, "pid": process.pid, "start_method": start_method},
504
- )
505
- self.server_process = process
506
- self._resolved_process_start_method = start_method
507
- self.running = True
508
- self.started_at = time.time()
574
+ state = _ServerThreadState(thread=thread)
575
+ self.server_thread = state
576
+ self._resolved_process_start_method = "thread"
577
+ thread.start()
578
+ if not state.ready.wait(timeout=self.startup_timeout):
579
+ self.stop()
580
+ raise TimeoutError(f"Timed out waiting for server socket {self.socket_path}")
581
+ if state.error is not None:
582
+ error = state.error
583
+ self.stop()
584
+ raise SharedTensorConfigurationError(
585
+ f"Failed to start background server thread for {self.socket_path}: {error}"
586
+ ) from error
509
587
 
510
- @staticmethod
511
- def _serve_forever_from_payload(
512
- payload: bytes,
513
- socket_path: str,
514
- max_request_bytes: int,
515
- max_workers: int,
516
- result_ttl: float,
517
- verbose_debug: bool,
518
- process_start_method: str | None,
519
- ) -> None:
520
- SharedTensorServer._configure_cuda_runtime()
521
- provider = cloudpickle.loads(payload)
522
- server = SharedTensorServer(
523
- provider,
524
- socket_path=socket_path,
525
- max_request_bytes=max_request_bytes,
526
- max_workers=max_workers,
527
- result_ttl=result_ttl,
528
- process_start_method=process_start_method,
529
- verbose_debug=verbose_debug,
530
- )
531
- server._resolved_process_start_method = process_start_method
532
- server._serve_forever()
588
+ def _serve_forever_in_thread(self) -> None:
589
+ state = self.server_thread
590
+ if state is None:
591
+ return
592
+ try:
593
+ self._serve_forever(started_event=state.ready)
594
+ except BaseException as exc: # noqa: BLE001
595
+ state.error = exc
596
+ state.ready.set()
597
+ raise
598
+ finally:
599
+ state.stopped.set()
533
600
 
534
- def _serve_forever(self) -> None:
601
+ def _serve_forever(self, *, started_event: threading.Event | None = None) -> None:
535
602
  self._configure_cuda_runtime()
536
603
  unlink_socket_path(self.socket_path)
537
604
  listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
538
- listener.bind(self.socket_path)
539
- listener.listen()
540
- if self.verbose_debug:
541
- logger.info("Server listening", extra={"socket_path": self.socket_path})
542
- self.listener = listener
543
- self.running = True
544
- self.started_at = time.time()
545
605
  try:
606
+ listener.bind(self.socket_path)
607
+ listener.listen()
608
+ if self.verbose_debug:
609
+ logger.info("Server listening", extra={"socket_path": self.socket_path})
610
+ self.listener = listener
611
+ self.running = True
612
+ self.started_at = time.time()
613
+ register_local_server(self.socket_path, self)
614
+ if started_event is not None:
615
+ started_event.set()
546
616
  while self.running:
547
617
  try:
548
618
  conn, _ = listener.accept()
@@ -553,6 +623,8 @@ class SharedTensorServer:
553
623
  thread = threading.Thread(target=self._handle_connection, args=(conn,), daemon=True)
554
624
  thread.start()
555
625
  finally:
626
+ if started_event is not None and not started_event.is_set():
627
+ started_event.set()
556
628
  self._shutdown_local_resources()
557
629
 
558
630
  def _handle_connection(self, conn: socket.socket) -> None:
@@ -586,24 +658,20 @@ class SharedTensorServer:
586
658
  def stop(self) -> None:
587
659
  if self.verbose_debug:
588
660
  logger.info("Server stopping", extra={"socket_path": self.socket_path})
589
- if not self.running:
590
- unlink_socket_path(self.socket_path)
591
- return
592
661
  self.running = False
593
- if self.server_process is not None:
594
- self.server_process.terminate()
595
- self.server_process.join(timeout=5)
596
- if self.server_process.is_alive():
597
- self.server_process.kill()
598
- self.server_process.join(timeout=5)
599
- self.server_process = None
600
- unlink_socket_path(self.socket_path)
601
- return
602
662
  if self.listener is not None:
603
663
  self.listener.close()
604
- self._shutdown_local_resources()
664
+ state = self.server_thread
665
+ if state is not None and state.thread.is_alive() and threading.current_thread() is not state.thread:
666
+ state.stopped.wait(timeout=5)
667
+ state.thread.join(timeout=5)
668
+ self.server_thread = None
669
+ self.server_process = None
670
+ if self.listener is None:
671
+ unlink_socket_path(self.socket_path)
605
672
 
606
673
  def _shutdown_local_resources(self) -> None:
674
+ self.running = False
607
675
  if self.listener is not None:
608
676
  self.listener.close()
609
677
  self.listener = None
@@ -612,8 +680,10 @@ class SharedTensorServer:
612
680
  self._task_manager = None
613
681
  self._managed_objects.clear()
614
682
  self._cache.clear()
683
+ self._local_cache.clear()
615
684
  self._inflight.clear()
616
685
  self._endpoint_locks.clear()
686
+ unregister_local_server(self.socket_path, self)
617
687
  unlink_socket_path(self.socket_path)
618
688
 
619
689
  def __enter__(self) -> SharedTensorServer:
@@ -10,6 +10,7 @@ shared_tensor/client.py
10
10
  shared_tensor/errors.py
11
11
  shared_tensor/managed_object.py
12
12
  shared_tensor/provider.py
13
+ shared_tensor/runtime.py
13
14
  shared_tensor/server.py
14
15
  shared_tensor/transport.py
15
16
  shared_tensor/utils.py
File without changes
File without changes
File without changes