comfygit-deploy 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,266 @@
1
+ """Worker CLI command handlers.
2
+
3
+ Commands for setting up and managing the worker server on GPU machines.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import secrets
9
+ from pathlib import Path
10
+
11
+ WORKER_CONFIG_PATH = Path.home() / ".config" / "comfygit" / "deploy" / "worker.json"
12
+
13
+
14
+ def generate_api_key() -> str:
15
+ """Generate a new worker API key."""
16
+ return f"cg_wk_{secrets.token_hex(16)}"
17
+
18
+
19
+ def load_worker_config() -> dict | None:
20
+ """Load worker config from disk."""
21
+ if not WORKER_CONFIG_PATH.exists():
22
+ return None
23
+ try:
24
+ return json.loads(WORKER_CONFIG_PATH.read_text())
25
+ except (json.JSONDecodeError, OSError):
26
+ return None
27
+
28
+
29
+ def get_validated_workspace() -> Path | None:
30
+ """Get workspace path from env or config.
31
+
32
+ Checks COMFYGIT_HOME env first, then falls back to worker config.
33
+
34
+ Returns:
35
+ Workspace Path if configured and exists, None otherwise.
36
+ """
37
+ import os
38
+
39
+ # Check environment variable first (takes precedence)
40
+ env_home = os.environ.get("COMFYGIT_HOME")
41
+ if env_home:
42
+ workspace = Path(env_home)
43
+ if workspace.exists():
44
+ return workspace
45
+
46
+ # Fall back to worker config
47
+ config = load_worker_config()
48
+ if config and config.get("workspace_path"):
49
+ workspace = Path(config["workspace_path"])
50
+ if workspace.exists():
51
+ return workspace
52
+
53
+ return None
54
+
55
+
56
+ def save_worker_config(config: dict) -> None:
57
+ """Save worker config to disk."""
58
+ WORKER_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
59
+ WORKER_CONFIG_PATH.write_text(json.dumps(config, indent=2))
60
+
61
+
62
+ def is_worker_running() -> bool:
63
+ """Check if worker server is running."""
64
+ # Simple check - look for PID file or try to connect
65
+ pid_file = WORKER_CONFIG_PATH.parent / "worker.pid"
66
+ if not pid_file.exists():
67
+ return False
68
+
69
+ try:
70
+ import os
71
+
72
+ pid = int(pid_file.read_text().strip())
73
+ os.kill(pid, 0) # Check if process exists
74
+ return True
75
+ except (OSError, ValueError):
76
+ return False
77
+
78
+
79
+ def handle_setup(args: argparse.Namespace) -> int:
80
+ """Handle 'worker setup' command."""
81
+ workspace = args.workspace or str(Path.home() / "comfygit")
82
+ api_key = args.api_key or generate_api_key()
83
+
84
+ config = {
85
+ "version": "1",
86
+ "api_key": api_key,
87
+ "workspace_path": workspace,
88
+ "default_mode": "docker",
89
+ "server_port": 9090,
90
+ "instance_port_range": {"start": 8200, "end": 8210},
91
+ }
92
+
93
+ save_worker_config(config)
94
+
95
+ print("Worker setup complete!")
96
+ print(f" Workspace: {workspace}")
97
+ print(f" API Key: {api_key}")
98
+ print(f" Config: {WORKER_CONFIG_PATH}")
99
+ print()
100
+ print("To start the worker server:")
101
+ print(" cg-deploy worker up")
102
+
103
+ return 0
104
+
105
+
106
+ def handle_up(args: argparse.Namespace) -> int:
107
+ """Handle 'worker up' command."""
108
+ import os
109
+
110
+ config = load_worker_config()
111
+ if not config:
112
+ print("Worker not configured. Run 'cg-deploy worker setup' first.")
113
+ return 1
114
+
115
+ # Parse port range
116
+ port_range = args.port_range.split(":")
117
+ port_start = int(port_range[0])
118
+ port_end = int(port_range[1]) if len(port_range) > 1 else port_start + 10
119
+
120
+ # Dev mode: explicit paths override saved config
121
+ dev_core = getattr(args, "dev_core", None)
122
+ dev_manager = getattr(args, "dev_manager", None)
123
+
124
+ # --dev flag loads saved config for any missing paths
125
+ if getattr(args, "dev", False):
126
+ from .dev import load_dev_config
127
+ dev_config = load_dev_config()
128
+ dev_core = dev_core or dev_config.get("core_path")
129
+ dev_manager = dev_manager or dev_config.get("manager_path")
130
+
131
+ if dev_core:
132
+ dev_core = str(Path(dev_core).resolve())
133
+ os.environ["COMFYGIT_DEV_CORE_PATH"] = dev_core
134
+ print(f"Dev mode: core -> {dev_core}")
135
+
136
+ if dev_manager:
137
+ dev_manager = str(Path(dev_manager).resolve())
138
+ # Symlink manager to system_nodes
139
+ workspace = Path(config["workspace_path"])
140
+ system_nodes = workspace / ".metadata" / "system_nodes"
141
+ system_nodes.mkdir(parents=True, exist_ok=True)
142
+ manager_link = system_nodes / "comfygit-manager"
143
+
144
+ if manager_link.is_symlink():
145
+ manager_link.unlink()
146
+ elif manager_link.is_dir():
147
+ import shutil
148
+ shutil.rmtree(manager_link)
149
+
150
+ manager_link.symlink_to(dev_manager)
151
+ print(f"Dev mode: manager -> {dev_manager}")
152
+
153
+ print(f"Starting worker server on {args.host}:{args.port}...")
154
+ print(f" Mode: {args.mode}")
155
+ print(f" Instance ports: {port_start}-{port_end}")
156
+ print(f" Broadcast: {args.broadcast}")
157
+ print()
158
+ print("Press Ctrl+C to stop.")
159
+
160
+ from aiohttp import web
161
+
162
+ from ..worker.server import create_worker_app
163
+
164
+ app = create_worker_app(
165
+ api_key=config["api_key"],
166
+ workspace_path=Path(config["workspace_path"]),
167
+ default_mode=args.mode,
168
+ port_range_start=port_start,
169
+ port_range_end=port_end,
170
+ )
171
+
172
+ # Save PID file
173
+ pid_file = WORKER_CONFIG_PATH.parent / "worker.pid"
174
+ import os
175
+
176
+ pid_file.write_text(str(os.getpid()))
177
+
178
+ # Start mDNS broadcast if requested
179
+ broadcaster = None
180
+ if args.broadcast:
181
+ from ..worker.mdns import MDNSBroadcaster
182
+
183
+ broadcaster = MDNSBroadcaster(port=args.port, mode=args.mode)
184
+ broadcaster.start()
185
+
186
+ try:
187
+ web.run_app(app, host=args.host, port=args.port, print=lambda _: None)
188
+ finally:
189
+ if broadcaster:
190
+ broadcaster.stop()
191
+ pid_file.unlink(missing_ok=True)
192
+
193
+ return 0
194
+
195
+
196
+ def handle_down(args: argparse.Namespace) -> int:
197
+ """Handle 'worker down' command."""
198
+ pid_file = WORKER_CONFIG_PATH.parent / "worker.pid"
199
+
200
+ if not pid_file.exists():
201
+ print("Worker server is not running.")
202
+ return 0
203
+
204
+ try:
205
+ import os
206
+ import signal
207
+
208
+ pid = int(pid_file.read_text().strip())
209
+ os.kill(pid, signal.SIGTERM)
210
+ pid_file.unlink(missing_ok=True)
211
+ print("Worker server stopped.")
212
+ except (OSError, ValueError) as e:
213
+ print(f"Failed to stop worker: {e}")
214
+ pid_file.unlink(missing_ok=True)
215
+ return 1
216
+
217
+ return 0
218
+
219
+
220
+ def handle_status(args: argparse.Namespace) -> int:
221
+ """Handle 'worker status' command."""
222
+ config = load_worker_config()
223
+
224
+ if not config:
225
+ print("Worker not configured. Run 'cg-deploy worker setup' first.")
226
+ return 0
227
+
228
+ running = is_worker_running()
229
+ status = "RUNNING" if running else "NOT RUNNING"
230
+
231
+ print(f"Worker Status: {status}")
232
+ print(f" Workspace: {config.get('workspace_path', 'N/A')}")
233
+ print(f" Default Mode: {config.get('default_mode', 'docker')}")
234
+ print(f" Server Port: {config.get('server_port', 9090)}")
235
+
236
+ port_range = config.get("instance_port_range", {})
237
+ print(
238
+ f" Instance Ports: {port_range.get('start', 8200)}-{port_range.get('end', 8210)}"
239
+ )
240
+
241
+ return 0
242
+
243
+
244
+ def handle_regenerate_key(args: argparse.Namespace) -> int:
245
+ """Handle 'worker regenerate-key' command."""
246
+ config = load_worker_config()
247
+ if not config:
248
+ print("Worker not configured. Run 'cg-deploy worker setup' first.")
249
+ return 1
250
+
251
+ old_key = config.get("api_key", "")[:20] + "..."
252
+ new_key = generate_api_key()
253
+
254
+ config["api_key"] = new_key
255
+ save_worker_config(config)
256
+
257
+ print("API key regenerated!")
258
+ print(f" Old: {old_key}")
259
+ print(f" New: {new_key}")
260
+ print()
261
+ print("Note: Update any clients using the old key.")
262
+
263
+ if is_worker_running():
264
+ print("Warning: Restart the worker server for the new key to take effect.")
265
+
266
+ return 0
@@ -0,0 +1,122 @@
1
+ """Configuration storage for comfygit-deploy.
2
+
3
+ Stores RunPod API keys and custom worker registry in ~/.config/comfygit/deploy/config.json
4
+ """
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+
12
+ def _get_default_config_path() -> Path:
13
+ """Get default config path, respecting HOME environment variable."""
14
+ home = os.environ.get("HOME") or str(Path.home())
15
+ return Path(home) / ".config" / "comfygit" / "deploy" / "config.json"
16
+
17
+
18
+ class DeployConfig:
19
+ """Configuration storage for deploy CLI.
20
+
21
+ Stores:
22
+ - RunPod API key
23
+ - Custom worker registry (name -> {host, port, api_key, ...})
24
+ """
25
+
26
+ def __init__(self, path: Path | None = None):
27
+ """Initialize config.
28
+
29
+ Args:
30
+ path: Config file path. Defaults to ~/.config/comfygit/deploy/config.json
31
+ """
32
+ self.path = path or _get_default_config_path()
33
+ self._data: dict[str, Any] = {"version": "1", "providers": {}, "workers": {}}
34
+ self._load()
35
+
36
+ def _load(self) -> None:
37
+ """Load config from disk if it exists."""
38
+ if self.path.exists():
39
+ try:
40
+ self._data = json.loads(self.path.read_text())
41
+ except (json.JSONDecodeError, OSError):
42
+ pass # Use defaults
43
+
44
+ def save(self) -> None:
45
+ """Save config to disk."""
46
+ self.path.parent.mkdir(parents=True, exist_ok=True)
47
+ self.path.write_text(json.dumps(self._data, indent=2))
48
+
49
+ @property
50
+ def runpod_api_key(self) -> str | None:
51
+ """Get RunPod API key."""
52
+ return self._data.get("providers", {}).get("runpod", {}).get("api_key")
53
+
54
+ @runpod_api_key.setter
55
+ def runpod_api_key(self, value: str | None) -> None:
56
+ """Set RunPod API key."""
57
+ if "providers" not in self._data:
58
+ self._data["providers"] = {}
59
+ if "runpod" not in self._data["providers"]:
60
+ self._data["providers"]["runpod"] = {}
61
+
62
+ if value is None:
63
+ self._data["providers"]["runpod"].pop("api_key", None)
64
+ else:
65
+ self._data["providers"]["runpod"]["api_key"] = value
66
+
67
+ @property
68
+ def workers(self) -> dict[str, dict[str, Any]]:
69
+ """Get custom workers registry."""
70
+ return self._data.get("workers", {})
71
+
72
+ def add_worker(
73
+ self,
74
+ name: str,
75
+ host: str,
76
+ port: int,
77
+ api_key: str,
78
+ mode: str = "docker",
79
+ ) -> None:
80
+ """Add a custom worker to the registry.
81
+
82
+ Args:
83
+ name: Worker name (unique identifier)
84
+ host: Worker host/IP
85
+ port: Worker API port
86
+ api_key: Worker API key
87
+ mode: Worker mode (docker or native)
88
+ """
89
+ if "workers" not in self._data:
90
+ self._data["workers"] = {}
91
+
92
+ self._data["workers"][name] = {
93
+ "host": host,
94
+ "port": port,
95
+ "api_key": api_key,
96
+ "mode": mode,
97
+ }
98
+
99
+ def remove_worker(self, name: str) -> bool:
100
+ """Remove a worker from the registry.
101
+
102
+ Args:
103
+ name: Worker name
104
+
105
+ Returns:
106
+ True if worker was removed, False if not found
107
+ """
108
+ if name in self._data.get("workers", {}):
109
+ del self._data["workers"][name]
110
+ return True
111
+ return False
112
+
113
+ def get_worker(self, name: str) -> dict[str, Any] | None:
114
+ """Get a worker by name.
115
+
116
+ Args:
117
+ name: Worker name
118
+
119
+ Returns:
120
+ Worker config dict or None if not found
121
+ """
122
+ return self._data.get("workers", {}).get(name)
@@ -0,0 +1,11 @@
1
+ """Provider clients for deployment backends."""
2
+
3
+ from .custom import CustomWorkerClient, CustomWorkerError
4
+ from .runpod import RunPodAPIError, RunPodClient
5
+
6
+ __all__ = [
7
+ "RunPodClient",
8
+ "RunPodAPIError",
9
+ "CustomWorkerClient",
10
+ "CustomWorkerError",
11
+ ]
@@ -0,0 +1,238 @@
1
+ """Custom worker HTTP client for connecting to self-hosted workers.
2
+
3
+ Provides async interface for worker server REST API.
4
+ """
5
+
6
+ import json
7
+ from collections.abc import AsyncIterator
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ import aiohttp
12
+
13
+
14
+ @dataclass
15
+ class CustomWorkerError(Exception):
16
+ """Error from custom worker API."""
17
+
18
+ message: str
19
+ status_code: int
20
+
21
+ def __str__(self) -> str:
22
+ return f"Worker Error ({self.status_code}): {self.message}"
23
+
24
+
25
+ @dataclass
26
+ class LogEntry:
27
+ """A single log entry from streaming."""
28
+
29
+ timestamp: str
30
+ level: str
31
+ message: str
32
+
33
+
34
+ class CustomWorkerClient:
35
+ """Async client for custom worker REST API."""
36
+
37
+ def __init__(self, host: str, port: int, api_key: str):
38
+ """Initialize client.
39
+
40
+ Args:
41
+ host: Worker host/IP
42
+ port: Worker API port
43
+ api_key: Worker API key
44
+ """
45
+ self.host = host
46
+ self.port = port
47
+ self.api_key = api_key
48
+ self.base_url = f"http://{host}:{port}"
49
+
50
+ def _headers(self) -> dict[str, str]:
51
+ """Get request headers with authorization."""
52
+ return {
53
+ "Authorization": f"Bearer {self.api_key}",
54
+ "Content-Type": "application/json",
55
+ }
56
+
57
+ async def _get(self, path: str) -> Any:
58
+ """Make GET request and return JSON response."""
59
+ async with aiohttp.ClientSession() as session:
60
+ async with session.get(
61
+ f"{self.base_url}{path}",
62
+ headers=self._headers(),
63
+ ) as response:
64
+ if response.status >= 400:
65
+ await self._handle_error(response)
66
+ return await response.json()
67
+
68
+ async def _post(self, path: str, data: dict | None = None) -> Any:
69
+ """Make POST request and return JSON response."""
70
+ async with aiohttp.ClientSession() as session:
71
+ async with session.post(
72
+ f"{self.base_url}{path}",
73
+ json=data,
74
+ headers=self._headers(),
75
+ ) as response:
76
+ if response.status >= 400:
77
+ await self._handle_error(response)
78
+ return await response.json()
79
+
80
+ async def _delete(self, path: str) -> Any:
81
+ """Make DELETE request and return JSON response."""
82
+ async with aiohttp.ClientSession() as session:
83
+ async with session.delete(
84
+ f"{self.base_url}{path}",
85
+ headers=self._headers(),
86
+ ) as response:
87
+ if response.status >= 400:
88
+ await self._handle_error(response)
89
+ return await response.json()
90
+
91
+ async def _handle_error(self, response: aiohttp.ClientResponse) -> None:
92
+ """Handle error response."""
93
+ try:
94
+ error_body = await response.json()
95
+ message = error_body.get("error", str(error_body))
96
+ except Exception:
97
+ message = await response.text() or f"HTTP {response.status}"
98
+ raise CustomWorkerError(message, response.status)
99
+
100
+ async def test_connection(self) -> dict[str, Any]:
101
+ """Test connection to worker.
102
+
103
+ Returns:
104
+ {"success": True, "worker_version": "..."} on success
105
+ {"success": False, "error": "..."} on failure
106
+ """
107
+ try:
108
+ health = await self._get("/api/v1/health")
109
+ return {
110
+ "success": True,
111
+ "worker_version": health.get("worker_version"),
112
+ }
113
+ except CustomWorkerError as e:
114
+ return {"success": False, "error": e.message}
115
+ except Exception as e:
116
+ return {"success": False, "error": str(e)}
117
+
118
+ async def get_system_info(self) -> dict[str, Any]:
119
+ """Get worker system information."""
120
+ return await self._get("/api/v1/system/info")
121
+
122
+ async def list_instances(self) -> list[dict]:
123
+ """List all instances on worker."""
124
+ result = await self._get("/api/v1/instances")
125
+ return result.get("instances", [])
126
+
127
+ async def create_instance(
128
+ self,
129
+ import_source: str,
130
+ name: str | None = None,
131
+ branch: str | None = None,
132
+ mode: str | None = None,
133
+ ) -> dict:
134
+ """Create new instance.
135
+
136
+ Args:
137
+ import_source: Git URL or local path
138
+ name: Optional instance name
139
+ branch: Optional git branch
140
+ mode: docker or native
141
+
142
+ Returns:
143
+ Instance data
144
+ """
145
+ data = {"import_source": import_source}
146
+ if name:
147
+ data["name"] = name
148
+ if branch:
149
+ data["branch"] = branch
150
+ if mode:
151
+ data["mode"] = mode
152
+
153
+ return await self._post("/api/v1/instances", data)
154
+
155
+ async def get_instance(self, instance_id: str) -> dict:
156
+ """Get instance details."""
157
+ return await self._get(f"/api/v1/instances/{instance_id}")
158
+
159
+ async def stop_instance(self, instance_id: str) -> dict:
160
+ """Stop a running instance."""
161
+ return await self._post(f"/api/v1/instances/{instance_id}/stop")
162
+
163
+ async def start_instance(self, instance_id: str) -> dict:
164
+ """Start a stopped instance."""
165
+ return await self._post(f"/api/v1/instances/{instance_id}/start")
166
+
167
+ async def terminate_instance(self, instance_id: str, keep_env: bool = False) -> dict:
168
+ """Terminate and remove instance."""
169
+ path = f"/api/v1/instances/{instance_id}"
170
+ if keep_env:
171
+ path += "?keep_env=true"
172
+ return await self._delete(path)
173
+
174
+ async def get_logs(self, instance_id: str, lines: int = 100) -> list[dict]:
175
+ """Get recent logs for an instance.
176
+
177
+ Args:
178
+ instance_id: Instance ID
179
+ lines: Number of lines to fetch
180
+
181
+ Returns:
182
+ List of log entries
183
+ """
184
+ result = await self._get(f"/api/v1/instances/{instance_id}/logs?lines={lines}")
185
+ return result.get("logs", [])
186
+
187
+ def _connect_ws(self, url: str):
188
+ """Create a WebSocket connection context manager.
189
+
190
+ Returns an async context manager that yields the WebSocket connection.
191
+ """
192
+ return _WSConnectionManager(url, self.api_key)
193
+
194
+ async def stream_logs(self, instance_id: str) -> AsyncIterator[LogEntry]:
195
+ """Stream logs from an instance via WebSocket.
196
+
197
+ Args:
198
+ instance_id: Instance ID
199
+
200
+ Yields:
201
+ LogEntry objects as they arrive
202
+ """
203
+ url = f"ws://{self.host}:{self.port}/api/v1/instances/{instance_id}/logs"
204
+ async with self._connect_ws(url) as ws:
205
+ async for msg in ws:
206
+ if msg.type == aiohttp.WSMsgType.TEXT:
207
+ data = json.loads(msg.data)
208
+ if data.get("type") == "log":
209
+ yield LogEntry(
210
+ timestamp=data.get("timestamp", ""),
211
+ level=data.get("level", "INFO"),
212
+ message=data.get("message", ""),
213
+ )
214
+ elif msg.type == aiohttp.WSMsgType.ERROR:
215
+ break
216
+
217
+
218
+ class _WSConnectionManager:
219
+ """Async context manager for WebSocket connections."""
220
+
221
+ def __init__(self, url: str, api_key: str):
222
+ self.url = url
223
+ self.api_key = api_key
224
+ self._session: aiohttp.ClientSession | None = None
225
+ self._ws = None
226
+
227
+ async def __aenter__(self):
228
+ self._session = aiohttp.ClientSession()
229
+ self._ws = await self._session.ws_connect(
230
+ self.url, headers={"Authorization": f"Bearer {self.api_key}"}
231
+ )
232
+ return self._ws
233
+
234
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
235
+ if self._ws:
236
+ await self._ws.close()
237
+ if self._session:
238
+ await self._session.close()