shared-tensor 0.2.7__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
7
7
  License-Expression: Apache-2.0
8
8
  Project-URL: Homepage, https://github.com/world-sim-dev/shared-tensor
9
9
  Project-URL: Repository, https://github.com/world-sim-dev/shared-tensor
10
- Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/wiki
10
+ Project-URL: Documentation, https://github.com/world-sim-dev/shared-tensor/tree/main/docs
11
11
  Project-URL: Bug Reports, https://github.com/world-sim-dev/shared-tensor/issues
12
12
  Project-URL: Changelog, https://github.com/world-sim-dev/shared-tensor/releases
13
13
  Keywords: gpu,memory,sharing,ipc,inter-process-communication,pytorch,cuda,model-serving,inference,torch,torch-ipc
@@ -77,7 +77,7 @@ Not supported:
77
77
 
78
78
  ## Install
79
79
 
80
- Use Python `3.10+` and a CUDA-enabled PyTorch build.
80
+ Use Python `3.9+` and a CUDA-enabled PyTorch build.
81
81
 
82
82
  ```bash
83
83
  pip install shared-tensor
@@ -91,6 +91,16 @@ conda activate shared-tensor-dev
91
91
  pip install -e ".[dev,test]"
92
92
  ```
93
93
 
94
+ ## Docs
95
+
96
+ Read the examples first, then the design notes:
97
+
98
+ - `docs/overview.md`
99
+ - `docs/patterns.md`
100
+ - `docs/architecture.md`
101
+ - `docs/lifecycle.md`
102
+ - `docs/diagrams.md`
103
+
94
104
  ## Example: Manual Two-Process Deployment
95
105
 
96
106
  Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
@@ -25,7 +25,7 @@ Not supported:
25
25
 
26
26
  ## Install
27
27
 
28
- Use Python `3.10+` and a CUDA-enabled PyTorch build.
28
+ Use Python `3.9+` and a CUDA-enabled PyTorch build.
29
29
 
30
30
  ```bash
31
31
  pip install shared-tensor
@@ -39,6 +39,16 @@ conda activate shared-tensor-dev
39
39
  pip install -e ".[dev,test]"
40
40
  ```
41
41
 
42
+ ## Docs
43
+
44
+ Read the examples first, then the design notes:
45
+
46
+ - `docs/overview.md`
47
+ - `docs/patterns.md`
48
+ - `docs/architecture.md`
49
+ - `docs/lifecycle.md`
50
+ - `docs/diagrams.md`
51
+
42
52
  ## Example: Manual Two-Process Deployment
43
53
 
44
54
  Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.7"
