comfy-env 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,903 @@
1
+ """
2
+ VenvWorker - Cross-venv isolation using subprocess + shared memory.
3
+
4
+ This worker supports calling functions in a different Python environment:
5
+ - Uses subprocess.Popen to run in different venv
6
+ - Transfers tensors via torch.save/load through /dev/shm (RAM-backed)
7
+ - One memcpy per tensor per direction
8
+ - ~100-500ms overhead per call (subprocess spawn + tensor I/O)
9
+
10
+ Use this when you need:
11
+ - Different PyTorch version
12
+ - Incompatible native library dependencies
13
+ - Different Python version
14
+
15
+ Example:
16
+ worker = VenvWorker(
17
+ python="/path/to/other/venv/bin/python",
18
+ working_dir="/path/to/code",
19
+ )
20
+
21
+ # Call a function by module path
22
+ result = worker.call_module(
23
+ module="my_module",
24
+ func="my_function",
25
+ image=my_tensor,
26
+ )
27
+ """
28
+
29
+ import json
30
+ import os
31
+ import shutil
32
+ import subprocess
33
+ import sys
34
+ import tempfile
35
+ import threading
36
+ import time
37
+ import uuid
38
+ from pathlib import Path
39
+ from typing import Any, Callable, Dict, List, Optional, Union
40
+
41
+ from .base import Worker, WorkerError
42
+
43
+
44
+ def _serialize_for_ipc(obj, visited=None):
45
+ """
46
+ Convert objects with broken __module__ paths to dicts for IPC.
47
+
48
+ ComfyUI sets weird __module__ values (file paths) on custom node classes,
49
+ which breaks pickle deserialization in the worker. This converts such
50
+ objects to a serializable dict format.
51
+ """
52
+ if visited is None:
53
+ visited = {} # Maps id -> serialized result
54
+
55
+ obj_id = id(obj)
56
+ if obj_id in visited:
57
+ return visited[obj_id] # Return cached serialized result
58
+
59
+ # Check if this is a custom object with broken module path
60
+ if (hasattr(obj, '__dict__') and
61
+ hasattr(obj, '__class__') and
62
+ not isinstance(obj, (dict, list, tuple, type)) and
63
+ obj.__class__.__name__ not in ('Tensor', 'ndarray', 'module')):
64
+
65
+ cls = obj.__class__
66
+ module = getattr(cls, '__module__', '')
67
+
68
+ # Check if module looks like a file path or is problematic for pickling
69
+ # This catches: file paths, custom_nodes imports, and modules starting with /
70
+ is_problematic = (
71
+ '/' in module or
72
+ '\\' in module or
73
+ module.startswith('/') or
74
+ 'custom_nodes' in module or
75
+ module == '' or
76
+ module == '__main__'
77
+ )
78
+ if is_problematic:
79
+ # Convert to serializable dict and cache it
80
+ result = {
81
+ '__isolated_object__': True,
82
+ '__class_name__': cls.__name__,
83
+ '__attrs__': {k: _serialize_for_ipc(v, visited) for k, v in obj.__dict__.items()},
84
+ }
85
+ visited[obj_id] = result
86
+ return result
87
+
88
+ # Recurse into containers
89
+ if isinstance(obj, dict):
90
+ result = {k: _serialize_for_ipc(v, visited) for k, v in obj.items()}
91
+ visited[obj_id] = result
92
+ return result
93
+ elif isinstance(obj, list):
94
+ result = [_serialize_for_ipc(v, visited) for v in obj]
95
+ visited[obj_id] = result
96
+ return result
97
+ elif isinstance(obj, tuple):
98
+ result = tuple(_serialize_for_ipc(v, visited) for v in obj)
99
+ visited[obj_id] = result
100
+ return result
101
+
102
+ # Primitives and other objects - cache and return as-is
103
+ visited[obj_id] = obj
104
+ return obj
105
+
106
+
107
+ # Worker script template - minimal, runs in target venv
108
+ _WORKER_SCRIPT = '''
109
+ import sys
110
+ import json
111
+ import traceback
112
+ from types import SimpleNamespace
113
+
114
+ def _deserialize_isolated_objects(obj):
115
+ """Reconstruct objects serialized with __isolated_object__ marker."""
116
+ if isinstance(obj, dict):
117
+ if obj.get("__isolated_object__"):
118
+ # Reconstruct as SimpleNamespace (supports .attr access)
119
+ attrs = {k: _deserialize_isolated_objects(v) for k, v in obj.get("__attrs__", {}).items()}
120
+ ns = SimpleNamespace(**attrs)
121
+ ns.__class_name__ = obj.get("__class_name__", "Unknown")
122
+ return ns
123
+ return {k: _deserialize_isolated_objects(v) for k, v in obj.items()}
124
+ elif isinstance(obj, list):
125
+ return [_deserialize_isolated_objects(v) for v in obj]
126
+ elif isinstance(obj, tuple):
127
+ return tuple(_deserialize_isolated_objects(v) for v in obj)
128
+ return obj
129
+
130
+ def main():
131
+ # Read request from file
132
+ request_path = sys.argv[1]
133
+ response_path = sys.argv[2]
134
+
135
+ with open(request_path, 'r') as f:
136
+ request = json.load(f)
137
+
138
+ try:
139
+ # Setup paths
140
+ for p in request.get("sys_path", []):
141
+ if p not in sys.path:
142
+ sys.path.insert(0, p)
143
+
144
+ # Import torch for tensor I/O
145
+ import torch
146
+
147
+ # Load inputs
148
+ inputs_path = request.get("inputs_path")
149
+ if inputs_path:
150
+ inputs = torch.load(inputs_path, weights_only=False)
151
+ inputs = _deserialize_isolated_objects(inputs)
152
+ else:
153
+ inputs = {}
154
+
155
+ # Import and call function
156
+ module_name = request["module"]
157
+ func_name = request["func"]
158
+
159
+ module = __import__(module_name, fromlist=[func_name])
160
+ func = getattr(module, func_name)
161
+
162
+ result = func(**inputs)
163
+
164
+ # Save outputs
165
+ outputs_path = request.get("outputs_path")
166
+ if outputs_path:
167
+ torch.save(result, outputs_path)
168
+
169
+ response = {"status": "ok"}
170
+
171
+ except Exception as e:
172
+ response = {
173
+ "status": "error",
174
+ "error": str(e),
175
+ "traceback": traceback.format_exc(),
176
+ }
177
+
178
+ with open(response_path, 'w') as f:
179
+ json.dump(response, f)
180
+
181
+ if __name__ == "__main__":
182
+ main()
183
+ '''
184
+
185
+
186
+ def _get_shm_dir() -> Path:
187
+ """Get shared memory directory for efficient tensor transfer."""
188
+ # Linux: /dev/shm is RAM-backed tmpfs
189
+ if sys.platform == 'linux' and os.path.isdir('/dev/shm'):
190
+ return Path('/dev/shm')
191
+ # Fallback to regular temp
192
+ return Path(tempfile.gettempdir())
193
+
194
+
195
+ class VenvWorker(Worker):
196
+ """
197
+ Worker using subprocess for cross-venv isolation.
198
+
199
+ This worker spawns a new Python process for each call, using
200
+ a different Python interpreter (from another venv). Tensors are
201
+ transferred via torch.save/load through shared memory.
202
+
203
+ For long-running workloads, consider using persistent mode which
204
+ keeps the subprocess alive between calls.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ python: Union[str, Path],
210
+ working_dir: Optional[Union[str, Path]] = None,
211
+ sys_path: Optional[List[str]] = None,
212
+ env: Optional[Dict[str, str]] = None,
213
+ name: Optional[str] = None,
214
+ persistent: bool = True,
215
+ ):
216
+ """
217
+ Initialize the worker.
218
+
219
+ Args:
220
+ python: Path to Python executable in target venv.
221
+ working_dir: Working directory for subprocess.
222
+ sys_path: Additional paths to add to sys.path in subprocess.
223
+ env: Additional environment variables.
224
+ name: Optional name for logging.
225
+ persistent: If True, keep subprocess alive between calls (faster).
226
+ """
227
+ self.python = Path(python)
228
+ self.working_dir = Path(working_dir) if working_dir else Path.cwd()
229
+ self.sys_path = sys_path or []
230
+ self.extra_env = env or {}
231
+ self.name = name or f"VenvWorker({self.python.parent.parent.name})"
232
+ self.persistent = persistent
233
+
234
+ # Verify Python exists
235
+ if not self.python.exists():
236
+ raise FileNotFoundError(f"Python not found: {self.python}")
237
+
238
+ # Create temp directory for IPC files
239
+ self._temp_dir = Path(tempfile.mkdtemp(prefix='comfyui_venv_'))
240
+ self._shm_dir = _get_shm_dir()
241
+
242
+ # Persistent process state
243
+ self._process: Optional[subprocess.Popen] = None
244
+ self._shutdown = False
245
+
246
+ # Write worker script
247
+ self._worker_script = self._temp_dir / "worker.py"
248
+ self._worker_script.write_text(_WORKER_SCRIPT)
249
+
250
+ def call(
251
+ self,
252
+ func: Callable,
253
+ *args,
254
+ timeout: Optional[float] = None,
255
+ **kwargs
256
+ ) -> Any:
257
+ """
258
+ Execute a function - NOT SUPPORTED for VenvWorker.
259
+
260
+ VenvWorker cannot pickle arbitrary functions across venv boundaries.
261
+ Use call_module() instead to call functions by module path.
262
+
263
+ Raises:
264
+ NotImplementedError: Always.
265
+ """
266
+ raise NotImplementedError(
267
+ f"{self.name}: VenvWorker cannot call arbitrary functions. "
268
+ f"Use call_module(module='...', func='...', **kwargs) instead."
269
+ )
270
+
271
+ def call_module(
272
+ self,
273
+ module: str,
274
+ func: str,
275
+ timeout: Optional[float] = None,
276
+ **kwargs
277
+ ) -> Any:
278
+ """
279
+ Call a function by module path in the isolated venv.
280
+
281
+ Args:
282
+ module: Module name (e.g., "my_package.my_module").
283
+ func: Function name within the module.
284
+ timeout: Timeout in seconds (None = 600s default).
285
+ **kwargs: Keyword arguments passed to the function.
286
+ Must be torch.save-compatible (tensors, dicts, etc.).
287
+
288
+ Returns:
289
+ Return value of module.func(**kwargs).
290
+
291
+ Raises:
292
+ WorkerError: If function raises an exception.
293
+ TimeoutError: If execution exceeds timeout.
294
+ """
295
+ if self._shutdown:
296
+ raise RuntimeError(f"{self.name}: Worker has been shut down")
297
+
298
+ timeout = timeout or 600.0 # 10 minute default
299
+
300
+ # Create unique ID for this call
301
+ call_id = str(uuid.uuid4())[:8]
302
+
303
+ # Paths for IPC (use shm for tensors, temp for json)
304
+ inputs_path = self._shm_dir / f"comfyui_venv_{call_id}_in.pt"
305
+ outputs_path = self._shm_dir / f"comfyui_venv_{call_id}_out.pt"
306
+ request_path = self._temp_dir / f"request_{call_id}.json"
307
+ response_path = self._temp_dir / f"response_{call_id}.json"
308
+
309
+ try:
310
+ # Save inputs via torch.save (handles tensors natively)
311
+ # Serialize custom objects with broken __module__ paths first
312
+ import torch
313
+ if kwargs:
314
+ serialized_kwargs = _serialize_for_ipc(kwargs)
315
+ torch.save(serialized_kwargs, str(inputs_path))
316
+
317
+ # Build request
318
+ request = {
319
+ "module": module,
320
+ "func": func,
321
+ "sys_path": [str(self.working_dir)] + self.sys_path,
322
+ "inputs_path": str(inputs_path) if kwargs else None,
323
+ "outputs_path": str(outputs_path),
324
+ }
325
+
326
+ request_path.write_text(json.dumps(request))
327
+
328
+ # Build environment
329
+ env = os.environ.copy()
330
+ env.update(self.extra_env)
331
+ env["COMFYUI_ISOLATION_WORKER"] = "1"
332
+
333
+ # Run subprocess
334
+ cmd = [
335
+ str(self.python),
336
+ str(self._worker_script),
337
+ str(request_path),
338
+ str(response_path),
339
+ ]
340
+
341
+ process = subprocess.Popen(
342
+ cmd,
343
+ cwd=str(self.working_dir),
344
+ env=env,
345
+ stdout=subprocess.PIPE,
346
+ stderr=subprocess.PIPE,
347
+ )
348
+
349
+ try:
350
+ stdout, stderr = process.communicate(timeout=timeout)
351
+ except subprocess.TimeoutExpired:
352
+ process.kill()
353
+ process.wait()
354
+ raise TimeoutError(f"{self.name}: Call timed out after {timeout}s")
355
+
356
+ # Check for process error
357
+ if process.returncode != 0:
358
+ raise WorkerError(
359
+ f"Subprocess failed with code {process.returncode}",
360
+ traceback=stderr.decode('utf-8', errors='replace'),
361
+ )
362
+
363
+ # Read response
364
+ if not response_path.exists():
365
+ raise WorkerError(
366
+ f"No response file",
367
+ traceback=stderr.decode('utf-8', errors='replace'),
368
+ )
369
+
370
+ response = json.loads(response_path.read_text())
371
+
372
+ if response["status"] == "error":
373
+ raise WorkerError(
374
+ response.get("error", "Unknown error"),
375
+ traceback=response.get("traceback"),
376
+ )
377
+
378
+ # Load result
379
+ if outputs_path.exists():
380
+ result = torch.load(str(outputs_path), weights_only=False)
381
+ return result
382
+ else:
383
+ return None
384
+
385
+ finally:
386
+ # Cleanup IPC files
387
+ for path in [inputs_path, outputs_path, request_path, response_path]:
388
+ try:
389
+ if path.exists():
390
+ path.unlink()
391
+ except:
392
+ pass
393
+
394
+ def shutdown(self) -> None:
395
+ """Shut down the worker and clean up resources."""
396
+ if self._shutdown:
397
+ return
398
+
399
+ self._shutdown = True
400
+
401
+ # Clean up temp directory
402
+ try:
403
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
404
+ except:
405
+ pass
406
+
407
+ def is_alive(self) -> bool:
408
+ """VenvWorker spawns fresh process per call, so always 'alive' if not shutdown."""
409
+ return not self._shutdown
410
+
411
+ def __repr__(self):
412
+ return f"<VenvWorker name={self.name!r} python={self.python}>"
413
+
414
+
415
+ # Persistent worker script - runs as __main__ in the venv Python subprocess
416
+ # Uses stdin/stdout JSON for IPC - avoids Windows multiprocessing spawn issues entirely
417
+ _PERSISTENT_WORKER_SCRIPT = '''
418
+ import sys
419
+ import os
420
+ import json
421
+ import traceback
422
+ from types import SimpleNamespace
423
+
424
+ # On Windows, add host Python's DLL directories so packages like opencv can find VC++ runtime
425
+ if sys.platform == "win32":
426
+ _host_python_dir = os.environ.get("COMFYUI_HOST_PYTHON_DIR")
427
+ if _host_python_dir and hasattr(os, "add_dll_directory"):
428
+ try:
429
+ os.add_dll_directory(_host_python_dir)
430
+ # Also add DLLs subdirectory if it exists
431
+ _dlls_dir = os.path.join(_host_python_dir, "DLLs")
432
+ if os.path.isdir(_dlls_dir):
433
+ os.add_dll_directory(_dlls_dir)
434
+ except Exception:
435
+ pass
436
+
437
+ def _deserialize_isolated_objects(obj):
438
+ """Reconstruct objects serialized with __isolated_object__ marker."""
439
+ if isinstance(obj, dict):
440
+ if obj.get("__isolated_object__"):
441
+ attrs = {k: _deserialize_isolated_objects(v) for k, v in obj.get("__attrs__", {}).items()}
442
+ ns = SimpleNamespace(**attrs)
443
+ ns.__class_name__ = obj.get("__class_name__", "Unknown")
444
+ return ns
445
+ return {k: _deserialize_isolated_objects(v) for k, v in obj.items()}
446
+ elif isinstance(obj, list):
447
+ return [_deserialize_isolated_objects(v) for v in obj]
448
+ elif isinstance(obj, tuple):
449
+ return tuple(_deserialize_isolated_objects(v) for v in obj)
450
+ return obj
451
+
452
+ def main():
453
+ # Save original stdout for JSON IPC - redirect stdout to stderr for module prints
454
+ _ipc_out = sys.stdout
455
+ sys.stdout = sys.stderr # All print() calls go to stderr now
456
+
457
+ # Read config from first line
458
+ config_line = sys.stdin.readline()
459
+ if not config_line:
460
+ return
461
+ config = json.loads(config_line)
462
+
463
+ # Setup sys.path
464
+ for p in config.get("sys_paths", []):
465
+ if p not in sys.path:
466
+ sys.path.insert(0, p)
467
+
468
+ # Import torch after path setup
469
+ import torch
470
+
471
+ # Signal ready (use _ipc_out, not stdout)
472
+ _ipc_out.write(json.dumps({"status": "ready"}) + "\\n")
473
+ _ipc_out.flush()
474
+
475
+ # Process requests
476
+ while True:
477
+ try:
478
+ line = sys.stdin.readline()
479
+ if not line:
480
+ break
481
+ request = json.loads(line)
482
+ except Exception:
483
+ break
484
+
485
+ if request.get("method") == "shutdown":
486
+ break
487
+
488
+ try:
489
+ request_type = request.get("type", "call_module")
490
+ module_name = request["module"]
491
+ inputs_path = request.get("inputs_path")
492
+ outputs_path = request.get("outputs_path")
493
+
494
+ # Load inputs
495
+ if inputs_path:
496
+ inputs = torch.load(inputs_path, weights_only=False)
497
+ inputs = _deserialize_isolated_objects(inputs)
498
+ else:
499
+ inputs = {}
500
+
501
+ # Import module
502
+ module = __import__(module_name, fromlist=[""])
503
+
504
+ if request_type == "call_method":
505
+ class_name = request["class_name"]
506
+ method_name = request["method_name"]
507
+ self_state = request.get("self_state")
508
+
509
+ cls = getattr(module, class_name)
510
+ instance = object.__new__(cls)
511
+ if self_state:
512
+ instance.__dict__.update(self_state)
513
+ method = getattr(instance, method_name)
514
+ result = method(**inputs)
515
+ else:
516
+ func_name = request["func"]
517
+ func = getattr(module, func_name)
518
+ result = func(**inputs)
519
+
520
+ # Save result
521
+ if outputs_path:
522
+ torch.save(result, outputs_path)
523
+
524
+ _ipc_out.write(json.dumps({"status": "ok"}) + "\\n")
525
+ _ipc_out.flush()
526
+
527
+ except Exception as e:
528
+ _ipc_out.write(json.dumps({
529
+ "status": "error",
530
+ "error": str(e),
531
+ "traceback": traceback.format_exc(),
532
+ }) + "\\n")
533
+ _ipc_out.flush()
534
+
535
+ if __name__ == "__main__":
536
+ main()
537
+ '''
538
+
539
+
540
+ class PersistentVenvWorker(Worker):
541
+ """
542
+ Persistent version of VenvWorker that keeps subprocess alive.
543
+
544
+ Uses subprocess.Popen with stdin/stdout JSON IPC instead of multiprocessing.
545
+ This avoids Windows multiprocessing spawn issues where the child process
546
+ tries to reimport __main__ (which fails when using a different Python).
547
+
548
+ Benefits:
549
+ - Works on Windows with different venv Python (full isolation)
550
+ - Compiled CUDA extensions load correctly in the venv
551
+ - ~50-100ms per call (vs ~300-500ms for VenvWorker per-call spawn)
552
+ - Tensor transfer via shared memory files
553
+
554
+ Use this for high-frequency calls to isolated venvs.
555
+ """
556
+
557
+ def __init__(
558
+ self,
559
+ python: Union[str, Path],
560
+ working_dir: Optional[Union[str, Path]] = None,
561
+ sys_path: Optional[List[str]] = None,
562
+ env: Optional[Dict[str, str]] = None,
563
+ name: Optional[str] = None,
564
+ share_torch: bool = True, # Kept for API compatibility
565
+ ):
566
+ """
567
+ Initialize persistent worker.
568
+
569
+ Args:
570
+ python: Path to Python executable in target venv.
571
+ working_dir: Working directory for subprocess.
572
+ sys_path: Additional paths to add to sys.path.
573
+ env: Additional environment variables.
574
+ name: Optional name for logging.
575
+ share_torch: Ignored (kept for API compatibility).
576
+ """
577
+ self.python = Path(python)
578
+ self.working_dir = Path(working_dir) if working_dir else Path.cwd()
579
+ self.sys_path = sys_path or []
580
+ self.extra_env = env or {}
581
+ self.name = name or f"PersistentVenvWorker({self.python.parent.parent.name})"
582
+
583
+ if not self.python.exists():
584
+ raise FileNotFoundError(f"Python not found: {self.python}")
585
+
586
+ self._temp_dir = Path(tempfile.mkdtemp(prefix='comfyui_pvenv_'))
587
+ self._shm_dir = _get_shm_dir()
588
+ self._process: Optional[subprocess.Popen] = None
589
+ self._shutdown = False
590
+ self._lock = threading.Lock()
591
+
592
+ # Write worker script to temp file
593
+ self._worker_script = self._temp_dir / "persistent_worker.py"
594
+ self._worker_script.write_text(_PERSISTENT_WORKER_SCRIPT)
595
+
596
+ def _find_comfyui_base(self) -> Optional[Path]:
597
+ """Find ComfyUI base directory by walking up from working_dir."""
598
+ current = self.working_dir.resolve()
599
+ for _ in range(10):
600
+ if (current / "main.py").exists() and (current / "comfy").exists():
601
+ return current
602
+ current = current.parent
603
+ return None
604
+
605
+ def _ensure_started(self):
606
+ """Start persistent worker subprocess if not running."""
607
+ if self._shutdown:
608
+ raise RuntimeError(f"{self.name}: Worker has been shut down")
609
+
610
+ if self._process is not None and self._process.poll() is None:
611
+ return # Already running
612
+
613
+ # Set up environment
614
+ env = os.environ.copy()
615
+ env.update(self.extra_env)
616
+ env["COMFYUI_ISOLATION_WORKER"] = "1"
617
+
618
+ # On Windows, pass host Python directory so worker can add it via os.add_dll_directory()
619
+ # This fixes "DLL load failed" errors for packages like opencv-python-headless
620
+ if sys.platform == "win32":
621
+ env["COMFYUI_HOST_PYTHON_DIR"] = str(Path(sys.executable).parent)
622
+
623
+ # Find ComfyUI base and set env var for folder_paths stub
624
+ comfyui_base = self._find_comfyui_base()
625
+ if comfyui_base:
626
+ env["COMFYUI_BASE"] = str(comfyui_base)
627
+
628
+ # Add stubs directory to sys_path for folder_paths etc.
629
+ stubs_dir = Path(__file__).parent.parent / "stubs"
630
+ all_sys_path = [str(stubs_dir), str(self.working_dir)] + self.sys_path
631
+
632
+ # Launch subprocess with the venv Python
633
+ # This runs _PERSISTENT_WORKER_SCRIPT as __main__ - no reimport issues!
634
+ self._process = subprocess.Popen(
635
+ [str(self.python), str(self._worker_script)],
636
+ stdin=subprocess.PIPE,
637
+ stdout=subprocess.PIPE,
638
+ stderr=subprocess.PIPE,
639
+ cwd=str(self.working_dir),
640
+ env=env,
641
+ bufsize=1, # Line buffered
642
+ text=True, # Text mode for JSON
643
+ )
644
+
645
+ # Start stderr forwarding thread to show worker output in real-time
646
+ def forward_stderr():
647
+ try:
648
+ for line in self._process.stderr:
649
+ # Forward to main process stderr (visible in console)
650
+ sys.stderr.write(f" {line}")
651
+ sys.stderr.flush()
652
+ except:
653
+ pass
654
+ self._stderr_thread = threading.Thread(target=forward_stderr, daemon=True)
655
+ self._stderr_thread.start()
656
+
657
+ # Send config
658
+ config = {"sys_paths": all_sys_path}
659
+ self._process.stdin.write(json.dumps(config) + "\n")
660
+ self._process.stdin.flush()
661
+
662
+ # Wait for ready signal with timeout
663
+ import select
664
+ if sys.platform == "win32":
665
+ # Windows: can't use select on pipes, use thread with timeout
666
+ ready_line = [None]
667
+ def read_ready():
668
+ try:
669
+ ready_line[0] = self._process.stdout.readline()
670
+ except:
671
+ pass
672
+ t = threading.Thread(target=read_ready, daemon=True)
673
+ t.start()
674
+ t.join(timeout=60)
675
+ line = ready_line[0]
676
+ else:
677
+ # Unix: use select for timeout
678
+ import select
679
+ ready, _, _ = select.select([self._process.stdout], [], [], 60)
680
+ line = self._process.stdout.readline() if ready else None
681
+
682
+ if not line:
683
+ stderr = ""
684
+ try:
685
+ self._process.kill()
686
+ _, stderr = self._process.communicate(timeout=5)
687
+ except:
688
+ pass
689
+ raise RuntimeError(f"{self.name}: Worker failed to start (timeout). stderr: {stderr}")
690
+
691
+ try:
692
+ msg = json.loads(line)
693
+ except json.JSONDecodeError as e:
694
+ raise RuntimeError(f"{self.name}: Invalid ready message: {line!r}") from e
695
+
696
+ if msg.get("status") != "ready":
697
+ raise RuntimeError(f"{self.name}: Unexpected ready message: {msg}")
698
+
699
+ def call(
700
+ self,
701
+ func: Callable,
702
+ *args,
703
+ timeout: Optional[float] = None,
704
+ **kwargs
705
+ ) -> Any:
706
+ """Not supported - use call_module()."""
707
+ raise NotImplementedError(
708
+ f"{self.name}: Use call_module(module='...', func='...') instead."
709
+ )
710
+
711
+ def _send_request(self, request: dict, timeout: float) -> dict:
712
+ """Send request via stdin and read response from stdout with timeout."""
713
+ # Send request
714
+ self._process.stdin.write(json.dumps(request) + "\n")
715
+ self._process.stdin.flush()
716
+
717
+ # Read response with timeout
718
+ if sys.platform == "win32":
719
+ # Windows: use thread for timeout
720
+ response_line = [None]
721
+ def read_response():
722
+ try:
723
+ response_line[0] = self._process.stdout.readline()
724
+ except:
725
+ pass
726
+ t = threading.Thread(target=read_response, daemon=True)
727
+ t.start()
728
+ t.join(timeout=timeout)
729
+ line = response_line[0]
730
+ else:
731
+ # Unix: use select
732
+ import select
733
+ ready, _, _ = select.select([self._process.stdout], [], [], timeout)
734
+ line = self._process.stdout.readline() if ready else None
735
+
736
+ if not line:
737
+ # Timeout - kill process
738
+ try:
739
+ self._process.kill()
740
+ except:
741
+ pass
742
+ self._shutdown = True
743
+ raise TimeoutError(f"{self.name}: Call timed out after {timeout}s")
744
+
745
+ try:
746
+ return json.loads(line)
747
+ except json.JSONDecodeError as e:
748
+ raise WorkerError(f"Invalid response from worker: {line!r}") from e
749
+
750
+ def call_method(
751
+ self,
752
+ module_name: str,
753
+ class_name: str,
754
+ method_name: str,
755
+ self_state: Optional[Dict[str, Any]] = None,
756
+ kwargs: Optional[Dict[str, Any]] = None,
757
+ timeout: Optional[float] = None,
758
+ ) -> Any:
759
+ """
760
+ Call a class method by module/class/method path.
761
+
762
+ Args:
763
+ module_name: Module containing the class (e.g., "depth_estimate").
764
+ class_name: Class name (e.g., "SAM3D_DepthEstimate").
765
+ method_name: Method name (e.g., "estimate_depth").
766
+ self_state: Optional dict to populate instance __dict__.
767
+ kwargs: Keyword arguments for the method.
768
+ timeout: Timeout in seconds.
769
+
770
+ Returns:
771
+ Return value of the method.
772
+ """
773
+ with self._lock:
774
+ self._ensure_started()
775
+
776
+ timeout = timeout or 600.0
777
+ call_id = str(uuid.uuid4())[:8]
778
+
779
+ import torch
780
+ inputs_path = self._shm_dir / f"comfyui_pvenv_{call_id}_in.pt"
781
+ outputs_path = self._shm_dir / f"comfyui_pvenv_{call_id}_out.pt"
782
+
783
+ try:
784
+ # Serialize kwargs
785
+ if kwargs:
786
+ serialized_kwargs = _serialize_for_ipc(kwargs)
787
+ torch.save(serialized_kwargs, str(inputs_path))
788
+
789
+ # Send request with class info
790
+ request = {
791
+ "type": "call_method",
792
+ "module": module_name,
793
+ "class_name": class_name,
794
+ "method_name": method_name,
795
+ "self_state": self_state,
796
+ "inputs_path": str(inputs_path) if kwargs else None,
797
+ "outputs_path": str(outputs_path),
798
+ }
799
+ response = self._send_request(request, timeout)
800
+
801
+ if response.get("status") == "error":
802
+ raise WorkerError(
803
+ response.get("error", "Unknown"),
804
+ traceback=response.get("traceback"),
805
+ )
806
+
807
+ if outputs_path.exists():
808
+ return torch.load(str(outputs_path), weights_only=False)
809
+ return None
810
+
811
+ finally:
812
+ for p in [inputs_path, outputs_path]:
813
+ try:
814
+ p.unlink()
815
+ except:
816
+ pass
817
+
818
+ def call_module(
819
+ self,
820
+ module: str,
821
+ func: str,
822
+ timeout: Optional[float] = None,
823
+ **kwargs
824
+ ) -> Any:
825
+ """Call a function by module path."""
826
+ with self._lock:
827
+ self._ensure_started()
828
+
829
+ timeout = timeout or 600.0
830
+ call_id = str(uuid.uuid4())[:8]
831
+
832
+ # Save inputs
833
+ import torch
834
+ inputs_path = self._shm_dir / f"comfyui_pvenv_{call_id}_in.pt"
835
+ outputs_path = self._shm_dir / f"comfyui_pvenv_{call_id}_out.pt"
836
+
837
+ try:
838
+ if kwargs:
839
+ serialized_kwargs = _serialize_for_ipc(kwargs)
840
+ torch.save(serialized_kwargs, str(inputs_path))
841
+
842
+ # Send request
843
+ request = {
844
+ "type": "call_module",
845
+ "module": module,
846
+ "func": func,
847
+ "inputs_path": str(inputs_path) if kwargs else None,
848
+ "outputs_path": str(outputs_path),
849
+ }
850
+ response = self._send_request(request, timeout)
851
+
852
+ if response.get("status") == "error":
853
+ raise WorkerError(
854
+ response.get("error", "Unknown"),
855
+ traceback=response.get("traceback"),
856
+ )
857
+
858
+ # Load result
859
+ if outputs_path.exists():
860
+ return torch.load(str(outputs_path), weights_only=False)
861
+ return None
862
+
863
+ finally:
864
+ for p in [inputs_path, outputs_path]:
865
+ try:
866
+ p.unlink()
867
+ except:
868
+ pass
869
+
870
+ def shutdown(self) -> None:
871
+ """Shut down the persistent worker."""
872
+ if self._shutdown:
873
+ return
874
+ self._shutdown = True
875
+
876
+ # Send shutdown signal via stdin
877
+ if self._process and self._process.poll() is None:
878
+ try:
879
+ self._process.stdin.write(json.dumps({"method": "shutdown"}) + "\n")
880
+ self._process.stdin.flush()
881
+ self._process.stdin.close()
882
+ except:
883
+ pass
884
+
885
+ # Wait for process to exit
886
+ try:
887
+ self._process.wait(timeout=5)
888
+ except subprocess.TimeoutExpired:
889
+ self._process.kill()
890
+ self._process.wait(timeout=2)
891
+
892
+ shutil.rmtree(self._temp_dir, ignore_errors=True)
893
+
894
+ def is_alive(self) -> bool:
895
+ if self._shutdown:
896
+ return False
897
+ if self._process is None:
898
+ return False
899
+ return self._process.poll() is None
900
+
901
+ def __repr__(self):
902
+ status = "alive" if self.is_alive() else "stopped"
903
+ return f"<PersistentVenvWorker name={self.name!r} status={status}>"