shared-tensor 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: shared-tensor
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation
5
5
  Author-email: Athena Team <contact@world-sim-dev.org>
6
6
  Maintainer-email: Athena Team <contact@world-sim-dev.org>
@@ -63,6 +63,7 @@ Supported:
63
63
  - sync `call` and task-backed `submit`
64
64
  - managed object handles with explicit release
65
65
  - server-side caching, `cache_format_key`, and singleflight
66
+ - manual two-process deployment as the primary production path
66
67
  - zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
67
68
 
68
69
  Not supported:
@@ -88,46 +89,58 @@ conda activate shared-tensor-dev
88
89
  pip install -e ".[dev,test]"
89
90
  ```
90
91
 
91
- ## Example: Same Code, Two Processes
92
+ ## Example: Manual Two-Process Deployment
93
+
94
+ Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
92
95
 
93
- See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
96
+ See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
97
+
98
+ Server process:
94
99
 
95
100
  ```python
96
- import torch
101
+ from shared_tensor import SharedTensorProvider, SharedTensorServer
97
102
 
98
- from shared_tensor import SharedObjectHandle, SharedTensorProvider
103
+ provider = SharedTensorProvider(execution_mode="server")
99
104
 
100
- provider = SharedTensorProvider()
105
+ @provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
106
+ def load_model(hidden_size: int = 4):
107
+ ...
101
108
 
109
+ server = SharedTensorServer(provider)
110
+ server.start(blocking=True)
111
+ ```
102
112
 
103
- @provider.share(
104
- execution="task",
105
- managed=True,
106
- concurrency="serialized",
107
- cache_format_key="model:{hidden_size}",
108
- )
109
- def load_model(hidden_size: int = 4) -> torch.nn.Module:
110
- return torch.nn.Linear(hidden_size, 2, device="cuda")
113
+ Client process:
111
114
 
115
+ ```python
116
+ import torch
117
+
118
+ from shared_tensor import SharedObjectHandle, SharedTensorClient
112
119
 
120
+ client = SharedTensorClient()
113
121
  x = torch.ones(1, 4, device="cuda")
114
- result = load_model(hidden_size=4)
122
+ result = client.call("load_model", hidden_size=4)
115
123
  if isinstance(result, SharedObjectHandle):
116
124
  with result as handle:
117
125
  y = handle.value(x)
118
- else:
119
- y = result(x)
120
126
  ```
121
127
 
122
- Server process:
128
+ This keeps the contract explicit:
123
129
 
124
- ```bash
125
- SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
130
+ ```text
131
+ server process client process
132
+ ------------------------------ ------------------------------
133
+ owns CUDA allocations issues local UDS RPC requests
134
+ executes endpoint functions reopens CUDA objects via torch IPC
135
+ manages cache and refcounts releases managed handles explicitly
126
136
  ```
127
137
 
128
- Client process with the exact same file:
138
+ ## Example: Same Code, Two Processes
139
+
140
+ See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
129
141
 
130
142
  ```bash
143
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
131
144
  SHARED_TENSOR_ENABLED=1 python demo.py
132
145
  ```
133
146
 
@@ -138,7 +151,7 @@ same code
138
151
 
139
152
  server process client process
140
153
  ------------------------------ ------------------------------
141
- provider auto-starts UDS daemon provider builds client wrappers
154
+ provider auto-starts local thread provider builds client wrappers
142
155
  shared function runs locally shared function becomes RPC call
143
156
  CUDA object stays on same GPU CUDA object is reopened via torch IPC
144
157
  ```
@@ -201,19 +214,19 @@ SharedTensorProvider(enabled=None)
201
214
  Provider runtime controls:
202
215
 
203
216
  ```python
204
- SharedTensorProvider(server_process_start_method="fork")
205
217
  SharedTensorProvider(server_startup_timeout=30.0)
206
218
  provider.get_runtime_info()