7
+ version = "0.2.8"
8
8
  description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -75,7 +75,7 @@ docs = [
75
75
  [project.urls]
76
76
  Homepage = "https://github.com/world-sim-dev/shared-tensor"
77
77
  Repository = "https://github.com/world-sim-dev/shared-tensor"
78
- Documentation = "https://github.com/world-sim-dev/shared-tensor/wiki"
78
+ Documentation = "https://github.com/world-sim-dev/shared-tensor/tree/main/docs"
79
79
  "Bug Reports" = "https://github.com/world-sim-dev/shared-tensor/issues"
80
80
  Changelog = "https://github.com/world-sim-dev/shared-tensor/releases"
81
81
 
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.7"
22
+ __version__ = "0.2.8"
@@ -33,8 +33,6 @@ class TaskInfo:
33
33
  created_at: float
34
34
  started_at: float | None = None
35
35
  completed_at: float | None = None
36
- result_encoding: str | None = None
37
- result_payload: bytes | None = None
38
36
  error_type: str | None = None
39
37
  error_message: str | None = None
40
38
  metadata: dict[str, Any] = field(default_factory=dict)
@@ -47,8 +45,6 @@ class TaskInfo:
47
45
  "created_at": self.created_at,
48
46
  "started_at": self.started_at,
49
47
  "completed_at": self.completed_at,
50
- "result_encoding": self.result_encoding,
51
- "result_payload": self.result_payload,
52
48
  "error_type": self.error_type,
53
49
  "error_message": self.error_message,
54
50
  "metadata": dict(self.metadata),
@@ -66,6 +62,8 @@ class TaskInfo:
66
62
  class _TaskEntry:
67
63
  info: TaskInfo
68
64
  future: Future[Any]
65
+ result_encoding: str | None = None
66
+ result_payload: bytes | None = None
69
67
  local_result: Any = None
70
68
 
71
69
 
@@ -87,6 +85,7 @@ class TaskManager:
87
85
  self._last_cleanup = 0.0
88
86
  self._lock = RLock()
89
87
  self._tasks: dict[str, _TaskEntry] = {}
88
+ self._accepting_submissions = True
90
89
 
91
90
  def submit(
92
91
  self,
@@ -98,6 +97,8 @@ class TaskManager:
98
97
  ) -> TaskInfo:
99
98
  self._maybe_cleanup()
100
99
  with self._lock:
100
+ if not self._accepting_submissions:
101
+ raise SharedTensorTaskError("Task manager is shutting down and is not accepting new tasks")
101
102
  self._drop_oldest_finished_tasks_if_needed()
102
103
  if len(self._tasks) >= self._max_tasks:
103
104
  raise SharedTensorTaskError("Task capacity exceeded")
@@ -143,12 +144,11 @@ class TaskManager:
143
144
  self._store_local_result(task_id, result)
144
145
 
145
146
  if result is None:
147
+ self._store_payload(task_id, encoding=None, payload=None, object_id=None)
146
148
  self._transition(
147
149
  task_id,
148
150
  status=TaskStatus.COMPLETED,
149
151
  completed_at=time.time(),
150
- result_encoding=None,
151
- result_payload=None,
152
152
  )
153
153
  return
154
154
 
@@ -168,13 +168,16 @@ class TaskManager:
168
168
  )
169
169
  return
170
170
 
171
+ self._store_payload(
172
+ task_id,
173
+ encoding=payload["encoding"],
174
+ payload=payload["payload_bytes"],
175
+ object_id=payload.get("object_id"),
176
+ )
171
177
  self._transition(
172
178
  task_id,
173
179
  status=TaskStatus.COMPLETED,
174
180
  completed_at=time.time(),
175
- result_encoding=payload["encoding"],
176
- result_payload=payload["payload_bytes"],
177
- metadata={"object_id": payload.get("object_id")},
178
181
  )
179
182
 
180
183
  @staticmethod
@@ -201,6 +204,24 @@ class TaskManager:
201
204
  return
202
205
  entry.local_result = value
203
206
 
207
+ def _store_payload(
208
+ self,
209
+ task_id: str,
210
+ *,
211
+ encoding: str | None,
212
+ payload: bytes | None,
213
+ object_id: str | None,
214
+ ) -> None:
215
+ with self._lock:
216
+ entry = self._tasks.get(task_id)
217
+ if entry is None:
218
+ return
219
+ entry.result_encoding = encoding
220
+ entry.result_payload = payload
221
+ metadata = dict(entry.info.metadata)
222
+ metadata["object_id"] = object_id
223
+ entry.info.metadata = metadata
224
+
204
225
  def get(self, task_id: str) -> TaskInfo:
205
226
  self._maybe_cleanup()
206
227
  with self._lock:
@@ -255,7 +276,14 @@ class TaskManager:
255
276
  return self.result_payload(task_id)
256
277
 
257
278
  def result_payload(self, task_id: str) -> dict[str, str | bytes | None]:
258
- info = self.get(task_id)
279
+ self._maybe_cleanup()
280
+ with self._lock:
281
+ entry = self._tasks.get(task_id)
282
+ if entry is None:
283
+ raise SharedTensorTaskError(f"Task '{task_id}' was not found")
284
+ info = copy.deepcopy(entry.info)
285
+ encoding = entry.result_encoding
286
+ payload = entry.result_payload
259
287
  if info.status == TaskStatus.CANCELLED:
260
288
  raise SharedTensorTaskError(f"Task '{task_id}' was cancelled")
261
289
  if info.status == TaskStatus.FAILED:
@@ -265,8 +293,8 @@ class TaskManager:
265
293
  f"Task '{task_id}' is not complete; current status is '{info.status.value}'"
266
294
  )
