comfy-env 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- comfy_env/__init__.py +161 -0
- comfy_env/cli.py +388 -0
- comfy_env/decorator.py +422 -0
- comfy_env/env/__init__.py +30 -0
- comfy_env/env/config.py +144 -0
- comfy_env/env/config_file.py +592 -0
- comfy_env/env/detection.py +176 -0
- comfy_env/env/manager.py +657 -0
- comfy_env/env/platform/__init__.py +21 -0
- comfy_env/env/platform/base.py +96 -0
- comfy_env/env/platform/darwin.py +53 -0
- comfy_env/env/platform/linux.py +68 -0
- comfy_env/env/platform/windows.py +377 -0
- comfy_env/env/security.py +267 -0
- comfy_env/errors.py +325 -0
- comfy_env/install.py +539 -0
- comfy_env/ipc/__init__.py +55 -0
- comfy_env/ipc/bridge.py +512 -0
- comfy_env/ipc/protocol.py +265 -0
- comfy_env/ipc/tensor.py +371 -0
- comfy_env/ipc/torch_bridge.py +401 -0
- comfy_env/ipc/transport.py +318 -0
- comfy_env/ipc/worker.py +221 -0
- comfy_env/registry.py +252 -0
- comfy_env/resolver.py +399 -0
- comfy_env/runner.py +273 -0
- comfy_env/stubs/__init__.py +1 -0
- comfy_env/stubs/folder_paths.py +57 -0
- comfy_env/workers/__init__.py +49 -0
- comfy_env/workers/base.py +82 -0
- comfy_env/workers/pool.py +241 -0
- comfy_env/workers/tensor_utils.py +188 -0
- comfy_env/workers/torch_mp.py +375 -0
- comfy_env/workers/venv.py +903 -0
- comfy_env-0.0.8.dist-info/METADATA +228 -0
- comfy_env-0.0.8.dist-info/RECORD +39 -0
- comfy_env-0.0.8.dist-info/WHEEL +4 -0
- comfy_env-0.0.8.dist-info/entry_points.txt +2 -0
- comfy_env-0.0.8.dist-info/licenses/LICENSE +21 -0
comfy_env/runner.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generic runner for isolated subprocess execution.
|
|
3
|
+
|
|
4
|
+
This module is the entry point for subprocess execution. The runner handles
|
|
5
|
+
requests for ANY @isolated class in the environment, importing classes on demand.
|
|
6
|
+
|
|
7
|
+
Usage (Unix Domain Socket - recommended):
|
|
8
|
+
python -m comfy_env.runner \
|
|
9
|
+
--node-dir /path/to/ComfyUI-SAM3DObjects/nodes \
|
|
10
|
+
--comfyui-base /path/to/ComfyUI \
|
|
11
|
+
--import-paths ".,../vendor" \
|
|
12
|
+
--socket /tmp/comfyui-isolation-myenv-12345.sock
|
|
13
|
+
|
|
14
|
+
Usage (Legacy stdin/stdout):
|
|
15
|
+
python -m comfy_env.runner \
|
|
16
|
+
--node-dir /path/to/ComfyUI-SAM3DObjects/nodes \
|
|
17
|
+
--comfyui-base /path/to/ComfyUI \
|
|
18
|
+
--import-paths ".,../vendor"
|
|
19
|
+
|
|
20
|
+
The runner:
|
|
21
|
+
1. Sets COMFYUI_ISOLATION_WORKER=1 (so @isolated decorator becomes no-op)
|
|
22
|
+
2. Adds paths to sys.path
|
|
23
|
+
3. Connects to Unix Domain Socket (or uses stdin/stdout)
|
|
24
|
+
4. Dynamically imports classes as needed (cached)
|
|
25
|
+
5. Calls methods and returns responses
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import os
|
|
29
|
+
import sys
|
|
30
|
+
import json
|
|
31
|
+
import argparse
|
|
32
|
+
import traceback
|
|
33
|
+
import warnings
|
|
34
|
+
import logging
|
|
35
|
+
import importlib
|
|
36
|
+
from typing import Any, Dict, Optional
|
|
37
|
+
|
|
38
|
+
# Suppress warnings that could interfere with JSON IPC
|
|
39
|
+
warnings.filterwarnings("ignore")
|
|
40
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
|
|
41
|
+
logging.disable(logging.WARNING)
|
|
42
|
+
|
|
43
|
+
# Mark that we're in worker mode - this makes @isolated decorator a no-op
|
|
44
|
+
os.environ["COMFYUI_ISOLATION_WORKER"] = "1"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def setup_paths(node_dir: str, comfyui_base: Optional[str], import_paths: Optional[str]):
|
|
48
|
+
"""Setup sys.path for imports."""
|
|
49
|
+
from pathlib import Path
|
|
50
|
+
|
|
51
|
+
node_path = Path(node_dir)
|
|
52
|
+
|
|
53
|
+
# Set COMFYUI_BASE env var for stubs to use
|
|
54
|
+
if comfyui_base:
|
|
55
|
+
os.environ["COMFYUI_BASE"] = comfyui_base
|
|
56
|
+
|
|
57
|
+
# Add comfyui-isolation stubs directory (provides folder_paths, etc.)
|
|
58
|
+
stubs_dir = Path(__file__).parent / "stubs"
|
|
59
|
+
sys.path.insert(0, str(stubs_dir))
|
|
60
|
+
|
|
61
|
+
# Add import paths
|
|
62
|
+
if import_paths:
|
|
63
|
+
for p in import_paths.split(","):
|
|
64
|
+
p = p.strip()
|
|
65
|
+
if p:
|
|
66
|
+
full_path = node_path / p
|
|
67
|
+
sys.path.insert(0, str(full_path))
|
|
68
|
+
|
|
69
|
+
# Add node_dir itself
|
|
70
|
+
sys.path.insert(0, str(node_path))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def serialize_result(obj: Any) -> Any:
|
|
74
|
+
"""Serialize result for JSON transport."""
|
|
75
|
+
from comfy_env.ipc.protocol import encode_object
|
|
76
|
+
return encode_object(obj)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def deserialize_arg(obj: Any) -> Any:
|
|
80
|
+
"""Deserialize argument from JSON transport."""
|
|
81
|
+
from comfy_env.ipc.protocol import decode_object
|
|
82
|
+
return decode_object(obj)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Cache for imported classes and instances
|
|
86
|
+
_class_cache: Dict[str, type] = {}
|
|
87
|
+
_instance_cache: Dict[str, object] = {}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_instance(module_name: str, class_name: str) -> object:
|
|
91
|
+
"""Get or create an instance of a class."""
|
|
92
|
+
cache_key = f"{module_name}.{class_name}"
|
|
93
|
+
|
|
94
|
+
if cache_key not in _instance_cache:
|
|
95
|
+
# Import the class if not cached
|
|
96
|
+
if cache_key not in _class_cache:
|
|
97
|
+
print(f"[Runner] Importing {class_name} from {module_name}...", file=sys.stderr)
|
|
98
|
+
module = importlib.import_module(module_name)
|
|
99
|
+
cls = getattr(module, class_name)
|
|
100
|
+
_class_cache[cache_key] = cls
|
|
101
|
+
|
|
102
|
+
# Create instance
|
|
103
|
+
cls = _class_cache[cache_key]
|
|
104
|
+
_instance_cache[cache_key] = cls()
|
|
105
|
+
print(f"[Runner] Created instance of {class_name}", file=sys.stderr)
|
|
106
|
+
|
|
107
|
+
return _instance_cache[cache_key]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def run_worker(
|
|
111
|
+
node_dir: str,
|
|
112
|
+
comfyui_base: Optional[str],
|
|
113
|
+
import_paths: Optional[str],
|
|
114
|
+
socket_path: Optional[str] = None,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Main worker loop - handles JSON-RPC requests via transport.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
node_dir: Path to node package directory
|
|
121
|
+
comfyui_base: Path to ComfyUI base directory
|
|
122
|
+
import_paths: Comma-separated import paths
|
|
123
|
+
socket_path: Unix domain socket path (if None, uses stdin/stdout)
|
|
124
|
+
"""
|
|
125
|
+
from comfy_env.ipc.transport import UnixSocketTransport, StdioTransport
|
|
126
|
+
|
|
127
|
+
# Setup paths first
|
|
128
|
+
setup_paths(node_dir, comfyui_base, import_paths)
|
|
129
|
+
|
|
130
|
+
# Create transport
|
|
131
|
+
if socket_path:
|
|
132
|
+
# Unix Domain Socket transport (recommended)
|
|
133
|
+
print(f"[Runner] Connecting to socket: {socket_path}", file=sys.stderr)
|
|
134
|
+
transport = UnixSocketTransport.connect(socket_path)
|
|
135
|
+
use_uds = True
|
|
136
|
+
else:
|
|
137
|
+
# Legacy stdin/stdout transport
|
|
138
|
+
print("[Runner] Using stdin/stdout transport", file=sys.stderr)
|
|
139
|
+
transport = StdioTransport()
|
|
140
|
+
use_uds = False
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
# Send ready signal
|
|
144
|
+
transport.send({"status": "ready"})
|
|
145
|
+
|
|
146
|
+
# Main loop - read requests, execute, respond
|
|
147
|
+
while True:
|
|
148
|
+
response = {"jsonrpc": "2.0", "id": None}
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
request = transport.recv()
|
|
152
|
+
response["id"] = request.get("id")
|
|
153
|
+
|
|
154
|
+
method_name = request.get("method")
|
|
155
|
+
params = request.get("params", {})
|
|
156
|
+
|
|
157
|
+
if method_name == "shutdown":
|
|
158
|
+
# Clean shutdown
|
|
159
|
+
response["result"] = {"status": "shutdown"}
|
|
160
|
+
transport.send(response)
|
|
161
|
+
break
|
|
162
|
+
|
|
163
|
+
# Get module/class from request
|
|
164
|
+
module_name = request.get("module")
|
|
165
|
+
class_name = request.get("class")
|
|
166
|
+
|
|
167
|
+
if not module_name or not class_name:
|
|
168
|
+
response["error"] = {
|
|
169
|
+
"code": -32602,
|
|
170
|
+
"message": "Missing 'module' or 'class' in request",
|
|
171
|
+
}
|
|
172
|
+
transport.send(response)
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
# Get or create instance
|
|
176
|
+
try:
|
|
177
|
+
instance = get_instance(module_name, class_name)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
response["error"] = {
|
|
180
|
+
"code": -32000,
|
|
181
|
+
"message": f"Failed to import {module_name}.{class_name}: {e}",
|
|
182
|
+
"data": {"traceback": traceback.format_exc()}
|
|
183
|
+
}
|
|
184
|
+
transport.send(response)
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
# Get the method
|
|
188
|
+
method = getattr(instance, method_name, None)
|
|
189
|
+
if method is None:
|
|
190
|
+
response["error"] = {
|
|
191
|
+
"code": -32601,
|
|
192
|
+
"message": f"Method not found: {method_name}",
|
|
193
|
+
}
|
|
194
|
+
transport.send(response)
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
# Deserialize arguments
|
|
198
|
+
deserialized_params = {}
|
|
199
|
+
for key, value in params.items():
|
|
200
|
+
deserialized_params[key] = deserialize_arg(value)
|
|
201
|
+
|
|
202
|
+
# For legacy stdio transport, redirect stdout to stderr during execution
|
|
203
|
+
# This prevents print() in node code from corrupting JSON protocol
|
|
204
|
+
# (UDS transport doesn't need this since it uses a separate socket)
|
|
205
|
+
if not use_uds:
|
|
206
|
+
original_stdout = sys.stdout
|
|
207
|
+
sys.stdout = sys.stderr
|
|
208
|
+
|
|
209
|
+
# Also redirect at file descriptor level for C libraries
|
|
210
|
+
stdout_fd = original_stdout.fileno()
|
|
211
|
+
stderr_fd = sys.stderr.fileno()
|
|
212
|
+
stdout_fd_copy = os.dup(stdout_fd)
|
|
213
|
+
os.dup2(stderr_fd, stdout_fd)
|
|
214
|
+
|
|
215
|
+
# Call the method
|
|
216
|
+
print(f"[Runner] Calling {class_name}.{method_name}...", file=sys.stderr)
|
|
217
|
+
try:
|
|
218
|
+
result = method(**deserialized_params)
|
|
219
|
+
finally:
|
|
220
|
+
if not use_uds:
|
|
221
|
+
# Restore file descriptor first, then Python stdout
|
|
222
|
+
os.dup2(stdout_fd_copy, stdout_fd)
|
|
223
|
+
os.close(stdout_fd_copy)
|
|
224
|
+
sys.stdout = original_stdout
|
|
225
|
+
|
|
226
|
+
# Serialize result
|
|
227
|
+
serialized_result = serialize_result(result)
|
|
228
|
+
response["result"] = serialized_result
|
|
229
|
+
|
|
230
|
+
print(f"[Runner] {class_name}.{method_name} completed", file=sys.stderr)
|
|
231
|
+
|
|
232
|
+
except ConnectionError as e:
|
|
233
|
+
# Socket closed - normal shutdown
|
|
234
|
+
print(f"[Runner] Connection closed: {e}", file=sys.stderr)
|
|
235
|
+
break
|
|
236
|
+
except Exception as e:
|
|
237
|
+
tb = traceback.format_exc()
|
|
238
|
+
print(f"[Runner] Error: {e}", file=sys.stderr)
|
|
239
|
+
print(tb, file=sys.stderr)
|
|
240
|
+
response["error"] = {
|
|
241
|
+
"code": -32000,
|
|
242
|
+
"message": str(e),
|
|
243
|
+
"data": {"traceback": tb}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
transport.send(response)
|
|
248
|
+
except ConnectionError:
|
|
249
|
+
break
|
|
250
|
+
|
|
251
|
+
finally:
|
|
252
|
+
transport.close()
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def main():
|
|
256
|
+
parser = argparse.ArgumentParser(description="Isolated node runner")
|
|
257
|
+
parser.add_argument("--node-dir", required=True, help="Node package directory")
|
|
258
|
+
parser.add_argument("--comfyui-base", help="ComfyUI base directory")
|
|
259
|
+
parser.add_argument("--import-paths", help="Comma-separated import paths")
|
|
260
|
+
parser.add_argument("--socket", help="Unix domain socket path (if not provided, uses stdin/stdout)")
|
|
261
|
+
|
|
262
|
+
args = parser.parse_args()
|
|
263
|
+
|
|
264
|
+
run_worker(
|
|
265
|
+
node_dir=args.node_dir,
|
|
266
|
+
comfyui_base=args.comfyui_base,
|
|
267
|
+
import_paths=args.import_paths,
|
|
268
|
+
socket_path=args.socket,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
if __name__ == "__main__":
|
|
273
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# ComfyUI stubs for isolated workers
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Minimal folder_paths stub for isolated worker processes.
|
|
3
|
+
|
|
4
|
+
Provides the same interface as ComfyUI's folder_paths module
|
|
5
|
+
without importing any ComfyUI dependencies.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
_comfyui_base = None
|
|
12
|
+
|
|
13
|
+
def _find_comfyui_base():
|
|
14
|
+
"""Find ComfyUI base from COMFYUI_BASE env var or by walking up."""
|
|
15
|
+
global _comfyui_base
|
|
16
|
+
if _comfyui_base:
|
|
17
|
+
return _comfyui_base
|
|
18
|
+
|
|
19
|
+
# Check env var first
|
|
20
|
+
if os.environ.get("COMFYUI_BASE"):
|
|
21
|
+
_comfyui_base = Path(os.environ["COMFYUI_BASE"])
|
|
22
|
+
return _comfyui_base
|
|
23
|
+
|
|
24
|
+
# Walk up from cwd looking for ComfyUI
|
|
25
|
+
current = Path.cwd().resolve()
|
|
26
|
+
for _ in range(10):
|
|
27
|
+
if (current / "main.py").exists() and (current / "comfy").exists():
|
|
28
|
+
_comfyui_base = current
|
|
29
|
+
return _comfyui_base
|
|
30
|
+
current = current.parent
|
|
31
|
+
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
# Models directory
|
|
35
|
+
@property
|
|
36
|
+
def models_dir():
|
|
37
|
+
base = _find_comfyui_base()
|
|
38
|
+
return str(base / "models") if base else None
|
|
39
|
+
|
|
40
|
+
# Make models_dir work as both attribute and property
|
|
41
|
+
class _ModuleProxy:
|
|
42
|
+
@property
|
|
43
|
+
def models_dir(self):
|
|
44
|
+
base = _find_comfyui_base()
|
|
45
|
+
return str(base / "models") if base else None
|
|
46
|
+
|
|
47
|
+
def get_output_directory(self):
|
|
48
|
+
base = _find_comfyui_base()
|
|
49
|
+
return str(base / "output") if base else None
|
|
50
|
+
|
|
51
|
+
def get_input_directory(self):
|
|
52
|
+
base = _find_comfyui_base()
|
|
53
|
+
return str(base / "input") if base else None
|
|
54
|
+
|
|
55
|
+
# Replace module with proxy instance
|
|
56
|
+
import sys
|
|
57
|
+
sys.modules[__name__] = _ModuleProxy()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Workers - Simple, explicit process isolation for ComfyUI nodes.
|
|
3
|
+
|
|
4
|
+
This module provides three isolation tiers:
|
|
5
|
+
|
|
6
|
+
Tier 1: TorchMPWorker (same Python, fresh CUDA context)
|
|
7
|
+
- Uses torch.multiprocessing.Queue
|
|
8
|
+
- Zero-copy tensor transfer via CUDA IPC
|
|
9
|
+
- ~30ms overhead per call
|
|
10
|
+
- Use for: Memory isolation, fresh CUDA context
|
|
11
|
+
|
|
12
|
+
Tier 2: VenvWorker (different Python/venv)
|
|
13
|
+
- Uses subprocess + torch.save/load via /dev/shm
|
|
14
|
+
- One memcpy per tensor direction
|
|
15
|
+
- ~100-500ms overhead per call
|
|
16
|
+
- Use for: Different PyTorch versions, incompatible deps
|
|
17
|
+
|
|
18
|
+
Tier 3: ContainerWorker (full isolation) [future]
|
|
19
|
+
- Docker with GPU passthrough
|
|
20
|
+
- Use for: Different CUDA versions, hermetic environments
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
from comfy_env.workers import get_worker, TorchMPWorker
|
|
24
|
+
|
|
25
|
+
# Get a named worker from the pool
|
|
26
|
+
worker = get_worker("sam3d")
|
|
27
|
+
result = worker.call(my_function, image=tensor)
|
|
28
|
+
|
|
29
|
+
# Or create directly
|
|
30
|
+
worker = TorchMPWorker()
|
|
31
|
+
result = worker.call(my_function, arg1, arg2)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from .base import Worker
|
|
35
|
+
from .torch_mp import TorchMPWorker
|
|
36
|
+
from .venv import VenvWorker, PersistentVenvWorker
|
|
37
|
+
from .pool import WorkerPool, get_worker, register_worker, shutdown_workers, list_workers
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"Worker",
|
|
41
|
+
"TorchMPWorker",
|
|
42
|
+
"VenvWorker",
|
|
43
|
+
"PersistentVenvWorker",
|
|
44
|
+
"WorkerPool",
|
|
45
|
+
"get_worker",
|
|
46
|
+
"register_worker",
|
|
47
|
+
"shutdown_workers",
|
|
48
|
+
"list_workers",
|
|
49
|
+
]
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base Worker Interface - Protocol for all worker implementations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Worker(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for process isolation workers.
|
|
12
|
+
|
|
13
|
+
All workers must implement:
|
|
14
|
+
- call(): Execute a function in the isolated process
|
|
15
|
+
- shutdown(): Clean up resources
|
|
16
|
+
|
|
17
|
+
Workers should be used as context managers when possible:
|
|
18
|
+
|
|
19
|
+
with TorchMPWorker() as worker:
|
|
20
|
+
result = worker.call(my_func, arg1, arg2)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def call(
|
|
25
|
+
self,
|
|
26
|
+
func: Callable,
|
|
27
|
+
*args,
|
|
28
|
+
timeout: Optional[float] = None,
|
|
29
|
+
**kwargs
|
|
30
|
+
) -> Any:
|
|
31
|
+
"""
|
|
32
|
+
Execute a function in the isolated worker process.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
func: The function to execute. Must be picklable (top-level or staticmethod).
|
|
36
|
+
*args: Positional arguments passed to func.
|
|
37
|
+
timeout: Optional timeout in seconds (None = no timeout).
|
|
38
|
+
**kwargs: Keyword arguments passed to func.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The return value of func(*args, **kwargs).
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
TimeoutError: If execution exceeds timeout.
|
|
45
|
+
RuntimeError: If worker process dies or raises exception.
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def shutdown(self) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Shut down the worker and release resources.
|
|
53
|
+
|
|
54
|
+
Safe to call multiple times. After shutdown, further calls to
|
|
55
|
+
call() will raise RuntimeError.
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def is_alive(self) -> bool:
|
|
61
|
+
"""Check if the worker process is still running."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
def __enter__(self) -> "Worker":
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
68
|
+
self.shutdown()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class WorkerError(Exception):
|
|
72
|
+
"""Exception raised when a worker encounters an error."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, message: str, traceback: Optional[str] = None):
|
|
75
|
+
super().__init__(message)
|
|
76
|
+
self.worker_traceback = traceback
|
|
77
|
+
|
|
78
|
+
def __str__(self):
|
|
79
|
+
msg = super().__str__()
|
|
80
|
+
if self.worker_traceback:
|
|
81
|
+
msg += f"\n\nWorker traceback:\n{self.worker_traceback}"
|
|
82
|
+
return msg
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WorkerPool - Global registry and management of named workers.
|
|
3
|
+
|
|
4
|
+
Provides a simple API for getting workers by name:
|
|
5
|
+
|
|
6
|
+
from comfy_env.workers import get_worker
|
|
7
|
+
|
|
8
|
+
worker = get_worker("sam3d")
|
|
9
|
+
result = worker.call_module("my_module", "my_func", image=tensor)
|
|
10
|
+
|
|
11
|
+
Workers are registered at startup and reused across calls:
|
|
12
|
+
|
|
13
|
+
from comfy_env.workers import register_worker, TorchMPWorker
|
|
14
|
+
|
|
15
|
+
register_worker("default", TorchMPWorker())
|
|
16
|
+
register_worker("sam3d", PersistentVenvWorker(
|
|
17
|
+
python="/path/to/venv/bin/python",
|
|
18
|
+
working_dir="/path/to/nodes",
|
|
19
|
+
))
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import atexit
|
|
23
|
+
import threading
|
|
24
|
+
from typing import Dict, Optional, Union
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
|
|
27
|
+
from .base import Worker
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class WorkerPool:
|
|
31
|
+
"""
|
|
32
|
+
Singleton pool of named workers.
|
|
33
|
+
|
|
34
|
+
Manages worker lifecycle, provides access by name, handles cleanup.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
_instance: Optional["WorkerPool"] = None
|
|
38
|
+
_lock = threading.Lock()
|
|
39
|
+
|
|
40
|
+
def __new__(cls):
|
|
41
|
+
if cls._instance is None:
|
|
42
|
+
with cls._lock:
|
|
43
|
+
if cls._instance is None:
|
|
44
|
+
cls._instance = super().__new__(cls)
|
|
45
|
+
cls._instance._initialized = False
|
|
46
|
+
return cls._instance
|
|
47
|
+
|
|
48
|
+
def __init__(self):
|
|
49
|
+
if self._initialized:
|
|
50
|
+
return
|
|
51
|
+
self._initialized = True
|
|
52
|
+
self._workers: Dict[str, Worker] = {}
|
|
53
|
+
self._factories: Dict[str, callable] = {}
|
|
54
|
+
self._worker_lock = threading.Lock()
|
|
55
|
+
|
|
56
|
+
def register(
|
|
57
|
+
self,
|
|
58
|
+
name: str,
|
|
59
|
+
worker: Optional[Worker] = None,
|
|
60
|
+
factory: Optional[callable] = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Register a worker or worker factory.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
name: Name to register under.
|
|
67
|
+
worker: Pre-created worker instance.
|
|
68
|
+
factory: Callable that creates worker on first use (lazy).
|
|
69
|
+
|
|
70
|
+
Only one of worker or factory should be provided.
|
|
71
|
+
"""
|
|
72
|
+
if worker is not None and factory is not None:
|
|
73
|
+
raise ValueError("Provide either worker or factory, not both")
|
|
74
|
+
if worker is None and factory is None:
|
|
75
|
+
raise ValueError("Must provide worker or factory")
|
|
76
|
+
|
|
77
|
+
with self._worker_lock:
|
|
78
|
+
# Shutdown existing worker if replacing
|
|
79
|
+
if name in self._workers:
|
|
80
|
+
try:
|
|
81
|
+
self._workers[name].shutdown()
|
|
82
|
+
except:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
if worker is not None:
|
|
86
|
+
self._workers[name] = worker
|
|
87
|
+
self._factories.pop(name, None)
|
|
88
|
+
else:
|
|
89
|
+
self._factories[name] = factory
|
|
90
|
+
self._workers.pop(name, None)
|
|
91
|
+
|
|
92
|
+
def get(self, name: str) -> Worker:
|
|
93
|
+
"""
|
|
94
|
+
Get a worker by name.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
name: Registered worker name.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
The worker instance.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
KeyError: If no worker registered with that name.
|
|
104
|
+
"""
|
|
105
|
+
with self._worker_lock:
|
|
106
|
+
# Check for existing worker
|
|
107
|
+
if name in self._workers:
|
|
108
|
+
worker = self._workers[name]
|
|
109
|
+
if worker.is_alive():
|
|
110
|
+
return worker
|
|
111
|
+
# Worker died, try to recreate from factory
|
|
112
|
+
if name not in self._factories:
|
|
113
|
+
raise RuntimeError(f"Worker '{name}' died and no factory to recreate")
|
|
114
|
+
|
|
115
|
+
# Create from factory
|
|
116
|
+
if name in self._factories:
|
|
117
|
+
worker = self._factories[name]()
|
|
118
|
+
self._workers[name] = worker
|
|
119
|
+
return worker
|
|
120
|
+
|
|
121
|
+
raise KeyError(f"No worker registered with name: {name}")
|
|
122
|
+
|
|
123
|
+
def shutdown(self, name: Optional[str] = None) -> None:
|
|
124
|
+
"""
|
|
125
|
+
Shutdown workers.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
name: If provided, shutdown only this worker.
|
|
129
|
+
If None, shutdown all workers.
|
|
130
|
+
"""
|
|
131
|
+
with self._worker_lock:
|
|
132
|
+
if name is not None:
|
|
133
|
+
if name in self._workers:
|
|
134
|
+
try:
|
|
135
|
+
self._workers[name].shutdown()
|
|
136
|
+
except:
|
|
137
|
+
pass
|
|
138
|
+
del self._workers[name]
|
|
139
|
+
else:
|
|
140
|
+
for worker in self._workers.values():
|
|
141
|
+
try:
|
|
142
|
+
worker.shutdown()
|
|
143
|
+
except:
|
|
144
|
+
pass
|
|
145
|
+
self._workers.clear()
|
|
146
|
+
|
|
147
|
+
def list_workers(self) -> Dict[str, str]:
|
|
148
|
+
"""
|
|
149
|
+
List all registered workers.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dict of name -> status string.
|
|
153
|
+
"""
|
|
154
|
+
with self._worker_lock:
|
|
155
|
+
result = {}
|
|
156
|
+
for name, worker in self._workers.items():
|
|
157
|
+
status = "alive" if worker.is_alive() else "dead"
|
|
158
|
+
result[name] = f"{type(worker).__name__} ({status})"
|
|
159
|
+
for name in self._factories:
|
|
160
|
+
if name not in result:
|
|
161
|
+
result[name] = f"factory (not started)"
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# Global pool instance
|
|
166
|
+
_pool = WorkerPool()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_worker(name: str) -> Worker:
|
|
170
|
+
"""
|
|
171
|
+
Get a worker by name from the global pool.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
name: Registered worker name.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Worker instance.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
worker = get_worker("sam3d")
|
|
181
|
+
result = worker.call_module("my_module", "my_func", image=tensor)
|
|
182
|
+
"""
|
|
183
|
+
return _pool.get(name)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def register_worker(
|
|
187
|
+
name: str,
|
|
188
|
+
worker: Optional[Worker] = None,
|
|
189
|
+
factory: Optional[callable] = None,
|
|
190
|
+
) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Register a worker in the global pool.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
name: Name to register under.
|
|
196
|
+
worker: Pre-created worker instance.
|
|
197
|
+
factory: Callable that creates worker on demand.
|
|
198
|
+
|
|
199
|
+
Example:
|
|
200
|
+
# Register pre-created worker
|
|
201
|
+
register_worker("default", TorchMPWorker())
|
|
202
|
+
|
|
203
|
+
# Register factory for lazy creation
|
|
204
|
+
register_worker("sam3d", factory=lambda: PersistentVenvWorker(
|
|
205
|
+
python="/path/to/venv/bin/python",
|
|
206
|
+
))
|
|
207
|
+
"""
|
|
208
|
+
_pool.register(name, worker=worker, factory=factory)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def shutdown_workers(name: Optional[str] = None) -> None:
|
|
212
|
+
"""
|
|
213
|
+
Shutdown workers in the global pool.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
name: If provided, shutdown only this worker.
|
|
217
|
+
If None, shutdown all workers.
|
|
218
|
+
"""
|
|
219
|
+
_pool.shutdown(name)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def list_workers() -> Dict[str, str]:
|
|
223
|
+
"""
|
|
224
|
+
List all registered workers.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Dict of name -> status description.
|
|
228
|
+
"""
|
|
229
|
+
return _pool.list_workers()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# Register default worker (TorchMPWorker) on import
|
|
233
|
+
def _register_default():
|
|
234
|
+
from .torch_mp import TorchMPWorker
|
|
235
|
+
register_worker("default", factory=lambda: TorchMPWorker(name="default"))
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
_register_default()
|
|
239
|
+
|
|
240
|
+
# Cleanup on exit
|
|
241
|
+
atexit.register(lambda: shutdown_workers())
|