207
219
  ```
208
220
 
209
- Use `server_process_start_method="fork"` when you explicitly want POSIX fork behavior.
210
- Leave it as `None` to let the library choose a safer default for the current entrypoint.
221
+ Non-blocking provider autostart runs the UDS server in a background thread inside the current process.
211
222
 
212
223
  `execution_mode="auto"` behaves as follows:
213
224
  - disabled: local mode
214
- - enabled + `SHARED_TENSOR_ROLE=server`: auto-start local server and execute endpoints locally
225
+ - enabled + `SHARED_TENSOR_ROLE=server`: auto-start a local background server thread and execute endpoints locally
215
226
  - enabled + role unset: build client wrappers
216
227
 
228
+ For production deployment, prefer explicit `SharedTensorServer(...).start(blocking=True)` in a dedicated server process.
229
+
217
230
  Socket selection is per CUDA device:
218
231
  - base path comes from `SHARED_TENSOR_BASE_PATH` or `/tmp/shared-tensor`
219
232
  - runtime socket path is `<base_path>-<device_index>.sock`
@@ -13,6 +13,7 @@ Supported:
13
13
  - sync `call` and task-backed `submit`
14
14
  - managed object handles with explicit release
15
15
  - server-side caching, `cache_format_key`, and singleflight
16
+ - manual two-process deployment as the primary production path
16
17
  - zero-branch auto mode gated by `SHARED_TENSOR_ENABLED=1`
17
18
 
18
19
  Not supported:
@@ -38,46 +39,58 @@ conda activate shared-tensor-dev
38
39
  pip install -e ".[dev,test]"
39
40
  ```
40
41
 
41
- ## Example: Same Code, Two Processes
42
+ ## Example: Manual Two-Process Deployment
43
+
44
+ Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.
42
45
 
43
- See [examples/zero_branch_env.py](./examples/zero_branch_env.py).
46
+ See [examples/model_service.py](./examples/model_service.py) for endpoint definitions.
47
+
48
+ Server process:
44
49
 
45
50
  ```python
46
- import torch
51
+ from shared_tensor import SharedTensorProvider, SharedTensorServer
47
52
 
48
- from shared_tensor import SharedObjectHandle, SharedTensorProvider
53
+ provider = SharedTensorProvider(execution_mode="server")
49
54
 
50
- provider = SharedTensorProvider()
55
+ @provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
56
+ def load_model(hidden_size: int = 4):
57
+ ...
51
58
 
59
+ server = SharedTensorServer(provider)
60
+ server.start(blocking=True)
61
+ ```
52
62
 
53
- @provider.share(
54
- execution="task",
55
- managed=True,
56
- concurrency="serialized",
57
- cache_format_key="model:{hidden_size}",
58
- )
59
- def load_model(hidden_size: int = 4) -> torch.nn.Module:
60
- return torch.nn.Linear(hidden_size, 2, device="cuda")
63
+ Client process:
61
64
 
65
+ ```python
66
+ import torch
67
+
68
+ from shared_tensor import SharedObjectHandle, SharedTensorClient
62
69
 
70
+ client = SharedTensorClient()
63
71
  x = torch.ones(1, 4, device="cuda")
64
- result = load_model(hidden_size=4)
72
+ result = client.call("load_model", hidden_size=4)
65
73
  if isinstance(result, SharedObjectHandle):
66
74
  with result as handle:
67
75
  y = handle.value(x)
68
- else:
69
- y = result(x)
70
76
  ```
71
77
 
72
- Server process:
78
+ This keeps the contract explicit:
73
79
 
74
- ```bash
75
- SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
80
+ ```text
81
+ server process client process
82
+ ------------------------------ ------------------------------
83
+ owns CUDA allocations issues local UDS RPC requests
84
+ executes endpoint functions reopens CUDA objects via torch IPC
85
+ manages cache and refcounts releases managed handles explicitly
76
86
  ```
77
87
 
78
- Client process with the exact same file:
88
+ ## Example: Same Code, Two Processes
89
+
90
+ See [examples/zero_branch_env.py](./examples/zero_branch_env.py). This is a convenience mode for environments that want one file and environment-controlled behavior.
79
91
 
80
92
  ```bash
93
+ SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
81
94
  SHARED_TENSOR_ENABLED=1 python demo.py
82
95
  ```
83
96
 
@@ -88,7 +101,7 @@ same code
88
101
 
89
102
  server process client process
90
103
  ------------------------------ ------------------------------
91
- provider auto-starts UDS daemon provider builds client wrappers
104
+ provider auto-starts local thread provider builds client wrappers
92
105
  shared function runs locally shared function becomes RPC call
93
106
  CUDA object stays on same GPU CUDA object is reopened via torch IPC
94
107
  ```
@@ -151,19 +164,19 @@ SharedTensorProvider(enabled=None)
151
164
  Provider runtime controls:
152
165
 
153
166
  ```python
154
- SharedTensorProvider(server_process_start_method="fork")
155
167
  SharedTensorProvider(server_startup_timeout=30.0)
156
168
  provider.get_runtime_info()
157
169
  ```
158
170
 
159
- Use `server_process_start_method="fork"` when you explicitly want POSIX fork behavior.
160
- Leave it as `None` to let the library choose a safer default for the current entrypoint.
171
+ Non-blocking provider autostart runs the UDS server in a background thread inside the current process.
161
172
 
162
173
  `execution_mode="auto"` behaves as follows:
163
174
  - disabled: local mode
164
- - enabled + `SHARED_TENSOR_ROLE=server`: auto-start local server and execute endpoints locally
175
+ - enabled + `SHARED_TENSOR_ROLE=server`: auto-start a local background server thread and execute endpoints locally
165
176
  - enabled + role unset: build client wrappers
166
177
 
178
+ For production deployment, prefer explicit `SharedTensorServer(...).start(blocking=True)` in a dedicated server process.
179
+
167
180
  Socket selection is per CUDA device:
168
181
  - base path comes from `SHARED_TENSOR_BASE_PATH` or `/tmp/shared-tensor`
169
182
  - runtime socket path is `<base_path>-<device_index>.sock`
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "shared-tensor"
7
- version = "0.2.5"
7
+ version = "0.2.6"
8
8
  description = "Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "TaskStatus",
20
20
  ]
21
21
 
22
- __version__ = "0.2.5"
22
+ __version__ = "0.2.6"
@@ -92,7 +92,6 @@ class SharedTensorProvider:
92
92
  device_index: int | None = None,
93
93
  timeout: float = 30.0,
94
94
  execution_mode: str = "auto",
95
- server_process_start_method: str | None = None,
96
95
  server_startup_timeout: float = 30.0,
97
96
  verbose_debug: bool = False,
98
97
  ) -> None:
@@ -106,7 +105,6 @@ class SharedTensorProvider:
106
105
  self.timeout = timeout
107
106
  self.execution_mode = resolved_mode
108
107
  self.auto_mode = auto_mode
109
- self.server_process_start_method = server_process_start_method
110
108
  self.server_startup_timeout = server_startup_timeout
111
109
  self.verbose_debug = verbose_debug
112
110
  self._client: Any | None = None
@@ -165,9 +163,6 @@ class SharedTensorProvider:
165
163
  if self._should_autostart_server():
166
164
  self._restart_autostart_server()
167
165
 
168
- if self.execution_mode == "server":
169
- return func
170
-
171
166
  @wraps(func)
172
167
  def wrapper(*args: Any, **kwargs: Any) -> Any:
173
168
  return self.call(endpoint_name, *args, **kwargs)
@@ -215,7 +210,11 @@ class SharedTensorProvider:
215
210
  def call(self, endpoint: str, *args: Any, **kwargs: Any) -> Any:
216
211
  if self.verbose_debug:
217
212
  logger.debug("Provider dispatching call", extra={"endpoint": endpoint, "mode": self.execution_mode})
218
- if self.execution_mode in {"server", "local"}:
213
+ if self.execution_mode == "server":
214
+ if self._server is not None and hasattr(self._server, "invoke_local"):
215
+ return self._server.invoke_local(endpoint, args=args, kwargs=kwargs)
216
+ return self.invoke_local(endpoint, args=args, kwargs=kwargs)
217
+ if self.execution_mode == "local":
219
218
  return self.invoke_local(endpoint, args=args, kwargs=kwargs)
220
219
  return self._get_client().call(endpoint, *args, **kwargs)
221
220
 
@@ -370,7 +369,6 @@ class SharedTensorProvider:
370
369
  "Provider restarting autostart server",
371
370
  extra={
372
371
  "socket_path": resolve_runtime_socket_path(self.base_path, self.device_index),
373
- "process_start_method": self.server_process_start_method,
374
372
  },
375
373
  )
376
374
  if self._server is not None:
@@ -378,7 +376,6 @@ class SharedTensorProvider:
378
376
  self._server = SharedTensorServer(
379
377
  self,
380
378
  socket_path=resolve_runtime_socket_path(self.base_path, self.device_index),
381
- process_start_method=self.server_process_start_method,
382
379
  startup_timeout=self.server_startup_timeout,
383
380
  verbose_debug=self.verbose_debug,
384
381
  )