267
295
  return {
268
- "encoding": info.result_encoding,
269
- "payload_bytes": info.result_payload,
296
+ "encoding": encoding,
297
+ "payload_bytes": payload,
270
298
  "object_id": info.metadata.get("object_id"),
271
299
  }
272
300
 
@@ -309,8 +337,10 @@ class TaskManager:
309
337
  }
310
338
  return items
311
339
 
312
- def shutdown(self, *, wait: bool = True) -> None:
313
- self._executor.shutdown(wait=wait, cancel_futures=True)
340
+ def shutdown(self, *, wait: bool = True, cancel_futures: bool = True) -> None:
341
+ with self._lock:
342
+ self._accepting_submissions = False
343
+ self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
314
344
 
315
345
  def _maybe_cleanup(self) -> None:
316
346
  now = time.time()
@@ -0,0 +1,135 @@
1
+ """Managed remote object handles and registry state."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from threading import RLock
8
+ from typing import Any, Generic, TypeVar
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class ManagedObjectEntry:
15
+ object_id: str
16
+ value: Any
17
+ endpoint: str
18
+ cache_key: str | None
19
+ refcount: int = 1
20
+
21
+
22
+ @dataclass(slots=True)
23
+ class ManagedReleaseResult:
24
+ released: bool
25
+ destroyed: bool
26
+ refcount: int
27
+ cache_key: str | None
28
+
29
+
30
+ class ManagedObjectRegistry:
31
+ def __init__(self) -> None:
32
+ self._entries: dict[str, ManagedObjectEntry] = {}
33
+ self._cache_index: dict[str, str] = {}
34
+ self._lock = RLock()
35
+
36
+ def get_cached(self, cache_key: str) -> ManagedObjectEntry | None:
37
+ with self._lock:
38
+ object_id = self._cache_index.get(cache_key)
39
+ if object_id is None:
40
+ return None
41
+ entry = self._entries.get(object_id)
42
+ if entry is None:
43
+ self._cache_index.pop(cache_key, None)
44
+ return None
45
+ return entry
46
+
47
+ def register(self, *, endpoint: str, value: Any, cache_key: str | None) -> ManagedObjectEntry:
48
+ with self._lock:
49
+ entry = ManagedObjectEntry(
50
+ object_id=uuid.uuid4().hex,
51
+ value=value,
52
+ endpoint=endpoint,
53
+ cache_key=cache_key,
54
+ )
55
+ self._entries[entry.object_id] = entry
56
+ if cache_key is not None:
57
+ self._cache_index[cache_key] = entry.object_id
58
+ return entry
59
+
60
+ def get(self, object_id: str) -> ManagedObjectEntry | None:
61
+ with self._lock:
62
+ return self._entries.get(object_id)
63
+
64
+ def add_ref(self, object_id: str) -> ManagedObjectEntry | None:
65
+ with self._lock:
66
+ entry = self._entries.get(object_id)
67
+ if entry is None:
68
+ return None
69
+ entry.refcount += 1
70
+ return entry
71
+
72
+ def release(self, object_id: str) -> ManagedReleaseResult:
73
+ with self._lock:
74
+ entry = self._entries.get(object_id)
75
+ if entry is None:
76
+ return ManagedReleaseResult(released=False, destroyed=False, refcount=0, cache_key=None)
77
+
78
+ entry.refcount -= 1
79
+ destroyed = entry.refcount <= 0
80
+ cache_key = entry.cache_key
81
+ refcount = max(entry.refcount, 0)
82
+ if destroyed:
83
+ self._entries.pop(object_id, None)
84
+ if cache_key is not None and self._cache_index.get(cache_key) == object_id:
85
+ self._cache_index.pop(cache_key, None)
86
+ return ManagedReleaseResult(
87
+ released=True,
88
+ destroyed=destroyed,
89
+ refcount=refcount,
90
+ cache_key=cache_key,
91
+ )
92
+
93
+ def info(self, object_id: str) -> dict[str, Any] | None:
94
+ with self._lock:
95
+ entry = self._entries.get(object_id)
96
+ if entry is None:
97
+ return None
98
+ return {
99
+ "object_id": entry.object_id,
100
+ "endpoint": entry.endpoint,
101
+ "cache_key": entry.cache_key,
102
+ "refcount": entry.refcount,
103
+ }
104
+
105
+ def clear(self) -> None:
106
+ with self._lock:
107
+ self._entries.clear()
108
+ self._cache_index.clear()
109
+
110
+
111
+ class ReleaseHandle:
112
+ def release(self) -> bool: # pragma: no cover - protocol surface only
113
+ raise NotImplementedError
114
+
115
+
116
+ @dataclass(slots=True)
117
+ class SharedObjectHandle(Generic[T]):
118
+ object_id: str
119
+ value: T
120
+ _releaser: ReleaseHandle
121
+ released: bool = False
122
+
123
+ def release(self) -> bool:
124
+ if self.released:
125
+ return False
126
+ released = self._releaser.release()
127
+ if released:
128
+ self.released = True
129
+ return released
130
+
131
+ def __enter__(self) -> SharedObjectHandle[T]:
132
+ return self
133
+
134
+ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
135
+ self.release()
@@ -8,6 +8,7 @@ import os
8
8
  from collections.abc import Callable
9
9
  from dataclasses import dataclass
10
10
  from functools import wraps
11
+ from threading import RLock
11
12
  from typing import Any, Literal
12
13
 
13
14
  from shared_tensor.errors import (
@@ -113,7 +114,9 @@ class SharedTensorProvider:
113
114
  self._cache: dict[str, Any] = {}
114
115
  self._endpoints: dict[str, EndpointDefinition] = {}
115
116
  self._registered_functions = self._endpoints
116
- atexit.register(self.close)
117
+ self._lock = RLock()
118
+ self._atexit_registered = False
119
+ self._register_atexit_once()
117
120
 
118
121
  def register(
119
122
  self,
@@ -129,25 +132,26 @@ class SharedTensorProvider:
129
132
  ) -> Callable[..., Any]:
130
133
  _validate_endpoint_options(execution=execution, concurrency=concurrency)
131
134
  endpoint_name = func.__name__
132
- if endpoint_name in self._endpoints:
133
- raise SharedTensorProviderError(f"Endpoint '{endpoint_name}' is already registered")
135
+ with self._lock:
136
+ if endpoint_name in self._endpoints:
137
+ raise SharedTensorProviderError(f"Endpoint '{endpoint_name}' is already registered")
134
138
 
135
- resolved_cache_format_key = (
136
- func.__qualname__ if cache_format_key is None else cache_format_key
137
- )
139
+ resolved_cache_format_key = (
140
+ func.__qualname__ if cache_format_key is None else cache_format_key
141
+ )
138
142
 
139
- definition = EndpointDefinition(
140
- name=endpoint_name,
141
- func=func,
142
- cache=cache,
143
- cache_format_key=resolved_cache_format_key,
144
- managed=managed,
145
- async_default_wait=async_default_wait,
146
- execution=execution,
147
- concurrency=concurrency,
148
- singleflight=singleflight,
149
- )
150
- self._endpoints[endpoint_name] = definition
143
+ definition = EndpointDefinition(
144
+ name=endpoint_name,
145
+ func=func,
146
+ cache=cache,
147
+ cache_format_key=resolved_cache_format_key,
148
+ managed=managed,
149
+ async_default_wait=async_default_wait,
150
+ execution=execution,
151
+ concurrency=concurrency,
152
+ singleflight=singleflight,
153
+ )
154
+ self._endpoints[endpoint_name] = definition
151
155
  if self.verbose_debug:
152
156
  logger.debug(
153
157
  "Provider registered endpoint",
@@ -161,7 +165,7 @@ class SharedTensorProvider:
161
165
  )
162
166
 
163
167
  if self._should_autostart_server():
164
- self._restart_autostart_server()
168
+ self._ensure_autostart_server()
165
169
 
166
170
  @wraps(func)
167
171
  def wrapper(*args: Any, **kwargs: Any) -> Any:
@@ -272,11 +276,13 @@ class SharedTensorProvider:
272
276
  return definition.func(*args, **resolved_kwargs)
273
277
 
274
278
  cache_key = self._cache_key_for(endpoint, definition, args, resolved_kwargs)
275
- if cache_key in self._cache:
276
- return self._cache[cache_key]
279
+ with self._lock:
280
+ if cache_key in self._cache:
281
+ return self._cache[cache_key]
277
282
 
278
283
  result = definition.func(*args, **resolved_kwargs)
279
- self._cache[cache_key] = result
284
+ with self._lock:
285
+ self._cache[cache_key] = result
280
286
  return result
281
287
 
282
288
  def get_endpoint(self, endpoint: str) -> EndpointDefinition:
@@ -309,18 +315,19 @@ class SharedTensorProvider:
309
315
  self._async_client.close()
310
316
  self._async_client = None
311
317
  if self._server is not None:
312
- self._server.stop()
318
+ self._server.stop(wait_for_tasks=True)
313
319
  self._server = None
314
320
 
315
321
  def get_runtime_info(self) -> dict[str, Any]:
316
322
  if self.execution_mode in {"server", "local"}:
323
+ server = self._server
317
324
  return {
318
325
  "execution_mode": self.execution_mode,
319
326
  "auto_mode": self.auto_mode,
320
327
  "base_path": self.base_path,
321
328
  "device_index": self.device_index,
322
329
  "server_socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
323
- "server_running": self._server is not None,
330
+ "server_running": bool(server is not None and getattr(server, "running", True)),
324
331
  }
325
332
  server_info = self._get_client().get_server_info()
326
333
  return {
@@ -361,18 +368,18 @@ class SharedTensorProvider:
361
368
  def _should_autostart_server(self) -> bool:
362
369
  return self.auto_mode and self.execution_mode == "server"
363
370
 
364
- def _restart_autostart_server(self) -> None:
371
+ def _ensure_autostart_server(self) -> None:
365
372
  from shared_tensor.server import SharedTensorServer
366
373
 
374
+ if self._server is not None:
375
+ return
367
376
  if self.verbose_debug:
368
377
  logger.debug(
369
- "Provider restarting autostart server",
378
+ "Provider starting autostart server",
370
379
  extra={
371
380
  "socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
372
381
  },
373
382
  )
374
- if self._server is not None:
375
- self._server.stop()
376
383
  self._server = SharedTensorServer(
377
384
  self,
378
385
  socket_path=resolve_runtime_socket_path(self.base_path, self.device_index),
@@ -395,3 +402,9 @@ class SharedTensorProvider:
395
402
  func=definition.func,
396
403
  cache_format_key=definition.cache_format_key,
397
404
  )
405
+
406
+ def _register_atexit_once(self) -> None:
407
+ if self._atexit_registered:
408
+ return
409
+ atexit.register(self.close)
410
+ self._atexit_registered = True
@@ -5,6 +5,8 @@ from __future__ import annotations
5
5
  from threading import RLock
6
6
  from typing import TYPE_CHECKING
7
7
 
8
+ from shared_tensor.errors import SharedTensorConfigurationError
9
+
8
10
  if TYPE_CHECKING:
9
11
  from shared_tensor.server import SharedTensorServer
10
12
 
@@ -15,6 +17,11 @@ _SERVERS: dict[str, "SharedTensorServer"] = {}
15
17
 
16
18
  def register_local_server(socket_path: str, server: "SharedTensorServer") -> None:
17
19
  with _LOCK:
20
+ current = _SERVERS.get(socket_path)
21
+ if current is not None and current is not server:
22
+ raise SharedTensorConfigurationError(
23
+ f"Local runtime socket '{socket_path}' is already registered by another server"
24
+ )
18
25
  _SERVERS[socket_path] = server
19
26
 
20
27
 
@@ -39,6 +39,24 @@ from shared_tensor.utils import (
39
39
  logger = logging.getLogger(__name__)
40
40
 
41
41
 
42
+ class _ConnectionExecutor:
43
+ def __init__(self, *, max_workers: int) -> None:
44
+ self._semaphore = threading.BoundedSemaphore(max_workers)
45
+
46
+ def submit(self, func, *args, **kwargs) -> threading.Thread:
47
+ self._semaphore.acquire()
48
+
49
+ def runner() -> None:
50
+ try:
51
+ func(*args, **kwargs)
52
+ finally:
53
+ self._semaphore.release()
54
+
55
+ thread = threading.Thread(target=runner, daemon=True)
56
+ thread.start()
57
+ return thread
58
+
59
+
42
60
  def _server_version() -> str:
43
61
  try:
44
62
  from shared_tensor import __version__
@@ -49,7 +67,7 @@ def _server_version() -> str:
49
67
 
50
68
  @dataclass(slots=True)
51
69
  class _InFlightCall:
52
- future: Future[dict[str, Any]]
70
+ future: Future
53
71
 
54
72
 
55
73
  @dataclass(slots=True)
@@ -107,6 +125,8 @@ class SharedTensorServer:
107
125
  self._inflight: dict[str, _InFlightCall] = {}
108
126
  self._endpoint_locks: dict[str, threading.Lock] = {}
109
127
  self._coordination_lock = threading.RLock()
128
+ self._connection_executor = _ConnectionExecutor(max_workers=max_workers)
129
+ self._accepting_requests = True
110
130
  if getattr(self.provider, "_server", None) is None:
111
131
  self.provider._server = self
112
132
 
@@ -114,6 +134,8 @@ class SharedTensorServer:
114
134
  if self.verbose_debug:
115
135
  logger.debug("Server processing request", extra={"method": request.get("method")})
116
136
  try:
137
+ if not self._accepting_requests:
138
+ raise SharedTensorConfigurationError("Server is stopping and not accepting new requests")
117
139
  method = request.get("method")
118
140
  if not isinstance(method, str) or not method:
119
141
  raise SharedTensorProtocolError("Missing required field 'method'")
@@ -293,7 +315,8 @@ class SharedTensorServer:
293
315
  if self.verbose_debug:
294
316
  logger.debug("Server executed direct endpoint", extra={"endpoint": endpoint})
295
317
  if cache_key is not None:
296
- self._local_cache[cache_key] = value
318
+ with self._coordination_lock:
319
+ self._local_cache[cache_key] = value
297
320
  return _EndpointResult(value=value)
298
321
 
299
322
  def _materialize_managed_result(
@@ -329,7 +352,8 @@ class SharedTensorServer:
329
352
  return None
330
353
  self._managed_objects.add_ref(cached.object_id)
331
354
  return _EndpointResult(value=cached.value, object_id=cached.object_id)
332
- local_value = self._local_cache.get(cache_key)
355
+ with self._coordination_lock:
356
+ local_value = self._local_cache.get(cache_key)
333
357
  if local_value is None:
334
358
  return None
335
359
  return _EndpointResult(value=local_value)
@@ -397,11 +421,14 @@ class SharedTensorServer:
397
421
  return existing.value
398
422
  self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
399
423
  return value
400
- if cache_key is not None and cache_key in self._local_cache:
401
- return self._local_cache[cache_key]
424
+ if cache_key is not None:
425
+ with self._coordination_lock:
426
+ if cache_key in self._local_cache:
427
+ return self._local_cache[cache_key]
402
428
  value = definition.func(*args, **resolved_kwargs)
403
429
  if cache_key is not None:
404
- self._local_cache[cache_key] = value
430
+ with self._coordination_lock:
431
+ self._local_cache[cache_key] = value
405
432
  return value
406
433
 
407
434
  def _cache_key(
@@ -421,16 +448,16 @@ class SharedTensorServer:
421
448
  cache_format_key=definition.cache_format_key,
422
449
  )
423
450
 
424
- def _acquire_inflight(self, inflight_key: str) -> tuple[Future[dict[str, Any]], bool]:
451
+ def _acquire_inflight(self, inflight_key: str) -> tuple[Future, bool]:
425
452
  with self._coordination_lock:
426
453
  inflight = self._inflight.get(inflight_key)
427
454
  if inflight is not None:
428
455
  return inflight.future, False
429
- future: Future[dict[str, Any]] = Future()
456
+ future = Future()
430
457
  self._inflight[inflight_key] = _InFlightCall(future=future)
431
458
  return future, True
432
459
 
433
- def _release_inflight(self, inflight_key: str | None, future: Future[dict[str, Any]]) -> None:
460
+ def _release_inflight(self, inflight_key: str | None, future: Future) -> None:
434
461
  if inflight_key is None:
435
462
  return
436
463
  with self._coordination_lock:
@@ -535,7 +562,7 @@ class SharedTensorServer:
535
562
  "socket_path": self.socket_path,
536
563
  "uptime": uptime,
537
564
  "running": self.running,
538
- "ready": self.running and self.listener is not None,
565
+ "ready": self.running and self.listener is not None and self._accepting_requests,
539
566
  "pid": os.getpid(),
540
567
  "ppid": os.getppid(),
541
568
  "device_index": resolve_device_index(self.provider.device_index),
@@ -558,6 +585,7 @@ class SharedTensorServer:
558
585
  logger.info("Server starting", extra={"socket_path": self.socket_path, "blocking": blocking})
559
586
  if self.running or self.server_thread is not None:
560
587
  raise SharedTensorConfigurationError("Server is already running")
588
+ self._accepting_requests = True
561
589
  if blocking:
562
590
  self._resolved_process_start_method = None
563
591
  self._serve_forever()
@@ -576,11 +604,11 @@ class SharedTensorServer:
576
604
  self._resolved_process_start_method = "thread"
577
605
  thread.start()
578
606
  if not state.ready.wait(timeout=self.startup_timeout):
579
- self.stop()
607
+ self.stop(wait_for_tasks=False)
580
608
  raise TimeoutError(f"Timed out waiting for server socket {self.socket_path}")
581
609
  if state.error is not None:
582
610
  error = state.error
583
- self.stop()
611
+ self.stop(wait_for_tasks=False)
584
612
  raise SharedTensorConfigurationError(
585
613
  f"Failed to start background server thread for {self.socket_path}: {error}"
586
614
  ) from error
@@ -620,12 +648,11 @@ class SharedTensorServer:
620
648
  if self.running:
621
649
  raise
622
650
  break
623
- thread = threading.Thread(target=self._handle_connection, args=(conn,), daemon=True)
624
- thread.start()
651
+ self._connection_executor.submit(self._handle_connection, conn)
625
652
  finally:
626
653
  if started_event is not None and not started_event.is_set():
627
654
  started_event.set()
628
- self._shutdown_local_resources()
655
+ self._shutdown_local_resources(wait_for_tasks=True)
629
656
 
630
657
  def _handle_connection(self, conn: socket.socket) -> None:
631
658
  with conn:
@@ -655,9 +682,10 @@ class SharedTensorServer:
655
682
  if 0 <= local_rank < torch.cuda.device_count():
656
683
  torch.cuda.set_device(local_rank)
657
684
 
658
- def stop(self) -> None:
685
+ def stop(self, *, wait_for_tasks: bool = True) -> None:
659
686
  if self.verbose_debug:
660
687
  logger.info("Server stopping", extra={"socket_path": self.socket_path})
688
+ self._accepting_requests = False
661
689
  self.running = False
662
690
  if self.listener is not None:
663
691
  self.listener.close()
@@ -668,21 +696,23 @@ class SharedTensorServer:
668
696
  self.server_thread = None
669
697
  self.server_process = None
670
698
  if self.listener is None:
671
- unlink_socket_path(self.socket_path)
699
+ self._shutdown_local_resources(wait_for_tasks=wait_for_tasks)
672
700
 
673
- def _shutdown_local_resources(self) -> None:
701
+ def _shutdown_local_resources(self, *, wait_for_tasks: bool) -> None:
702
+ self._accepting_requests = False
674
703
  self.running = False
675
704
  if self.listener is not None:
676
705
  self.listener.close()
677
706
  self.listener = None
678
707
  if self._task_manager is not None:
679
- self._task_manager.shutdown(wait=False)
708
+ self._task_manager.shutdown(wait=wait_for_tasks, cancel_futures=not wait_for_tasks)
680
709
  self._task_manager = None
681
710
  self._managed_objects.clear()
682
- self._cache.clear()
683
- self._local_cache.clear()
684
- self._inflight.clear()
685
- self._endpoint_locks.clear()
711
+ with self._coordination_lock:
712
+ self._cache.clear()
713
+ self._local_cache.clear()
714
+ self._inflight.clear()
715
+ self._endpoint_locks.clear()
686
716
  unregister_local_server(self.socket_path, self)
687
717
  unlink_socket_path(self.socket_path)
688
718
 
@@ -1,126 +0,0 @@
1
- """Managed remote object handles and registry state."""
2
-
3
- from __future__ import annotations
4
-
5
- import uuid
6
- from dataclasses import dataclass
7
- from typing import Any, Generic, TypeVar
8
-
9
- T = TypeVar("T")
10
-
11
-
12
- @dataclass(slots=True)
13
- class ManagedObjectEntry:
14
- object_id: str
15
- value: Any
16
- endpoint: str
17
- cache_key: str | None
18
- refcount: int = 1
19
-
20
-
21
- @dataclass(slots=True)
22
- class ManagedReleaseResult:
23
- released: bool
24
- destroyed: bool
25
- refcount: int
26
- cache_key: str | None
27
-
28
-
29
- class ManagedObjectRegistry:
30
- def __init__(self) -> None:
31
- self._entries: dict[str, ManagedObjectEntry] = {}
32
- self._cache_index: dict[str, str] = {}
33
-
34
- def get_cached(self, cache_key: str) -> ManagedObjectEntry | None:
35
- object_id = self._cache_index.get(cache_key)
36
- if object_id is None:
37
- return None
38
- entry = self._entries.get(object_id)
39
- if entry is None:
40
- self._cache_index.pop(cache_key, None)
41
- return None
42
- return entry
43
-
44
- def register(self, *, endpoint: str, value: Any, cache_key: str | None) -> ManagedObjectEntry:
45
- entry = ManagedObjectEntry(
46
- object_id=uuid.uuid4().hex,
47
- value=value,
48
- endpoint=endpoint,
49
- cache_key=cache_key,
50
- )
51
- self._entries[entry.object_id] = entry
52
- if cache_key is not None:
53
- self._cache_index[cache_key] = entry.object_id
54
- return entry
55
-
56
- def get(self, object_id: str) -> ManagedObjectEntry | None:
57
- return self._entries.get(object_id)
58
-
59
- def add_ref(self, object_id: str) -> ManagedObjectEntry | None:
60
- entry = self._entries.get(object_id)
61
- if entry is None:
62
- return None
63
- entry.refcount += 1
64
- return entry
65
-
66
- def release(self, object_id: str) -> ManagedReleaseResult:
67
- entry = self._entries.get(object_id)
68
- if entry is None:
69
- return ManagedReleaseResult(released=False, destroyed=False, refcount=0, cache_key=None)
70
-
71
- entry.refcount -= 1
72
- destroyed = entry.refcount <= 0
73
- cache_key = entry.cache_key
74
- refcount = max(entry.refcount, 0)
75
- if destroyed:
76
- self._entries.pop(object_id, None)
77
- if cache_key is not None and self._cache_index.get(cache_key) == object_id:
78
- self._cache_index.pop(cache_key, None)
79
- return ManagedReleaseResult(
80
- released=True,
81
- destroyed=destroyed,
82
- refcount=refcount,
83
- cache_key=cache_key,
84
- )
85
-
86
- def info(self, object_id: str) -> dict[str, Any] | None:
87
- entry = self._entries.get(object_id)
88
- if entry is None:
89
- return None
90
- return {
91
- "object_id": entry.object_id,
92
- "endpoint": entry.endpoint,
93
- "cache_key": entry.cache_key,
94
- "refcount": entry.refcount,
95
- }
96
-
97
- def clear(self) -> None:
98
- self._entries.clear()
99
- self._cache_index.clear()
100
-
101
-
102
- class ReleaseHandle:
103
- def release(self) -> bool: # pragma: no cover - protocol surface only
104
- raise NotImplementedError
105
-
106
-
107
- @dataclass(slots=True)
108
- class SharedObjectHandle(Generic[T]):
109
- object_id: str
110
- value: T
111
- _releaser: ReleaseHandle
112
- released: bool = False
113
-
114
- def release(self) -> bool:
115
- if self.released:
116
- return False
117
- released = self._releaser.release()
118
- if released:
119
- self.released = True
120
- return released
121
-
122
- def __enter__(self) -> SharedObjectHandle[T]:
123
- return self
124
-
125
- def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
126
- self.release()
File without changes
File without changes
File without changes