comfy-env 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfy_env/__init__.py +161 -0
- comfy_env/cli.py +388 -0
- comfy_env/decorator.py +422 -0
- comfy_env/env/__init__.py +30 -0
- comfy_env/env/config.py +144 -0
- comfy_env/env/config_file.py +592 -0
- comfy_env/env/detection.py +176 -0
- comfy_env/env/manager.py +657 -0
- comfy_env/env/platform/__init__.py +21 -0
- comfy_env/env/platform/base.py +96 -0
- comfy_env/env/platform/darwin.py +53 -0
- comfy_env/env/platform/linux.py +68 -0
- comfy_env/env/platform/windows.py +377 -0
- comfy_env/env/security.py +267 -0
- comfy_env/errors.py +325 -0
- comfy_env/install.py +539 -0
- comfy_env/ipc/__init__.py +55 -0
- comfy_env/ipc/bridge.py +512 -0
- comfy_env/ipc/protocol.py +265 -0
- comfy_env/ipc/tensor.py +371 -0
- comfy_env/ipc/torch_bridge.py +401 -0
- comfy_env/ipc/transport.py +318 -0
- comfy_env/ipc/worker.py +221 -0
- comfy_env/registry.py +252 -0
- comfy_env/resolver.py +399 -0
- comfy_env/runner.py +273 -0
- comfy_env/stubs/__init__.py +1 -0
- comfy_env/stubs/folder_paths.py +57 -0
- comfy_env/workers/__init__.py +49 -0
- comfy_env/workers/base.py +82 -0
- comfy_env/workers/pool.py +241 -0
- comfy_env/workers/tensor_utils.py +188 -0
- comfy_env/workers/torch_mp.py +375 -0
- comfy_env/workers/venv.py +903 -0
- comfy_env-0.0.8.dist-info/METADATA +228 -0
- comfy_env-0.0.8.dist-info/RECORD +39 -0
- comfy_env-0.0.8.dist-info/WHEEL +4 -0
- comfy_env-0.0.8.dist-info/entry_points.txt +2 -0
- comfy_env-0.0.8.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tensor utilities for robust IPC handling.
|
|
3
|
+
|
|
4
|
+
Patterns borrowed from pyisolate (MIT licensed):
|
|
5
|
+
- TensorKeeper: Prevents GC race conditions
|
|
6
|
+
- CUDA IPC re-share detection: Graceful handling of received tensors
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import collections
|
|
10
|
+
import logging
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("comfy_env")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
# TensorKeeper - Prevents GC Race Conditions
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
|
|
22
|
+
class TensorKeeper:
|
|
23
|
+
"""
|
|
24
|
+
Keeps strong references to tensors during IPC to prevent premature GC.
|
|
25
|
+
|
|
26
|
+
Problem this solves:
|
|
27
|
+
When a tensor is serialized for IPC, the serialization returns
|
|
28
|
+
immediately but the receiving process may not have opened the
|
|
29
|
+
shared memory yet. If the sending process's tensor gets garbage
|
|
30
|
+
collected, the shared memory file is deleted, causing
|
|
31
|
+
"No such file or directory" errors on the receiver.
|
|
32
|
+
|
|
33
|
+
Solution:
|
|
34
|
+
Keep strong references to tensors for a configurable window
|
|
35
|
+
(default 30 seconds) to ensure the receiver has time to open them.
|
|
36
|
+
|
|
37
|
+
Usage:
|
|
38
|
+
keeper = TensorKeeper()
|
|
39
|
+
keeper.keep(tensor) # Call before putting on queue
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, retention_seconds: float = 30.0):
|
|
43
|
+
"""
|
|
44
|
+
Args:
|
|
45
|
+
retention_seconds: How long to keep tensor references.
|
|
46
|
+
30s is safe for slow systems.
|
|
47
|
+
"""
|
|
48
|
+
self.retention_seconds = retention_seconds
|
|
49
|
+
self._keeper: collections.deque = collections.deque()
|
|
50
|
+
self._lock = threading.Lock()
|
|
51
|
+
|
|
52
|
+
def keep(self, t: Any) -> None:
|
|
53
|
+
"""Keep a strong reference to tensor for retention_seconds."""
|
|
54
|
+
# Only keep torch tensors
|
|
55
|
+
try:
|
|
56
|
+
import torch
|
|
57
|
+
if not isinstance(t, torch.Tensor):
|
|
58
|
+
return
|
|
59
|
+
except ImportError:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
now = time.time()
|
|
63
|
+
with self._lock:
|
|
64
|
+
self._keeper.append((now, t))
|
|
65
|
+
|
|
66
|
+
# Cleanup old entries
|
|
67
|
+
while self._keeper:
|
|
68
|
+
timestamp, _ = self._keeper[0]
|
|
69
|
+
if now - timestamp > self.retention_seconds:
|
|
70
|
+
self._keeper.popleft()
|
|
71
|
+
else:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
def keep_recursive(self, obj: Any) -> None:
|
|
75
|
+
"""Recursively keep all tensors in a nested structure."""
|
|
76
|
+
try:
|
|
77
|
+
import torch
|
|
78
|
+
if isinstance(obj, torch.Tensor):
|
|
79
|
+
self.keep(obj)
|
|
80
|
+
elif isinstance(obj, (list, tuple)):
|
|
81
|
+
for item in obj:
|
|
82
|
+
self.keep_recursive(item)
|
|
83
|
+
elif isinstance(obj, dict):
|
|
84
|
+
for v in obj.values():
|
|
85
|
+
self.keep_recursive(v)
|
|
86
|
+
except ImportError:
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
def __len__(self) -> int:
|
|
90
|
+
"""Return number of tensors currently being kept."""
|
|
91
|
+
with self._lock:
|
|
92
|
+
return len(self._keeper)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Global instance
|
|
96
|
+
_tensor_keeper = TensorKeeper()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def keep_tensor(t: Any) -> None:
|
|
100
|
+
"""Keep a tensor reference to prevent GC during IPC."""
|
|
101
|
+
_tensor_keeper.keep(t)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def keep_tensors_recursive(obj: Any) -> None:
|
|
105
|
+
"""Keep all tensor references in a nested structure."""
|
|
106
|
+
_tensor_keeper.keep_recursive(obj)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ---------------------------------------------------------------------------
|
|
110
|
+
# CUDA IPC Re-share Detection
|
|
111
|
+
# ---------------------------------------------------------------------------
|
|
112
|
+
|
|
113
|
+
def prepare_tensor_for_ipc(t: Any) -> Any:
|
|
114
|
+
"""
|
|
115
|
+
Prepare a tensor for IPC, handling CUDA IPC re-share limitation.
|
|
116
|
+
|
|
117
|
+
Problem this solves:
|
|
118
|
+
Tensors received via CUDA IPC cannot be re-shared. If a node
|
|
119
|
+
receives a tensor via IPC and tries to return it, you get:
|
|
120
|
+
"RuntimeError: Attempted to send CUDA tensor received from
|
|
121
|
+
another process; this is not currently supported."
|
|
122
|
+
|
|
123
|
+
Solution:
|
|
124
|
+
Detect this situation and clone the tensor. Log a warning for
|
|
125
|
+
large tensors so users can optimize their pipelines.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
t: A tensor (or non-tensor, which is returned as-is)
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The tensor, possibly cloned if it was received via IPC.
|
|
132
|
+
"""
|
|
133
|
+
try:
|
|
134
|
+
import torch
|
|
135
|
+
if not isinstance(t, torch.Tensor):
|
|
136
|
+
return t
|
|
137
|
+
|
|
138
|
+
if not t.is_cuda:
|
|
139
|
+
# CPU tensors don't have this limitation
|
|
140
|
+
return t
|
|
141
|
+
|
|
142
|
+
# Test if tensor can be shared
|
|
143
|
+
import torch.multiprocessing.reductions as reductions
|
|
144
|
+
try:
|
|
145
|
+
func, args = reductions.reduce_tensor(t)
|
|
146
|
+
return t # Can be shared as-is
|
|
147
|
+
except RuntimeError as e:
|
|
148
|
+
if "received from another process" in str(e):
|
|
149
|
+
# This tensor was received via IPC and can't be re-shared
|
|
150
|
+
tensor_size_mb = t.numel() * t.element_size() / (1024 * 1024)
|
|
151
|
+
if tensor_size_mb > 100:
|
|
152
|
+
logger.warning(
|
|
153
|
+
f"PERFORMANCE: Cloning large CUDA tensor ({tensor_size_mb:.1f}MB) "
|
|
154
|
+
"received from another process. Consider modifying the node "
|
|
155
|
+
"to avoid returning unmodified input tensors."
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
logger.debug(
|
|
159
|
+
f"Cloning CUDA tensor ({tensor_size_mb:.2f}MB) received from another process"
|
|
160
|
+
)
|
|
161
|
+
return t.clone()
|
|
162
|
+
raise
|
|
163
|
+
|
|
164
|
+
except ImportError:
|
|
165
|
+
return t
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def prepare_for_ipc_recursive(obj: Any) -> Any:
|
|
169
|
+
"""
|
|
170
|
+
Recursively prepare all tensors in a nested structure for IPC.
|
|
171
|
+
|
|
172
|
+
Also keeps tensor references to prevent GC.
|
|
173
|
+
"""
|
|
174
|
+
try:
|
|
175
|
+
import torch
|
|
176
|
+
if isinstance(obj, torch.Tensor):
|
|
177
|
+
prepared = prepare_tensor_for_ipc(obj)
|
|
178
|
+
keep_tensor(prepared)
|
|
179
|
+
return prepared
|
|
180
|
+
elif isinstance(obj, list):
|
|
181
|
+
return [prepare_for_ipc_recursive(x) for x in obj]
|
|
182
|
+
elif isinstance(obj, tuple):
|
|
183
|
+
return tuple(prepare_for_ipc_recursive(x) for x in obj)
|
|
184
|
+
elif isinstance(obj, dict):
|
|
185
|
+
return {k: prepare_for_ipc_recursive(v) for k, v in obj.items()}
|
|
186
|
+
except ImportError:
|
|
187
|
+
pass
|
|
188
|
+
return obj
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TorchMPWorker - Same-venv isolation using torch.multiprocessing.
|
|
3
|
+
|
|
4
|
+
This is the simplest and fastest worker type:
|
|
5
|
+
- Uses torch.multiprocessing.Queue for IPC
|
|
6
|
+
- Zero-copy tensor transfer via CUDA IPC (automatic)
|
|
7
|
+
- Fresh CUDA context in subprocess
|
|
8
|
+
- ~30ms overhead per call
|
|
9
|
+
|
|
10
|
+
Use this when you need:
|
|
11
|
+
- Memory isolation between nodes
|
|
12
|
+
- Fresh CUDA context (automatic VRAM cleanup on worker death)
|
|
13
|
+
- Same Python environment as host
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
worker = TorchMPWorker()
|
|
17
|
+
|
|
18
|
+
def gpu_work(image):
|
|
19
|
+
import torch
|
|
20
|
+
return image * 2
|
|
21
|
+
|
|
22
|
+
result = worker.call(gpu_work, image=my_tensor)
|
|
23
|
+
worker.shutdown()
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
import traceback
|
|
28
|
+
from queue import Empty as QueueEmpty
|
|
29
|
+
from typing import Any, Callable, Optional
|
|
30
|
+
|
|
31
|
+
from .base import Worker, WorkerError
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger("comfy_env")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Sentinel value for shutdown
|
|
37
|
+
_SHUTDOWN = object()
|
|
38
|
+
|
|
39
|
+
# Message type for method calls (avoids pickling issues with functions)
|
|
40
|
+
_CALL_METHOD = "call_method"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _worker_loop(queue_in, queue_out):
|
|
44
|
+
"""
|
|
45
|
+
Worker process main loop.
|
|
46
|
+
|
|
47
|
+
Receives work items and executes them:
|
|
48
|
+
- ("call_method", module_name, class_name, method_name, self_state, kwargs): Call a method on a class
|
|
49
|
+
- (func, args, kwargs): Execute a function directly
|
|
50
|
+
- _SHUTDOWN: Shutdown the worker
|
|
51
|
+
|
|
52
|
+
Runs until receiving _SHUTDOWN sentinel.
|
|
53
|
+
"""
|
|
54
|
+
import importlib
|
|
55
|
+
import os
|
|
56
|
+
import sys
|
|
57
|
+
|
|
58
|
+
# Set worker mode env var
|
|
59
|
+
os.environ["COMFYUI_ISOLATION_WORKER"] = "1"
|
|
60
|
+
|
|
61
|
+
while True:
|
|
62
|
+
try:
|
|
63
|
+
item = queue_in.get()
|
|
64
|
+
|
|
65
|
+
# Check for shutdown signal
|
|
66
|
+
if item is _SHUTDOWN:
|
|
67
|
+
queue_out.put(("shutdown", None))
|
|
68
|
+
break
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
# Handle method call protocol
|
|
72
|
+
if isinstance(item, tuple) and len(item) == 6 and item[0] == _CALL_METHOD:
|
|
73
|
+
_, module_name, class_name, method_name, self_state, kwargs = item
|
|
74
|
+
result = _execute_method_call(
|
|
75
|
+
module_name, class_name, method_name, self_state, kwargs
|
|
76
|
+
)
|
|
77
|
+
queue_out.put(("ok", result))
|
|
78
|
+
else:
|
|
79
|
+
# Direct function call (legacy)
|
|
80
|
+
func, args, kwargs = item
|
|
81
|
+
result = func(*args, **kwargs)
|
|
82
|
+
queue_out.put(("ok", result))
|
|
83
|
+
|
|
84
|
+
except Exception as e:
|
|
85
|
+
tb = traceback.format_exc()
|
|
86
|
+
queue_out.put(("error", (str(e), tb)))
|
|
87
|
+
|
|
88
|
+
except Exception as e:
|
|
89
|
+
# Queue error - try to report, then exit
|
|
90
|
+
try:
|
|
91
|
+
queue_out.put(("fatal", str(e)))
|
|
92
|
+
except:
|
|
93
|
+
pass
|
|
94
|
+
break
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _execute_method_call(module_name: str, class_name: str, method_name: str,
|
|
98
|
+
self_state: dict, kwargs: dict) -> Any:
|
|
99
|
+
"""
|
|
100
|
+
Execute a method call in the worker process.
|
|
101
|
+
|
|
102
|
+
This function imports the class fresh and calls the original (un-decorated) method.
|
|
103
|
+
"""
|
|
104
|
+
import importlib
|
|
105
|
+
|
|
106
|
+
# Import the module
|
|
107
|
+
module = importlib.import_module(module_name)
|
|
108
|
+
cls = getattr(module, class_name)
|
|
109
|
+
|
|
110
|
+
# Create instance with proper __slots__ handling
|
|
111
|
+
instance = object.__new__(cls)
|
|
112
|
+
|
|
113
|
+
# Handle both __slots__ and __dict__ based classes
|
|
114
|
+
if hasattr(cls, '__slots__'):
|
|
115
|
+
# Class uses __slots__ - set attributes individually
|
|
116
|
+
for slot in cls.__slots__:
|
|
117
|
+
if slot in self_state:
|
|
118
|
+
setattr(instance, slot, self_state[slot])
|
|
119
|
+
# Also check for __dict__ slot (hybrid classes)
|
|
120
|
+
if '__dict__' in cls.__slots__ or hasattr(instance, '__dict__'):
|
|
121
|
+
for key, value in self_state.items():
|
|
122
|
+
if key not in cls.__slots__:
|
|
123
|
+
setattr(instance, key, value)
|
|
124
|
+
else:
|
|
125
|
+
# Standard class with __dict__
|
|
126
|
+
instance.__dict__.update(self_state)
|
|
127
|
+
|
|
128
|
+
# Get the ORIGINAL method stored by the decorator, not the proxy
|
|
129
|
+
# This avoids the infinite recursion of proxy -> worker -> proxy
|
|
130
|
+
original_method = getattr(cls, '_isolated_original_method', None)
|
|
131
|
+
if original_method is None:
|
|
132
|
+
# Fallback: class wasn't decorated, use the method directly
|
|
133
|
+
original_method = getattr(cls, method_name)
|
|
134
|
+
return original_method(instance, **kwargs)
|
|
135
|
+
|
|
136
|
+
# Call the original method (it's an unbound function, pass instance)
|
|
137
|
+
return original_method(instance, **kwargs)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class TorchMPWorker(Worker):
|
|
141
|
+
"""
|
|
142
|
+
Worker using torch.multiprocessing for same-venv isolation.
|
|
143
|
+
|
|
144
|
+
Features:
|
|
145
|
+
- Zero-copy CUDA tensor transfer (via CUDA IPC handles)
|
|
146
|
+
- Zero-copy CPU tensor transfer (via shared memory)
|
|
147
|
+
- Fresh CUDA context (subprocess has independent GPU state)
|
|
148
|
+
- Automatic cleanup on worker death
|
|
149
|
+
|
|
150
|
+
The subprocess uses 'spawn' start method, ensuring a clean Python
|
|
151
|
+
interpreter without inherited state from the parent.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, name: Optional[str] = None):
|
|
155
|
+
"""
|
|
156
|
+
Initialize the worker.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
name: Optional name for logging/debugging.
|
|
160
|
+
"""
|
|
161
|
+
self.name = name or "TorchMPWorker"
|
|
162
|
+
self._process = None
|
|
163
|
+
self._queue_in = None
|
|
164
|
+
self._queue_out = None
|
|
165
|
+
self._started = False
|
|
166
|
+
self._shutdown = False
|
|
167
|
+
|
|
168
|
+
def _ensure_started(self):
|
|
169
|
+
"""Lazily start the worker process on first call."""
|
|
170
|
+
if self._shutdown:
|
|
171
|
+
raise RuntimeError(f"{self.name}: Worker has been shut down")
|
|
172
|
+
|
|
173
|
+
if self._started:
|
|
174
|
+
if not self._process.is_alive():
|
|
175
|
+
raise RuntimeError(f"{self.name}: Worker process died unexpectedly")
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# Import torch here to avoid import at module level
|
|
179
|
+
import torch.multiprocessing as mp
|
|
180
|
+
|
|
181
|
+
# Use spawn to get clean subprocess (no inherited CUDA context)
|
|
182
|
+
ctx = mp.get_context('spawn')
|
|
183
|
+
|
|
184
|
+
self._queue_in = ctx.Queue()
|
|
185
|
+
self._queue_out = ctx.Queue()
|
|
186
|
+
self._process = ctx.Process(
|
|
187
|
+
target=_worker_loop,
|
|
188
|
+
args=(self._queue_in, self._queue_out),
|
|
189
|
+
daemon=True,
|
|
190
|
+
)
|
|
191
|
+
self._process.start()
|
|
192
|
+
self._started = True
|
|
193
|
+
|
|
194
|
+
def call(
|
|
195
|
+
self,
|
|
196
|
+
func: Callable,
|
|
197
|
+
*args,
|
|
198
|
+
timeout: Optional[float] = None,
|
|
199
|
+
**kwargs
|
|
200
|
+
) -> Any:
|
|
201
|
+
"""
|
|
202
|
+
Execute a function in the worker process.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
func: Function to execute. Must be picklable (module-level or staticmethod).
|
|
206
|
+
*args: Positional arguments.
|
|
207
|
+
timeout: Timeout in seconds (None = no timeout, default).
|
|
208
|
+
**kwargs: Keyword arguments.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Return value of func(*args, **kwargs).
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
WorkerError: If func raises an exception.
|
|
215
|
+
TimeoutError: If execution exceeds timeout.
|
|
216
|
+
RuntimeError: If worker process dies.
|
|
217
|
+
"""
|
|
218
|
+
self._ensure_started()
|
|
219
|
+
|
|
220
|
+
# Send work item
|
|
221
|
+
self._queue_in.put((func, args, kwargs))
|
|
222
|
+
|
|
223
|
+
return self._get_result(timeout)
|
|
224
|
+
|
|
225
|
+
def call_method(
|
|
226
|
+
self,
|
|
227
|
+
module_name: str,
|
|
228
|
+
class_name: str,
|
|
229
|
+
method_name: str,
|
|
230
|
+
self_state: dict,
|
|
231
|
+
kwargs: dict,
|
|
232
|
+
timeout: Optional[float] = None,
|
|
233
|
+
) -> Any:
|
|
234
|
+
"""
|
|
235
|
+
Execute a class method in the worker process.
|
|
236
|
+
|
|
237
|
+
This uses a string-based protocol to avoid pickle issues with decorated methods.
|
|
238
|
+
The worker imports the module fresh and calls the original (un-decorated) method.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
module_name: Full module path (e.g., 'my_package.nodes.my_node')
|
|
242
|
+
class_name: Class name (e.g., 'MyNode')
|
|
243
|
+
method_name: Method name (e.g., 'process')
|
|
244
|
+
self_state: Instance __dict__ to restore
|
|
245
|
+
kwargs: Method keyword arguments
|
|
246
|
+
timeout: Timeout in seconds (None = no timeout, default).
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Return value of method.
|
|
250
|
+
|
|
251
|
+
Raises:
|
|
252
|
+
WorkerError: If method raises an exception.
|
|
253
|
+
TimeoutError: If execution exceeds timeout.
|
|
254
|
+
RuntimeError: If worker process dies.
|
|
255
|
+
"""
|
|
256
|
+
self._ensure_started()
|
|
257
|
+
|
|
258
|
+
# Send method call request using protocol
|
|
259
|
+
self._queue_in.put((
|
|
260
|
+
_CALL_METHOD,
|
|
261
|
+
module_name,
|
|
262
|
+
class_name,
|
|
263
|
+
method_name,
|
|
264
|
+
self_state,
|
|
265
|
+
kwargs,
|
|
266
|
+
))
|
|
267
|
+
|
|
268
|
+
return self._get_result(timeout)
|
|
269
|
+
|
|
270
|
+
def _get_result(self, timeout: Optional[float]) -> Any:
|
|
271
|
+
"""Wait for and return result from worker."""
|
|
272
|
+
try:
|
|
273
|
+
status, result = self._queue_out.get(timeout=timeout)
|
|
274
|
+
except QueueEmpty:
|
|
275
|
+
# Timeout - use graceful escalation
|
|
276
|
+
self._handle_timeout(timeout)
|
|
277
|
+
# _handle_timeout always raises, but just in case:
|
|
278
|
+
raise TimeoutError(f"{self.name}: Call timed out after {timeout}s")
|
|
279
|
+
except Exception as e:
|
|
280
|
+
raise RuntimeError(f"{self.name}: Failed to get result: {e}")
|
|
281
|
+
|
|
282
|
+
# Handle response
|
|
283
|
+
if status == "ok":
|
|
284
|
+
return result
|
|
285
|
+
elif status == "error":
|
|
286
|
+
msg, tb = result
|
|
287
|
+
raise WorkerError(msg, traceback=tb)
|
|
288
|
+
elif status == "fatal":
|
|
289
|
+
self._shutdown = True
|
|
290
|
+
raise RuntimeError(f"{self.name}: Fatal worker error: {result}")
|
|
291
|
+
else:
|
|
292
|
+
raise RuntimeError(f"{self.name}: Unknown response status: {status}")
|
|
293
|
+
|
|
294
|
+
def shutdown(self) -> None:
|
|
295
|
+
"""Shut down the worker process."""
|
|
296
|
+
if self._shutdown or not self._started:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
self._shutdown = True
|
|
300
|
+
|
|
301
|
+
try:
|
|
302
|
+
# Send shutdown signal
|
|
303
|
+
self._queue_in.put(_SHUTDOWN)
|
|
304
|
+
|
|
305
|
+
# Wait for acknowledgment
|
|
306
|
+
try:
|
|
307
|
+
self._queue_out.get(timeout=5.0)
|
|
308
|
+
except:
|
|
309
|
+
pass
|
|
310
|
+
|
|
311
|
+
# Wait for process to exit
|
|
312
|
+
self._process.join(timeout=5.0)
|
|
313
|
+
|
|
314
|
+
if self._process.is_alive():
|
|
315
|
+
self._process.kill()
|
|
316
|
+
self._process.join(timeout=1.0)
|
|
317
|
+
|
|
318
|
+
except Exception:
|
|
319
|
+
# Force kill if anything goes wrong
|
|
320
|
+
if self._process and self._process.is_alive():
|
|
321
|
+
self._process.kill()
|
|
322
|
+
|
|
323
|
+
def _handle_timeout(self, timeout: float) -> None:
|
|
324
|
+
"""
|
|
325
|
+
Handle timeout with graceful escalation.
|
|
326
|
+
|
|
327
|
+
Instead of immediately killing the worker (which can leak GPU memory),
|
|
328
|
+
try graceful shutdown first, then escalate to SIGTERM, then SIGKILL.
|
|
329
|
+
|
|
330
|
+
Inspired by pyisolate's timeout handling pattern.
|
|
331
|
+
"""
|
|
332
|
+
logger.warning(f"{self.name}: Call timed out after {timeout}s, attempting graceful shutdown")
|
|
333
|
+
|
|
334
|
+
# Stage 1: Send shutdown signal, wait 3s for graceful exit
|
|
335
|
+
try:
|
|
336
|
+
self._queue_in.put(_SHUTDOWN)
|
|
337
|
+
self._queue_out.get(timeout=3.0)
|
|
338
|
+
self._process.join(timeout=2.0)
|
|
339
|
+
if not self._process.is_alive():
|
|
340
|
+
self._shutdown = True
|
|
341
|
+
raise TimeoutError(f"{self.name}: Graceful shutdown after timeout ({timeout}s)")
|
|
342
|
+
except QueueEmpty:
|
|
343
|
+
pass
|
|
344
|
+
except TimeoutError:
|
|
345
|
+
raise
|
|
346
|
+
except Exception:
|
|
347
|
+
pass
|
|
348
|
+
|
|
349
|
+
# Stage 2: SIGTERM, wait 5s
|
|
350
|
+
if self._process.is_alive():
|
|
351
|
+
logger.warning(f"{self.name}: Graceful shutdown failed, sending SIGTERM")
|
|
352
|
+
self._process.terminate()
|
|
353
|
+
self._process.join(timeout=5.0)
|
|
354
|
+
|
|
355
|
+
# Stage 3: SIGKILL as last resort
|
|
356
|
+
if self._process.is_alive():
|
|
357
|
+
logger.error(f"{self.name}: SIGTERM failed, force killing worker (may leak GPU memory)")
|
|
358
|
+
self._process.kill()
|
|
359
|
+
self._process.join(timeout=1.0)
|
|
360
|
+
|
|
361
|
+
self._shutdown = True
|
|
362
|
+
raise TimeoutError(f"{self.name}: Call timed out after {timeout}s")
|
|
363
|
+
|
|
364
|
+
def is_alive(self) -> bool:
|
|
365
|
+
"""Check if worker process is running or can be started."""
|
|
366
|
+
if self._shutdown:
|
|
367
|
+
return False
|
|
368
|
+
# Not started yet = can still be started = "alive"
|
|
369
|
+
if not self._started:
|
|
370
|
+
return True
|
|
371
|
+
return self._process.is_alive()
|
|
372
|
+
|
|
373
|
+
def __repr__(self):
|
|
374
|
+
status = "alive" if self.is_alive() else "stopped"
|
|
375
|
+
return f"<TorchMPWorker name={self.name!r} status={status}>"
|