comfy-env 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
comfy_env/runner.py ADDED
@@ -0,0 +1,273 @@
1
+ """
2
+ Generic runner for isolated subprocess execution.
3
+
4
+ This module is the entry point for subprocess execution. The runner handles
5
+ requests for ANY @isolated class in the environment, importing classes on demand.
6
+
7
+ Usage (Unix Domain Socket - recommended):
8
+ python -m comfy_env.runner \
9
+ --node-dir /path/to/ComfyUI-SAM3DObjects/nodes \
10
+ --comfyui-base /path/to/ComfyUI \
11
+ --import-paths ".,../vendor" \
12
+ --socket /tmp/comfyui-isolation-myenv-12345.sock
13
+
14
+ Usage (Legacy stdin/stdout):
15
+ python -m comfy_env.runner \
16
+ --node-dir /path/to/ComfyUI-SAM3DObjects/nodes \
17
+ --comfyui-base /path/to/ComfyUI \
18
+ --import-paths ".,../vendor"
19
+
20
+ The runner:
21
+ 1. Sets COMFYUI_ISOLATION_WORKER=1 (so @isolated decorator becomes no-op)
22
+ 2. Adds paths to sys.path
23
+ 3. Connects to Unix Domain Socket (or uses stdin/stdout)
24
+ 4. Dynamically imports classes as needed (cached)
25
+ 5. Calls methods and returns responses
26
+ """
27
+
28
+ import os
29
+ import sys
30
+ import json
31
+ import argparse
32
+ import traceback
33
+ import warnings
34
+ import logging
35
+ import importlib
36
+ from typing import Any, Dict, Optional
37
+
38
+ # Suppress warnings that could interfere with JSON IPC
39
+ warnings.filterwarnings("ignore")
40
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
41
+ logging.disable(logging.WARNING)
42
+
43
+ # Mark that we're in worker mode - this makes @isolated decorator a no-op
44
+ os.environ["COMFYUI_ISOLATION_WORKER"] = "1"
45
+
46
+
47
+ def setup_paths(node_dir: str, comfyui_base: Optional[str], import_paths: Optional[str]):
48
+ """Setup sys.path for imports."""
49
+ from pathlib import Path
50
+
51
+ node_path = Path(node_dir)
52
+
53
+ # Set COMFYUI_BASE env var for stubs to use
54
+ if comfyui_base:
55
+ os.environ["COMFYUI_BASE"] = comfyui_base
56
+
57
+ # Add comfyui-isolation stubs directory (provides folder_paths, etc.)
58
+ stubs_dir = Path(__file__).parent / "stubs"
59
+ sys.path.insert(0, str(stubs_dir))
60
+
61
+ # Add import paths
62
+ if import_paths:
63
+ for p in import_paths.split(","):
64
+ p = p.strip()
65
+ if p:
66
+ full_path = node_path / p
67
+ sys.path.insert(0, str(full_path))
68
+
69
+ # Add node_dir itself
70
+ sys.path.insert(0, str(node_path))
71
+
72
+
73
+ def serialize_result(obj: Any) -> Any:
74
+ """Serialize result for JSON transport."""
75
+ from comfy_env.ipc.protocol import encode_object
76
+ return encode_object(obj)
77
+
78
+
79
+ def deserialize_arg(obj: Any) -> Any:
80
+ """Deserialize argument from JSON transport."""
81
+ from comfy_env.ipc.protocol import decode_object
82
+ return decode_object(obj)
83
+
84
+
85
+ # Cache for imported classes and instances
86
+ _class_cache: Dict[str, type] = {}
87
+ _instance_cache: Dict[str, object] = {}
88
+
89
+
90
+ def get_instance(module_name: str, class_name: str) -> object:
91
+ """Get or create an instance of a class."""
92
+ cache_key = f"{module_name}.{class_name}"
93
+
94
+ if cache_key not in _instance_cache:
95
+ # Import the class if not cached
96
+ if cache_key not in _class_cache:
97
+ print(f"[Runner] Importing {class_name} from {module_name}...", file=sys.stderr)
98
+ module = importlib.import_module(module_name)
99
+ cls = getattr(module, class_name)
100
+ _class_cache[cache_key] = cls
101
+
102
+ # Create instance
103
+ cls = _class_cache[cache_key]
104
+ _instance_cache[cache_key] = cls()
105
+ print(f"[Runner] Created instance of {class_name}", file=sys.stderr)
106
+
107
+ return _instance_cache[cache_key]
108
+
109
+
110
+ def run_worker(
111
+ node_dir: str,
112
+ comfyui_base: Optional[str],
113
+ import_paths: Optional[str],
114
+ socket_path: Optional[str] = None,
115
+ ):
116
+ """
117
+ Main worker loop - handles JSON-RPC requests via transport.
118
+
119
+ Args:
120
+ node_dir: Path to node package directory
121
+ comfyui_base: Path to ComfyUI base directory
122
+ import_paths: Comma-separated import paths
123
+ socket_path: Unix domain socket path (if None, uses stdin/stdout)
124
+ """
125
+ from comfy_env.ipc.transport import UnixSocketTransport, StdioTransport
126
+
127
+ # Setup paths first
128
+ setup_paths(node_dir, comfyui_base, import_paths)
129
+
130
+ # Create transport
131
+ if socket_path:
132
+ # Unix Domain Socket transport (recommended)
133
+ print(f"[Runner] Connecting to socket: {socket_path}", file=sys.stderr)
134
+ transport = UnixSocketTransport.connect(socket_path)
135
+ use_uds = True
136
+ else:
137
+ # Legacy stdin/stdout transport
138
+ print("[Runner] Using stdin/stdout transport", file=sys.stderr)
139
+ transport = StdioTransport()
140
+ use_uds = False
141
+
142
+ try:
143
+ # Send ready signal
144
+ transport.send({"status": "ready"})
145
+
146
+ # Main loop - read requests, execute, respond
147
+ while True:
148
+ response = {"jsonrpc": "2.0", "id": None}
149
+
150
+ try:
151
+ request = transport.recv()
152
+ response["id"] = request.get("id")
153
+
154
+ method_name = request.get("method")
155
+ params = request.get("params", {})
156
+
157
+ if method_name == "shutdown":
158
+ # Clean shutdown
159
+ response["result"] = {"status": "shutdown"}
160
+ transport.send(response)
161
+ break
162
+
163
+ # Get module/class from request
164
+ module_name = request.get("module")
165
+ class_name = request.get("class")
166
+
167
+ if not module_name or not class_name:
168
+ response["error"] = {
169
+ "code": -32602,
170
+ "message": "Missing 'module' or 'class' in request",
171
+ }
172
+ transport.send(response)
173
+ continue
174
+
175
+ # Get or create instance
176
+ try:
177
+ instance = get_instance(module_name, class_name)
178
+ except Exception as e:
179
+ response["error"] = {
180
+ "code": -32000,
181
+ "message": f"Failed to import {module_name}.{class_name}: {e}",
182
+ "data": {"traceback": traceback.format_exc()}
183
+ }
184
+ transport.send(response)
185
+ continue
186
+
187
+ # Get the method
188
+ method = getattr(instance, method_name, None)
189
+ if method is None:
190
+ response["error"] = {
191
+ "code": -32601,
192
+ "message": f"Method not found: {method_name}",
193
+ }
194
+ transport.send(response)
195
+ continue
196
+
197
+ # Deserialize arguments
198
+ deserialized_params = {}
199
+ for key, value in params.items():
200
+ deserialized_params[key] = deserialize_arg(value)
201
+
202
+ # For legacy stdio transport, redirect stdout to stderr during execution
203
+ # This prevents print() in node code from corrupting JSON protocol
204
+ # (UDS transport doesn't need this since it uses a separate socket)
205
+ if not use_uds:
206
+ original_stdout = sys.stdout
207
+ sys.stdout = sys.stderr
208
+
209
+ # Also redirect at file descriptor level for C libraries
210
+ stdout_fd = original_stdout.fileno()
211
+ stderr_fd = sys.stderr.fileno()
212
+ stdout_fd_copy = os.dup(stdout_fd)
213
+ os.dup2(stderr_fd, stdout_fd)
214
+
215
+ # Call the method
216
+ print(f"[Runner] Calling {class_name}.{method_name}...", file=sys.stderr)
217
+ try:
218
+ result = method(**deserialized_params)
219
+ finally:
220
+ if not use_uds:
221
+ # Restore file descriptor first, then Python stdout
222
+ os.dup2(stdout_fd_copy, stdout_fd)
223
+ os.close(stdout_fd_copy)
224
+ sys.stdout = original_stdout
225
+
226
+ # Serialize result
227
+ serialized_result = serialize_result(result)
228
+ response["result"] = serialized_result
229
+
230
+ print(f"[Runner] {class_name}.{method_name} completed", file=sys.stderr)
231
+
232
+ except ConnectionError as e:
233
+ # Socket closed - normal shutdown
234
+ print(f"[Runner] Connection closed: {e}", file=sys.stderr)
235
+ break
236
+ except Exception as e:
237
+ tb = traceback.format_exc()
238
+ print(f"[Runner] Error: {e}", file=sys.stderr)
239
+ print(tb, file=sys.stderr)
240
+ response["error"] = {
241
+ "code": -32000,
242
+ "message": str(e),
243
+ "data": {"traceback": tb}
244
+ }
245
+
246
+ try:
247
+ transport.send(response)
248
+ except ConnectionError:
249
+ break
250
+
251
+ finally:
252
+ transport.close()
253
+
254
+
255
+ def main():
256
+ parser = argparse.ArgumentParser(description="Isolated node runner")
257
+ parser.add_argument("--node-dir", required=True, help="Node package directory")
258
+ parser.add_argument("--comfyui-base", help="ComfyUI base directory")
259
+ parser.add_argument("--import-paths", help="Comma-separated import paths")
260
+ parser.add_argument("--socket", help="Unix domain socket path (if not provided, uses stdin/stdout)")
261
+
262
+ args = parser.parse_args()
263
+
264
+ run_worker(
265
+ node_dir=args.node_dir,
266
+ comfyui_base=args.comfyui_base,
267
+ import_paths=args.import_paths,
268
+ socket_path=args.socket,
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()
@@ -0,0 +1 @@
1
+ # ComfyUI stubs for isolated workers
@@ -0,0 +1,57 @@
1
+ """
2
+ Minimal folder_paths stub for isolated worker processes.
3
+
4
+ Provides the same interface as ComfyUI's folder_paths module
5
+ without importing any ComfyUI dependencies.
6
+ """
7
+
8
+ import os
9
+ from pathlib import Path
10
+
11
+ _comfyui_base = None
12
+
13
+ def _find_comfyui_base():
14
+ """Find ComfyUI base from COMFYUI_BASE env var or by walking up."""
15
+ global _comfyui_base
16
+ if _comfyui_base:
17
+ return _comfyui_base
18
+
19
+ # Check env var first
20
+ if os.environ.get("COMFYUI_BASE"):
21
+ _comfyui_base = Path(os.environ["COMFYUI_BASE"])
22
+ return _comfyui_base
23
+
24
+ # Walk up from cwd looking for ComfyUI
25
+ current = Path.cwd().resolve()
26
+ for _ in range(10):
27
+ if (current / "main.py").exists() and (current / "comfy").exists():
28
+ _comfyui_base = current
29
+ return _comfyui_base
30
+ current = current.parent
31
+
32
+ return None
33
+
34
+ # Models directory
35
+ @property
36
+ def models_dir():
37
+ base = _find_comfyui_base()
38
+ return str(base / "models") if base else None
39
+
40
+ # Make models_dir work as both attribute and property
41
+ class _ModuleProxy:
42
+ @property
43
+ def models_dir(self):
44
+ base = _find_comfyui_base()
45
+ return str(base / "models") if base else None
46
+
47
+ def get_output_directory(self):
48
+ base = _find_comfyui_base()
49
+ return str(base / "output") if base else None
50
+
51
+ def get_input_directory(self):
52
+ base = _find_comfyui_base()
53
+ return str(base / "input") if base else None
54
+
55
+ # Replace module with proxy instance
56
+ import sys
57
+ sys.modules[__name__] = _ModuleProxy()
@@ -0,0 +1,49 @@
1
+ """
2
+ Workers - Simple, explicit process isolation for ComfyUI nodes.
3
+
4
+ This module provides three isolation tiers:
5
+
6
+ Tier 1: TorchMPWorker (same Python, fresh CUDA context)
7
+ - Uses torch.multiprocessing.Queue
8
+ - Zero-copy tensor transfer via CUDA IPC
9
+ - ~30ms overhead per call
10
+ - Use for: Memory isolation, fresh CUDA context
11
+
12
+ Tier 2: VenvWorker (different Python/venv)
13
+ - Uses subprocess + torch.save/load via /dev/shm
14
+ - One memcpy per tensor direction
15
+ - ~100-500ms overhead per call
16
+ - Use for: Different PyTorch versions, incompatible deps
17
+
18
+ Tier 3: ContainerWorker (full isolation) [future]
19
+ - Docker with GPU passthrough
20
+ - Use for: Different CUDA versions, hermetic environments
21
+
22
+ Usage:
23
+ from comfy_env.workers import get_worker, TorchMPWorker
24
+
25
+ # Get a named worker from the pool
26
+ worker = get_worker("sam3d")
27
+ result = worker.call(my_function, image=tensor)
28
+
29
+ # Or create directly
30
+ worker = TorchMPWorker()
31
+ result = worker.call(my_function, arg1, arg2)
32
+ """
33
+
34
+ from .base import Worker
35
+ from .torch_mp import TorchMPWorker
36
+ from .venv import VenvWorker, PersistentVenvWorker
37
+ from .pool import WorkerPool, get_worker, register_worker, shutdown_workers, list_workers
38
+
39
+ __all__ = [
40
+ "Worker",
41
+ "TorchMPWorker",
42
+ "VenvWorker",
43
+ "PersistentVenvWorker",
44
+ "WorkerPool",
45
+ "get_worker",
46
+ "register_worker",
47
+ "shutdown_workers",
48
+ "list_workers",
49
+ ]
@@ -0,0 +1,82 @@
1
+ """
2
+ Base Worker Interface - Protocol for all worker implementations.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Callable, Optional
7
+
8
+
9
+ class Worker(ABC):
10
+ """
11
+ Abstract base class for process isolation workers.
12
+
13
+ All workers must implement:
14
+ - call(): Execute a function in the isolated process
15
+ - shutdown(): Clean up resources
16
+
17
+ Workers should be used as context managers when possible:
18
+
19
+ with TorchMPWorker() as worker:
20
+ result = worker.call(my_func, arg1, arg2)
21
+ """
22
+
23
+ @abstractmethod
24
+ def call(
25
+ self,
26
+ func: Callable,
27
+ *args,
28
+ timeout: Optional[float] = None,
29
+ **kwargs
30
+ ) -> Any:
31
+ """
32
+ Execute a function in the isolated worker process.
33
+
34
+ Args:
35
+ func: The function to execute. Must be picklable (top-level or staticmethod).
36
+ *args: Positional arguments passed to func.
37
+ timeout: Optional timeout in seconds (None = no timeout).
38
+ **kwargs: Keyword arguments passed to func.
39
+
40
+ Returns:
41
+ The return value of func(*args, **kwargs).
42
+
43
+ Raises:
44
+ TimeoutError: If execution exceeds timeout.
45
+ RuntimeError: If worker process dies or raises exception.
46
+ """
47
+ pass
48
+
49
+ @abstractmethod
50
+ def shutdown(self) -> None:
51
+ """
52
+ Shut down the worker and release resources.
53
+
54
+ Safe to call multiple times. After shutdown, further calls to
55
+ call() will raise RuntimeError.
56
+ """
57
+ pass
58
+
59
+ @abstractmethod
60
+ def is_alive(self) -> bool:
61
+ """Check if the worker process is still running."""
62
+ pass
63
+
64
+ def __enter__(self) -> "Worker":
65
+ return self
66
+
67
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
68
+ self.shutdown()
69
+
70
+
71
+ class WorkerError(Exception):
72
+ """Exception raised when a worker encounters an error."""
73
+
74
+ def __init__(self, message: str, traceback: Optional[str] = None):
75
+ super().__init__(message)
76
+ self.worker_traceback = traceback
77
+
78
+ def __str__(self):
79
+ msg = super().__str__()
80
+ if self.worker_traceback:
81
+ msg += f"\n\nWorker traceback:\n{self.worker_traceback}"
82
+ return msg
@@ -0,0 +1,241 @@
1
+ """
2
+ WorkerPool - Global registry and management of named workers.
3
+
4
+ Provides a simple API for getting workers by name:
5
+
6
+ from comfy_env.workers import get_worker
7
+
8
+ worker = get_worker("sam3d")
9
+ result = worker.call_module("my_module", "my_func", image=tensor)
10
+
11
+ Workers are registered at startup and reused across calls:
12
+
13
+ from comfy_env.workers import register_worker, TorchMPWorker
14
+
15
+ register_worker("default", TorchMPWorker())
16
+ register_worker("sam3d", PersistentVenvWorker(
17
+ python="/path/to/venv/bin/python",
18
+ working_dir="/path/to/nodes",
19
+ ))
20
+ """
21
+
22
+ import atexit
23
+ import threading
24
+ from typing import Dict, Optional, Union
25
+ from pathlib import Path
26
+
27
+ from .base import Worker
28
+
29
+
30
+ class WorkerPool:
31
+ """
32
+ Singleton pool of named workers.
33
+
34
+ Manages worker lifecycle, provides access by name, handles cleanup.
35
+ """
36
+
37
+ _instance: Optional["WorkerPool"] = None
38
+ _lock = threading.Lock()
39
+
40
+ def __new__(cls):
41
+ if cls._instance is None:
42
+ with cls._lock:
43
+ if cls._instance is None:
44
+ cls._instance = super().__new__(cls)
45
+ cls._instance._initialized = False
46
+ return cls._instance
47
+
48
+ def __init__(self):
49
+ if self._initialized:
50
+ return
51
+ self._initialized = True
52
+ self._workers: Dict[str, Worker] = {}
53
+ self._factories: Dict[str, callable] = {}
54
+ self._worker_lock = threading.Lock()
55
+
56
+ def register(
57
+ self,
58
+ name: str,
59
+ worker: Optional[Worker] = None,
60
+ factory: Optional[callable] = None,
61
+ ) -> None:
62
+ """
63
+ Register a worker or worker factory.
64
+
65
+ Args:
66
+ name: Name to register under.
67
+ worker: Pre-created worker instance.
68
+ factory: Callable that creates worker on first use (lazy).
69
+
70
+ Only one of worker or factory should be provided.
71
+ """
72
+ if worker is not None and factory is not None:
73
+ raise ValueError("Provide either worker or factory, not both")
74
+ if worker is None and factory is None:
75
+ raise ValueError("Must provide worker or factory")
76
+
77
+ with self._worker_lock:
78
+ # Shutdown existing worker if replacing
79
+ if name in self._workers:
80
+ try:
81
+ self._workers[name].shutdown()
82
+ except:
83
+ pass
84
+
85
+ if worker is not None:
86
+ self._workers[name] = worker
87
+ self._factories.pop(name, None)
88
+ else:
89
+ self._factories[name] = factory
90
+ self._workers.pop(name, None)
91
+
92
+ def get(self, name: str) -> Worker:
93
+ """
94
+ Get a worker by name.
95
+
96
+ Args:
97
+ name: Registered worker name.
98
+
99
+ Returns:
100
+ The worker instance.
101
+
102
+ Raises:
103
+ KeyError: If no worker registered with that name.
104
+ """
105
+ with self._worker_lock:
106
+ # Check for existing worker
107
+ if name in self._workers:
108
+ worker = self._workers[name]
109
+ if worker.is_alive():
110
+ return worker
111
+ # Worker died, try to recreate from factory
112
+ if name not in self._factories:
113
+ raise RuntimeError(f"Worker '{name}' died and no factory to recreate")
114
+
115
+ # Create from factory
116
+ if name in self._factories:
117
+ worker = self._factories[name]()
118
+ self._workers[name] = worker
119
+ return worker
120
+
121
+ raise KeyError(f"No worker registered with name: {name}")
122
+
123
+ def shutdown(self, name: Optional[str] = None) -> None:
124
+ """
125
+ Shutdown workers.
126
+
127
+ Args:
128
+ name: If provided, shutdown only this worker.
129
+ If None, shutdown all workers.
130
+ """
131
+ with self._worker_lock:
132
+ if name is not None:
133
+ if name in self._workers:
134
+ try:
135
+ self._workers[name].shutdown()
136
+ except:
137
+ pass
138
+ del self._workers[name]
139
+ else:
140
+ for worker in self._workers.values():
141
+ try:
142
+ worker.shutdown()
143
+ except:
144
+ pass
145
+ self._workers.clear()
146
+
147
+ def list_workers(self) -> Dict[str, str]:
148
+ """
149
+ List all registered workers.
150
+
151
+ Returns:
152
+ Dict of name -> status string.
153
+ """
154
+ with self._worker_lock:
155
+ result = {}
156
+ for name, worker in self._workers.items():
157
+ status = "alive" if worker.is_alive() else "dead"
158
+ result[name] = f"{type(worker).__name__} ({status})"
159
+ for name in self._factories:
160
+ if name not in result:
161
+ result[name] = f"factory (not started)"
162
+ return result
163
+
164
+
165
+ # Global pool instance
166
+ _pool = WorkerPool()
167
+
168
+
169
+ def get_worker(name: str) -> Worker:
170
+ """
171
+ Get a worker by name from the global pool.
172
+
173
+ Args:
174
+ name: Registered worker name.
175
+
176
+ Returns:
177
+ Worker instance.
178
+
179
+ Example:
180
+ worker = get_worker("sam3d")
181
+ result = worker.call_module("my_module", "my_func", image=tensor)
182
+ """
183
+ return _pool.get(name)
184
+
185
+
186
+ def register_worker(
187
+ name: str,
188
+ worker: Optional[Worker] = None,
189
+ factory: Optional[callable] = None,
190
+ ) -> None:
191
+ """
192
+ Register a worker in the global pool.
193
+
194
+ Args:
195
+ name: Name to register under.
196
+ worker: Pre-created worker instance.
197
+ factory: Callable that creates worker on demand.
198
+
199
+ Example:
200
+ # Register pre-created worker
201
+ register_worker("default", TorchMPWorker())
202
+
203
+ # Register factory for lazy creation
204
+ register_worker("sam3d", factory=lambda: PersistentVenvWorker(
205
+ python="/path/to/venv/bin/python",
206
+ ))
207
+ """
208
+ _pool.register(name, worker=worker, factory=factory)
209
+
210
+
211
+ def shutdown_workers(name: Optional[str] = None) -> None:
212
+ """
213
+ Shutdown workers in the global pool.
214
+
215
+ Args:
216
+ name: If provided, shutdown only this worker.
217
+ If None, shutdown all workers.
218
+ """
219
+ _pool.shutdown(name)
220
+
221
+
222
+ def list_workers() -> Dict[str, str]:
223
+ """
224
+ List all registered workers.
225
+
226
+ Returns:
227
+ Dict of name -> status description.
228
+ """
229
+ return _pool.list_workers()
230
+
231
+
232
+ # Register default worker (TorchMPWorker) on import
233
+ def _register_default():
234
+ from .torch_mp import TorchMPWorker
235
+ register_worker("default", factory=lambda: TorchMPWorker(name="default"))
236
+
237
+
238
+ _register_default()
239
+
240
+ # Cleanup on exit
241
+ atexit.register(lambda: shutdown_workers())