@@ -2,16 +2,13 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import cloudpickle
6
5
  import logging
7
- import multiprocessing as mp
8
6
  import os
9
- import sys
10
7
  import socket
11
8
  import threading
12
9
  import time
13
10
  from concurrent.futures import Future
14
- from dataclasses import dataclass
11
+ from dataclasses import dataclass, field
15
12
  from typing import Any
16
13
 
17
14
  from shared_tensor.async_task import TaskManager, TaskStatus
@@ -41,11 +38,27 @@ from shared_tensor.utils import (
41
38
  logger = logging.getLogger(__name__)
42
39
 
43
40
 
41
+ def _server_version() -> str:
42
+ try:
43
+ from shared_tensor import __version__
44
+ except ImportError:
45
+ return "unknown"
46
+ return __version__
47
+
48
+
44
49
  @dataclass(slots=True)
45
50
  class _InFlightCall:
46
51
  future: Future[dict[str, Any]]
47
52
 
48
53
 
54
+ @dataclass(slots=True)
55
+ class _ServerThreadState:
56
+ thread: threading.Thread
57
+ ready: threading.Event = field(default_factory=threading.Event)
58
+ stopped: threading.Event = field(default_factory=threading.Event)
59
+ error: BaseException | None = None
60
+
61
+
49
62
  class SharedTensorServer:
50
63
  def __init__(
51
64
  self,
@@ -72,6 +85,7 @@ class SharedTensorServer:
72
85
  self.startup_timeout = startup_timeout
73
86
  self.listener: socket.socket | None = None
74
87
  self.server_process: Any | None = None
88
+ self.server_thread: _ServerThreadState | None = None
75
89
  self._resolved_process_start_method: str | None = None
76
90
  self.running = False
77
91
  self.started_at: float | None = None
@@ -81,10 +95,13 @@ class SharedTensorServer:
81
95
  }
82
96
  self._task_manager: TaskManager | None = None
83
97
  self._cache: dict[str, dict[str, Any]] = {}
98
+ self._local_cache: dict[str, Any] = {}
84
99
  self._managed_objects = ManagedObjectRegistry()
85
100
  self._inflight: dict[str, _InFlightCall] = {}
86
101
  self._endpoint_locks: dict[str, threading.Lock] = {}
87
102
  self._coordination_lock = threading.RLock()
103
+ if getattr(self.provider, "_server", None) is None:
104
+ self.provider._server = self
88
105
 
89
106
  def process_request(self, request: dict[str, Any]) -> dict[str, Any]:
90
107
  if self.verbose_debug:
@@ -265,6 +282,7 @@ class SharedTensorServer:
265
282
  result = self._encode_result(value)
266
283
  if cache_key is not None:
267
284
  self._cache[cache_key] = result
285
+ self._local_cache[cache_key] = value
268
286
  return result
269
287
 
270
288
  def _materialize_managed_result(
@@ -300,7 +318,44 @@ class SharedTensorServer:
300
318
  return None
301
319
  self._managed_objects.add_ref(cached.object_id)
302
320
  return self._encode_result(cached.value, object_id=cached.object_id)
303
- return self._cache.get(cache_key)
321
+ cached = self._cache.get(cache_key)
322
+ if cached is not None:
323
+ return cached
324
+ local_value = self._local_cache.get(cache_key)
325
+ if local_value is None:
326
+ return None
327
+ encoded = self._encode_result(local_value)
328
+ self._cache[cache_key] = encoded
329
+ return encoded
330
+
331
+ def invoke_local(
332
+ self,
333
+ endpoint: str,
334
+ *,
335
+ args: tuple[Any, ...] = (),
336
+ kwargs: dict[str, Any] | None = None,
337
+ ) -> Any:
338
+ definition = self.provider.get_endpoint(endpoint)
339
+ resolved_kwargs = kwargs or {}
340
+ cache_key = self._cache_key(endpoint, definition, args, resolved_kwargs)
341
+ if definition.managed:
342
+ if cache_key is not None:
343
+ cached = self._managed_objects.get_cached(cache_key)
344
+ if cached is not None:
345
+ return cached.value
346
+ value = definition.func(*args, **resolved_kwargs)
347
+ if cache_key is not None:
348
+ existing = self._managed_objects.get_cached(cache_key)
349
+ if existing is not None:
350
+ return existing.value
351
+ self._managed_objects.register(endpoint=endpoint, value=value, cache_key=cache_key)
352
+ return value
353
+ if cache_key is not None and cache_key in self._local_cache:
354
+ return self._local_cache[cache_key]
355
+ value = definition.func(*args, **resolved_kwargs)
356
+ if cache_key is not None:
357
+ self._local_cache[cache_key] = value
358
+ return value
304
359
 
305
360
  def _cache_key(
306
361
  self,
@@ -426,7 +481,7 @@ class SharedTensorServer:
426
481
  uptime = 0.0 if self.started_at is None else time.time() - self.started_at
427
482
  return {
428
483
  "server": "SharedTensorServer",
429
- "version": "0.2.4",
484
+ "version": _server_version(),
430
485
  "socket_path": self.socket_path,
431
486
  "uptime": uptime,
432
487
  "running": self.running,
@@ -448,101 +503,65 @@ class SharedTensorServer:
448
503
  "data": None,
449
504
  }
450
505
 
451
- def _resolve_process_start_method(self) -> str:
452
- if self.process_start_method is not None:
453
- allowed = set(mp.get_all_start_methods())
454
- if self.process_start_method not in allowed:
455
- raise SharedTensorConfigurationError(
456
- f"Unsupported process_start_method '{self.process_start_method}'"
457
- )
458
- return self.process_start_method
459
- if os.name != "posix":
460
- return "spawn"
461
- try:
462
- import torch
463
- except ImportError:
464
- torch = None
465
- if torch is not None and torch.cuda.is_available() and torch.cuda.is_initialized():
466
- return "spawn"
467
- if not hasattr(sys.modules.get("__main__"), "__file__"):
468
- return "fork"
469
- return "spawn"
470
-
471
506
  def start(self, blocking: bool = True) -> None:
472
507
  if self.verbose_debug:
473
508
  logger.info("Server starting", extra={"socket_path": self.socket_path, "blocking": blocking})
474
- if self.running:
509
+ if self.running or self.server_thread is not None:
475
510
  raise SharedTensorConfigurationError("Server is already running")
476
511
  if blocking:
477
512
  self._resolved_process_start_method = None
478
513
  self._serve_forever()
479
514
  return
480
- if os.name != "posix":
515
+ if self.process_start_method is not None:
481
516
  raise SharedTensorConfigurationError(
482
- "Non-blocking shared_tensor servers require POSIX multiprocessing support"
517
+ "process_start_method is not supported for thread-backed non-blocking servers"
483
518
  )
484
- start_method = self._resolve_process_start_method()
485
- payload = cloudpickle.dumps(self.provider)
486
- process = mp.get_context(start_method).Process(
487
- target=self._serve_forever_from_payload,
488
- args=(
489
- payload,
490
- self.socket_path,
491
- self.max_request_bytes,
492
- self.max_workers,
493
- self.result_ttl,
494
- self.verbose_debug,
495
- start_method,
496
- ),
497
- name=f"shared-tensor-daemon:{self.socket_path}",
519
+ thread = threading.Thread(
520
+ target=self._serve_forever_in_thread,
521
+ name=f"shared-tensor-server:{self.socket_path}",
522
+ daemon=True,
498
523
  )
499
- process.start()
500
- if self.verbose_debug:
501
- logger.info(
502
- "Server spawned background process",
503
- extra={"socket_path": self.socket_path, "pid": process.pid, "start_method": start_method},
504
- )
505
- self.server_process = process
506
- self._resolved_process_start_method = start_method
507
- self.running = True
508
- self.started_at = time.time()
524
+ state = _ServerThreadState(thread=thread)
525
+ self.server_thread = state
526
+ self._resolved_process_start_method = "thread"
527
+ thread.start()
528
+ if not state.ready.wait(timeout=self.startup_timeout):
529
+ self.stop()
530
+ raise TimeoutError(f"Timed out waiting for server socket {self.socket_path}")
531
+ if state.error is not None:
532
+ error = state.error
533
+ self.stop()
534
+ raise SharedTensorConfigurationError(
535
+ f"Failed to start background server thread for {self.socket_path}: {error}"
536
+ ) from error
509
537
 
510
- @staticmethod
511
- def _serve_forever_from_payload(
512
- payload: bytes,
513
- socket_path: str,
514
- max_request_bytes: int,
515
- max_workers: int,
516
- result_ttl: float,
517
- verbose_debug: bool,
518
- process_start_method: str | None,
519
- ) -> None:
520
- SharedTensorServer._configure_cuda_runtime()
521
- provider = cloudpickle.loads(payload)
522
- server = SharedTensorServer(
523
- provider,
524
- socket_path=socket_path,
525
- max_request_bytes=max_request_bytes,
526
- max_workers=max_workers,
527
- result_ttl=result_ttl,
528
- process_start_method=process_start_method,
529
- verbose_debug=verbose_debug,
530
- )
531
- server._resolved_process_start_method = process_start_method
532
- server._serve_forever()
538
+ def _serve_forever_in_thread(self) -> None:
539
+ state = self.server_thread
540
+ if state is None:
541
+ return
542
+ try:
543
+ self._serve_forever(started_event=state.ready)
544
+ except BaseException as exc: # noqa: BLE001
545
+ state.error = exc
546
+ state.ready.set()
547
+ raise
548
+ finally:
549
+ state.stopped.set()
533
550
 
534
- def _serve_forever(self) -> None:
551
+ def _serve_forever(self, *, started_event: threading.Event | None = None) -> None:
535
552
  self._configure_cuda_runtime()
536
553
  unlink_socket_path(self.socket_path)
537
554
  listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
538
- listener.bind(self.socket_path)
539
- listener.listen()
540
- if self.verbose_debug:
541
- logger.info("Server listening", extra={"socket_path": self.socket_path})
542
- self.listener = listener
543
- self.running = True
544
- self.started_at = time.time()
545
555
  try:
556
+ listener.bind(self.socket_path)
557
+ listener.listen()
558
+ if self.verbose_debug:
559
+ logger.info("Server listening", extra={"socket_path": self.socket_path})
560
+ self.listener = listener
561
+ self.running = True
562
+ self.started_at = time.time()
563
+ if started_event is not None:
564
+ started_event.set()
546
565
  while self.running:
547
566
  try:
548
567
  conn, _ = listener.accept()
@@ -553,6 +572,8 @@ class SharedTensorServer:
553
572
  thread = threading.Thread(target=self._handle_connection, args=(conn,), daemon=True)
554
573
  thread.start()
555
574
  finally:
575
+ if started_event is not None and not started_event.is_set():
576
+ started_event.set()
556
577
  self._shutdown_local_resources()
557
578
 
558
579
  def _handle_connection(self, conn: socket.socket) -> None:
@@ -586,24 +607,20 @@ class SharedTensorServer:
586
607
  def stop(self) -> None:
587
608
  if self.verbose_debug:
588
609
  logger.info("Server stopping", extra={"socket_path": self.socket_path})
589
- if not self.running:
590
- unlink_socket_path(self.socket_path)
591
- return
592
610
  self.running = False
593
- if self.server_process is not None:
594
- self.server_process.terminate()
595
- self.server_process.join(timeout=5)
596
- if self.server_process.is_alive():
597
- self.server_process.kill()
598
- self.server_process.join(timeout=5)
599
- self.server_process = None
600
- unlink_socket_path(self.socket_path)
601
- return
602
611
  if self.listener is not None:
603
612
  self.listener.close()
604
- self._shutdown_local_resources()
613
+ state = self.server_thread
614
+ if state is not None and state.thread.is_alive() and threading.current_thread() is not state.thread:
615
+ state.stopped.wait(timeout=5)
616
+ state.thread.join(timeout=5)
617
+ self.server_thread = None
618
+ self.server_process = None
619
+ if self.listener is None:
620
+ unlink_socket_path(self.socket_path)
605
621
 
606
622
  def _shutdown_local_resources(self) -> None:
623
+ self.running = False
607
624
  if self.listener is not None:
608
625
  self.listener.close()
609
626
  self.listener = None
@@ -612,6 +629,7 @@ class SharedTensorServer:
612
629
  self._task_manager = None
613
630
  self._managed_objects.clear()
614
631
  self._cache.clear()
632
+ self._local_cache.clear()
615
633
  self._inflight.clear()
616
634
  self._endpoint_locks.clear()
617
635
  unlink_socket_path(self.socket_path)
File without changes
File without changes
File without changes