comfy-env 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,188 @@
1
+ """
2
+ Tensor utilities for robust IPC handling.
3
+
4
+ Patterns borrowed from pyisolate (MIT licensed):
5
+ - TensorKeeper: Prevents GC race conditions
6
+ - CUDA IPC re-share detection: Graceful handling of received tensors
7
+ """
8
+
9
+ import collections
10
+ import logging
11
+ import threading
12
+ import time
13
+ from typing import Any
14
+
15
+ logger = logging.getLogger("comfy_env")
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # TensorKeeper - Prevents GC Race Conditions
20
+ # ---------------------------------------------------------------------------
21
+
22
+ class TensorKeeper:
23
+ """
24
+ Keeps strong references to tensors during IPC to prevent premature GC.
25
+
26
+ Problem this solves:
27
+ When a tensor is serialized for IPC, the serialization returns
28
+ immediately but the receiving process may not have opened the
29
+ shared memory yet. If the sending process's tensor gets garbage
30
+ collected, the shared memory file is deleted, causing
31
+ "No such file or directory" errors on the receiver.
32
+
33
+ Solution:
34
+ Keep strong references to tensors for a configurable window
35
+ (default 30 seconds) to ensure the receiver has time to open them.
36
+
37
+ Usage:
38
+ keeper = TensorKeeper()
39
+ keeper.keep(tensor) # Call before putting on queue
40
+ """
41
+
42
+ def __init__(self, retention_seconds: float = 30.0):
43
+ """
44
+ Args:
45
+ retention_seconds: How long to keep tensor references.
46
+ 30s is safe for slow systems.
47
+ """
48
+ self.retention_seconds = retention_seconds
49
+ self._keeper: collections.deque = collections.deque()
50
+ self._lock = threading.Lock()
51
+
52
+ def keep(self, t: Any) -> None:
53
+ """Keep a strong reference to tensor for retention_seconds."""
54
+ # Only keep torch tensors
55
+ try:
56
+ import torch
57
+ if not isinstance(t, torch.Tensor):
58
+ return
59
+ except ImportError:
60
+ return
61
+
62
+ now = time.time()
63
+ with self._lock:
64
+ self._keeper.append((now, t))
65
+
66
+ # Cleanup old entries
67
+ while self._keeper:
68
+ timestamp, _ = self._keeper[0]
69
+ if now - timestamp > self.retention_seconds:
70
+ self._keeper.popleft()
71
+ else:
72
+ break
73
+
74
+ def keep_recursive(self, obj: Any) -> None:
75
+ """Recursively keep all tensors in a nested structure."""
76
+ try:
77
+ import torch
78
+ if isinstance(obj, torch.Tensor):
79
+ self.keep(obj)
80
+ elif isinstance(obj, (list, tuple)):
81
+ for item in obj:
82
+ self.keep_recursive(item)
83
+ elif isinstance(obj, dict):
84
+ for v in obj.values():
85
+ self.keep_recursive(v)
86
+ except ImportError:
87
+ pass
88
+
89
+ def __len__(self) -> int:
90
+ """Return number of tensors currently being kept."""
91
+ with self._lock:
92
+ return len(self._keeper)
93
+
94
+
95
+ # Global instance
96
+ _tensor_keeper = TensorKeeper()
97
+
98
+
99
+ def keep_tensor(t: Any) -> None:
100
+ """Keep a tensor reference to prevent GC during IPC."""
101
+ _tensor_keeper.keep(t)
102
+
103
+
104
+ def keep_tensors_recursive(obj: Any) -> None:
105
+ """Keep all tensor references in a nested structure."""
106
+ _tensor_keeper.keep_recursive(obj)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # CUDA IPC Re-share Detection
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def prepare_tensor_for_ipc(t: Any) -> Any:
114
+ """
115
+ Prepare a tensor for IPC, handling CUDA IPC re-share limitation.
116
+
117
+ Problem this solves:
118
+ Tensors received via CUDA IPC cannot be re-shared. If a node
119
+ receives a tensor via IPC and tries to return it, you get:
120
+ "RuntimeError: Attempted to send CUDA tensor received from
121
+ another process; this is not currently supported."
122
+
123
+ Solution:
124
+ Detect this situation and clone the tensor. Log a warning for
125
+ large tensors so users can optimize their pipelines.
126
+
127
+ Args:
128
+ t: A tensor (or non-tensor, which is returned as-is)
129
+
130
+ Returns:
131
+ The tensor, possibly cloned if it was received via IPC.
132
+ """
133
+ try:
134
+ import torch
135
+ if not isinstance(t, torch.Tensor):
136
+ return t
137
+
138
+ if not t.is_cuda:
139
+ # CPU tensors don't have this limitation
140
+ return t
141
+
142
+ # Test if tensor can be shared
143
+ import torch.multiprocessing.reductions as reductions
144
+ try:
145
+ func, args = reductions.reduce_tensor(t)
146
+ return t # Can be shared as-is
147
+ except RuntimeError as e:
148
+ if "received from another process" in str(e):
149
+ # This tensor was received via IPC and can't be re-shared
150
+ tensor_size_mb = t.numel() * t.element_size() / (1024 * 1024)
151
+ if tensor_size_mb > 100:
152
+ logger.warning(
153
+ f"PERFORMANCE: Cloning large CUDA tensor ({tensor_size_mb:.1f}MB) "
154
+ "received from another process. Consider modifying the node "
155
+ "to avoid returning unmodified input tensors."
156
+ )
157
+ else:
158
+ logger.debug(
159
+ f"Cloning CUDA tensor ({tensor_size_mb:.2f}MB) received from another process"
160
+ )
161
+ return t.clone()
162
+ raise
163
+
164
+ except ImportError:
165
+ return t
166
+
167
+
168
+ def prepare_for_ipc_recursive(obj: Any) -> Any:
169
+ """
170
+ Recursively prepare all tensors in a nested structure for IPC.
171
+
172
+ Also keeps tensor references to prevent GC.
173
+ """
174
+ try:
175
+ import torch
176
+ if isinstance(obj, torch.Tensor):
177
+ prepared = prepare_tensor_for_ipc(obj)
178
+ keep_tensor(prepared)
179
+ return prepared
180
+ elif isinstance(obj, list):
181
+ return [prepare_for_ipc_recursive(x) for x in obj]
182
+ elif isinstance(obj, tuple):
183
+ return tuple(prepare_for_ipc_recursive(x) for x in obj)
184
+ elif isinstance(obj, dict):
185
+ return {k: prepare_for_ipc_recursive(v) for k, v in obj.items()}
186
+ except ImportError:
187
+ pass
188
+ return obj
@@ -0,0 +1,375 @@
1
+ """
2
+ TorchMPWorker - Same-venv isolation using torch.multiprocessing.
3
+
4
+ This is the simplest and fastest worker type:
5
+ - Uses torch.multiprocessing.Queue for IPC
6
+ - Zero-copy tensor transfer via CUDA IPC (automatic)
7
+ - Fresh CUDA context in subprocess
8
+ - ~30ms overhead per call
9
+
10
+ Use this when you need:
11
+ - Memory isolation between nodes
12
+ - Fresh CUDA context (automatic VRAM cleanup on worker death)
13
+ - Same Python environment as host
14
+
15
+ Example:
16
+ worker = TorchMPWorker()
17
+
18
+ def gpu_work(image):
19
+ import torch
20
+ return image * 2
21
+
22
+ result = worker.call(gpu_work, image=my_tensor)
23
+ worker.shutdown()
24
+ """
25
+
26
+ import logging
27
+ import traceback
28
+ from queue import Empty as QueueEmpty
29
+ from typing import Any, Callable, Optional
30
+
31
+ from .base import Worker, WorkerError
32
+
33
+ logger = logging.getLogger("comfy_env")
34
+
35
+
36
+ # Sentinel value for shutdown
37
+ _SHUTDOWN = object()
38
+
39
+ # Message type for method calls (avoids pickling issues with functions)
40
+ _CALL_METHOD = "call_method"
41
+
42
+
43
+ def _worker_loop(queue_in, queue_out):
44
+ """
45
+ Worker process main loop.
46
+
47
+ Receives work items and executes them:
48
+ - ("call_method", module_name, class_name, method_name, self_state, kwargs): Call a method on a class
49
+ - (func, args, kwargs): Execute a function directly
50
+ - _SHUTDOWN: Shutdown the worker
51
+
52
+ Runs until receiving _SHUTDOWN sentinel.
53
+ """
54
+ import importlib
55
+ import os
56
+ import sys
57
+
58
+ # Set worker mode env var
59
+ os.environ["COMFYUI_ISOLATION_WORKER"] = "1"
60
+
61
+ while True:
62
+ try:
63
+ item = queue_in.get()
64
+
65
+ # Check for shutdown signal
66
+ if item is _SHUTDOWN:
67
+ queue_out.put(("shutdown", None))
68
+ break
69
+
70
+ try:
71
+ # Handle method call protocol
72
+ if isinstance(item, tuple) and len(item) == 6 and item[0] == _CALL_METHOD:
73
+ _, module_name, class_name, method_name, self_state, kwargs = item
74
+ result = _execute_method_call(
75
+ module_name, class_name, method_name, self_state, kwargs
76
+ )
77
+ queue_out.put(("ok", result))
78
+ else:
79
+ # Direct function call (legacy)
80
+ func, args, kwargs = item
81
+ result = func(*args, **kwargs)
82
+ queue_out.put(("ok", result))
83
+
84
+ except Exception as e:
85
+ tb = traceback.format_exc()
86
+ queue_out.put(("error", (str(e), tb)))
87
+
88
+ except Exception as e:
89
+ # Queue error - try to report, then exit
90
+ try:
91
+ queue_out.put(("fatal", str(e)))
92
+ except:
93
+ pass
94
+ break
95
+
96
+
97
+ def _execute_method_call(module_name: str, class_name: str, method_name: str,
98
+ self_state: dict, kwargs: dict) -> Any:
99
+ """
100
+ Execute a method call in the worker process.
101
+
102
+ This function imports the class fresh and calls the original (un-decorated) method.
103
+ """
104
+ import importlib
105
+
106
+ # Import the module
107
+ module = importlib.import_module(module_name)
108
+ cls = getattr(module, class_name)
109
+
110
+ # Create instance with proper __slots__ handling
111
+ instance = object.__new__(cls)
112
+
113
+ # Handle both __slots__ and __dict__ based classes
114
+ if hasattr(cls, '__slots__'):
115
+ # Class uses __slots__ - set attributes individually
116
+ for slot in cls.__slots__:
117
+ if slot in self_state:
118
+ setattr(instance, slot, self_state[slot])
119
+ # Also check for __dict__ slot (hybrid classes)
120
+ if '__dict__' in cls.__slots__ or hasattr(instance, '__dict__'):
121
+ for key, value in self_state.items():
122
+ if key not in cls.__slots__:
123
+ setattr(instance, key, value)
124
+ else:
125
+ # Standard class with __dict__
126
+ instance.__dict__.update(self_state)
127
+
128
+ # Get the ORIGINAL method stored by the decorator, not the proxy
129
+ # This avoids the infinite recursion of proxy -> worker -> proxy
130
+ original_method = getattr(cls, '_isolated_original_method', None)
131
+ if original_method is None:
132
+ # Fallback: class wasn't decorated, use the method directly
133
+ original_method = getattr(cls, method_name)
134
+ return original_method(instance, **kwargs)
135
+
136
+ # Call the original method (it's an unbound function, pass instance)
137
+ return original_method(instance, **kwargs)
138
+
139
+
140
+ class TorchMPWorker(Worker):
141
+ """
142
+ Worker using torch.multiprocessing for same-venv isolation.
143
+
144
+ Features:
145
+ - Zero-copy CUDA tensor transfer (via CUDA IPC handles)
146
+ - Zero-copy CPU tensor transfer (via shared memory)
147
+ - Fresh CUDA context (subprocess has independent GPU state)
148
+ - Automatic cleanup on worker death
149
+
150
+ The subprocess uses 'spawn' start method, ensuring a clean Python
151
+ interpreter without inherited state from the parent.
152
+ """
153
+
154
+ def __init__(self, name: Optional[str] = None):
155
+ """
156
+ Initialize the worker.
157
+
158
+ Args:
159
+ name: Optional name for logging/debugging.
160
+ """
161
+ self.name = name or "TorchMPWorker"
162
+ self._process = None
163
+ self._queue_in = None
164
+ self._queue_out = None
165
+ self._started = False
166
+ self._shutdown = False
167
+
168
+ def _ensure_started(self):
169
+ """Lazily start the worker process on first call."""
170
+ if self._shutdown:
171
+ raise RuntimeError(f"{self.name}: Worker has been shut down")
172
+
173
+ if self._started:
174
+ if not self._process.is_alive():
175
+ raise RuntimeError(f"{self.name}: Worker process died unexpectedly")
176
+ return
177
+
178
+ # Import torch here to avoid import at module level
179
+ import torch.multiprocessing as mp
180
+
181
+ # Use spawn to get clean subprocess (no inherited CUDA context)
182
+ ctx = mp.get_context('spawn')
183
+
184
+ self._queue_in = ctx.Queue()
185
+ self._queue_out = ctx.Queue()
186
+ self._process = ctx.Process(
187
+ target=_worker_loop,
188
+ args=(self._queue_in, self._queue_out),
189
+ daemon=True,
190
+ )
191
+ self._process.start()
192
+ self._started = True
193
+
194
+ def call(
195
+ self,
196
+ func: Callable,
197
+ *args,
198
+ timeout: Optional[float] = None,
199
+ **kwargs
200
+ ) -> Any:
201
+ """
202
+ Execute a function in the worker process.
203
+
204
+ Args:
205
+ func: Function to execute. Must be picklable (module-level or staticmethod).
206
+ *args: Positional arguments.
207
+ timeout: Timeout in seconds (None = no timeout, default).
208
+ **kwargs: Keyword arguments.
209
+
210
+ Returns:
211
+ Return value of func(*args, **kwargs).
212
+
213
+ Raises:
214
+ WorkerError: If func raises an exception.
215
+ TimeoutError: If execution exceeds timeout.
216
+ RuntimeError: If worker process dies.
217
+ """
218
+ self._ensure_started()
219
+
220
+ # Send work item
221
+ self._queue_in.put((func, args, kwargs))
222
+
223
+ return self._get_result(timeout)
224
+
225
+ def call_method(
226
+ self,
227
+ module_name: str,
228
+ class_name: str,
229
+ method_name: str,
230
+ self_state: dict,
231
+ kwargs: dict,
232
+ timeout: Optional[float] = None,
233
+ ) -> Any:
234
+ """
235
+ Execute a class method in the worker process.
236
+
237
+ This uses a string-based protocol to avoid pickle issues with decorated methods.
238
+ The worker imports the module fresh and calls the original (un-decorated) method.
239
+
240
+ Args:
241
+ module_name: Full module path (e.g., 'my_package.nodes.my_node')
242
+ class_name: Class name (e.g., 'MyNode')
243
+ method_name: Method name (e.g., 'process')
244
+ self_state: Instance __dict__ to restore
245
+ kwargs: Method keyword arguments
246
+ timeout: Timeout in seconds (None = no timeout, default).
247
+
248
+ Returns:
249
+ Return value of method.
250
+
251
+ Raises:
252
+ WorkerError: If method raises an exception.
253
+ TimeoutError: If execution exceeds timeout.
254
+ RuntimeError: If worker process dies.
255
+ """
256
+ self._ensure_started()
257
+
258
+ # Send method call request using protocol
259
+ self._queue_in.put((
260
+ _CALL_METHOD,
261
+ module_name,
262
+ class_name,
263
+ method_name,
264
+ self_state,
265
+ kwargs,
266
+ ))
267
+
268
+ return self._get_result(timeout)
269
+
270
+ def _get_result(self, timeout: Optional[float]) -> Any:
271
+ """Wait for and return result from worker."""
272
+ try:
273
+ status, result = self._queue_out.get(timeout=timeout)
274
+ except QueueEmpty:
275
+ # Timeout - use graceful escalation
276
+ self._handle_timeout(timeout)
277
+ # _handle_timeout always raises, but just in case:
278
+ raise TimeoutError(f"{self.name}: Call timed out after {timeout}s")
279
+ except Exception as e:
280
+ raise RuntimeError(f"{self.name}: Failed to get result: {e}")
281
+
282
+ # Handle response
283
+ if status == "ok":
284
+ return result
285
+ elif status == "error":
286
+ msg, tb = result
287
+ raise WorkerError(msg, traceback=tb)
288
+ elif status == "fatal":
289
+ self._shutdown = True
290
+ raise RuntimeError(f"{self.name}: Fatal worker error: {result}")
291
+ else:
292
+ raise RuntimeError(f"{self.name}: Unknown response status: {status}")
293
+
294
+ def shutdown(self) -> None:
295
+ """Shut down the worker process."""
296
+ if self._shutdown or not self._started:
297
+ return
298
+
299
+ self._shutdown = True
300
+
301
+ try:
302
+ # Send shutdown signal
303
+ self._queue_in.put(_SHUTDOWN)
304
+
305
+ # Wait for acknowledgment
306
+ try:
307
+ self._queue_out.get(timeout=5.0)
308
+ except:
309
+ pass
310
+
311
+ # Wait for process to exit
312
+ self._process.join(timeout=5.0)
313
+
314
+ if self._process.is_alive():
315
+ self._process.kill()
316
+ self._process.join(timeout=1.0)
317
+
318
+ except Exception:
319
+ # Force kill if anything goes wrong
320
+ if self._process and self._process.is_alive():
321
+ self._process.kill()
322
+
323
+ def _handle_timeout(self, timeout: float) -> None:
324
+ """
325
+ Handle timeout with graceful escalation.
326
+
327
+ Instead of immediately killing the worker (which can leak GPU memory),
328
+ try graceful shutdown first, then escalate to SIGTERM, then SIGKILL.
329
+
330
+ Inspired by pyisolate's timeout handling pattern.
331
+ """
332
+ logger.warning(f"{self.name}: Call timed out after {timeout}s, attempting graceful shutdown")
333
+
334
+ # Stage 1: Send shutdown signal, wait 3s for graceful exit
335
+ try:
336
+ self._queue_in.put(_SHUTDOWN)
337
+ self._queue_out.get(timeout=3.0)
338
+ self._process.join(timeout=2.0)
339
+ if not self._process.is_alive():
340
+ self._shutdown = True
341
+ raise TimeoutError(f"{self.name}: Graceful shutdown after timeout ({timeout}s)")
342
+ except QueueEmpty:
343
+ pass
344
+ except TimeoutError:
345
+ raise
346
+ except Exception:
347
+ pass
348
+
349
+ # Stage 2: SIGTERM, wait 5s
350
+ if self._process.is_alive():
351
+ logger.warning(f"{self.name}: Graceful shutdown failed, sending SIGTERM")
352
+ self._process.terminate()
353
+ self._process.join(timeout=5.0)
354
+
355
+ # Stage 3: SIGKILL as last resort
356
+ if self._process.is_alive():
357
+ logger.error(f"{self.name}: SIGTERM failed, force killing worker (may leak GPU memory)")
358
+ self._process.kill()
359
+ self._process.join(timeout=1.0)
360
+
361
+ self._shutdown = True
362
+ raise TimeoutError(f"{self.name}: Call timed out after {timeout}s")
363
+
364
+ def is_alive(self) -> bool:
365
+ """Check if worker process is running or can be started."""
366
+ if self._shutdown:
367
+ return False
368
+ # Not started yet = can still be started = "alive"
369
+ if not self._started:
370
+ return True
371
+ return self._process.is_alive()
372
+
373
+ def __repr__(self):
374
+ status = "alive" if self.is_alive() else "stopped"
375
+ return f"<TorchMPWorker name={self.name!r} status={status}>"