more-compute 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontend/app/globals.css +734 -27
- frontend/app/layout.tsx +13 -3
- frontend/components/Notebook.tsx +2 -14
- frontend/components/cell/MonacoCell.tsx +99 -5
- frontend/components/layout/Sidebar.tsx +39 -4
- frontend/components/panels/ClaudePanel.tsx +461 -0
- frontend/components/popups/ComputePopup.tsx +738 -447
- frontend/components/popups/FilterPopup.tsx +305 -189
- frontend/components/popups/MetricsPopup.tsx +20 -1
- frontend/components/popups/ProviderConfigModal.tsx +322 -0
- frontend/components/popups/ProviderDropdown.tsx +398 -0
- frontend/components/popups/SettingsPopup.tsx +1 -1
- frontend/contexts/ClaudeContext.tsx +392 -0
- frontend/contexts/PodWebSocketContext.tsx +16 -21
- frontend/hooks/useInlineDiff.ts +269 -0
- frontend/lib/api.ts +323 -12
- frontend/lib/settings.ts +5 -0
- frontend/lib/websocket-native.ts +4 -8
- frontend/lib/websocket.ts +1 -2
- frontend/package-lock.json +733 -36
- frontend/package.json +2 -0
- frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
- frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
- frontend/public/assets/icons/providers/runpod.svg +9 -0
- frontend/public/assets/icons/providers/vastai.svg +1 -0
- frontend/settings.md +54 -0
- frontend/tsconfig.tsbuildinfo +1 -0
- frontend/types/claude.ts +194 -0
- kernel_run.py +13 -0
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
- morecompute/__init__.py +1 -1
- morecompute/__version__.py +1 -1
- morecompute/execution/executor.py +24 -67
- morecompute/execution/worker.py +6 -72
- morecompute/models/api_models.py +62 -0
- morecompute/notebook.py +11 -0
- morecompute/server.py +641 -133
- morecompute/services/claude_service.py +392 -0
- morecompute/services/pod_manager.py +168 -67
- morecompute/services/pod_monitor.py +67 -39
- morecompute/services/prime_intellect.py +0 -4
- morecompute/services/providers/__init__.py +92 -0
- morecompute/services/providers/base_provider.py +336 -0
- morecompute/services/providers/lambda_labs_provider.py +394 -0
- morecompute/services/providers/provider_factory.py +194 -0
- morecompute/services/providers/runpod_provider.py +504 -0
- morecompute/services/providers/vastai_provider.py +407 -0
- morecompute/utils/cell_magics.py +0 -3
- morecompute/utils/config_util.py +93 -3
- morecompute/utils/special_commands.py +5 -32
- morecompute/utils/version_check.py +117 -0
- frontend/styling_README.md +0 -23
- {more_compute-0.4.4.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,35 +1,44 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import subprocess
|
|
3
3
|
import os
|
|
4
|
-
import sys
|
|
5
4
|
import tempfile
|
|
6
5
|
import tarfile
|
|
7
6
|
from pathlib import Path
|
|
8
7
|
from typing import TYPE_CHECKING
|
|
9
8
|
|
|
10
|
-
from .
|
|
9
|
+
from .providers.base_provider import BaseGPUProvider
|
|
10
|
+
from ..models.api_models import PodResponse
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from ..execution.executor import NextZmqExecutor
|
|
14
14
|
|
|
15
|
+
# Type alias for supported provider services
|
|
16
|
+
ProviderService = BaseGPUProvider
|
|
17
|
+
|
|
18
|
+
|
|
15
19
|
class PodKernelManager:
|
|
16
20
|
"""
|
|
17
|
-
Manages remote GPU pod connections
|
|
18
|
-
|
|
21
|
+
Manages remote GPU pod connections and SSH tunnels for ZMQ execution.
|
|
22
|
+
|
|
23
|
+
Supports multiple GPU providers:
|
|
24
|
+
- RunPod
|
|
25
|
+
- Lambda Labs
|
|
26
|
+
- Vast.ai
|
|
19
27
|
"""
|
|
20
|
-
|
|
28
|
+
provider_service: ProviderService
|
|
29
|
+
provider_type: str
|
|
21
30
|
pod: PodResponse | None
|
|
22
31
|
ssh_tunnel_proc: subprocess.Popen[bytes] | None
|
|
23
32
|
local_cmd_port: int
|
|
24
33
|
local_pub_port: int
|
|
25
|
-
remote_cmd_port
|
|
34
|
+
remote_cmd_port: int
|
|
26
35
|
remote_pub_port: int
|
|
27
36
|
executor: "NextZmqExecutor | None"
|
|
28
37
|
_ssh_key_cache: str | None
|
|
29
38
|
|
|
30
39
|
def __init__(
|
|
31
40
|
self,
|
|
32
|
-
|
|
41
|
+
provider_service: ProviderService,
|
|
33
42
|
local_cmd_port: int = 15555,
|
|
34
43
|
local_pub_port: int = 15556,
|
|
35
44
|
remote_cmd_port: int = 5555,
|
|
@@ -39,13 +48,15 @@ class PodKernelManager:
|
|
|
39
48
|
Initialize pod manager
|
|
40
49
|
|
|
41
50
|
args:
|
|
42
|
-
|
|
51
|
+
provider_service: GPU provider service implementing BaseGPUProvider
|
|
43
52
|
local_cmd_port: Local port for REQ/REP tunnel
|
|
44
53
|
local_pub_port: Local port for PUB/SUB tunnel
|
|
45
54
|
remote_cmd_port: Remote port for REQ/REP socket
|
|
46
55
|
remote_pub_port: Remote port for PUB/SUB socket
|
|
47
56
|
"""
|
|
48
|
-
self.
|
|
57
|
+
self.provider_service = provider_service
|
|
58
|
+
self.provider_type = getattr(provider_service, 'PROVIDER_NAME', 'unknown')
|
|
59
|
+
|
|
49
60
|
self.pod = None
|
|
50
61
|
self.ssh_tunnel_proc = None
|
|
51
62
|
self.local_cmd_port = local_cmd_port
|
|
@@ -71,9 +82,8 @@ class PodKernelManager:
|
|
|
71
82
|
self._ssh_key_cache = expanded
|
|
72
83
|
return expanded
|
|
73
84
|
|
|
74
|
-
# Try common SSH key paths
|
|
85
|
+
# Try common SSH key paths
|
|
75
86
|
common_keys = [
|
|
76
|
-
"~/.ssh/primeintellect_ed25519", # Prime Intellect's recommended name
|
|
77
87
|
"~/.ssh/id_ed25519",
|
|
78
88
|
"~/.ssh/id_rsa",
|
|
79
89
|
]
|
|
@@ -85,6 +95,84 @@ class PodKernelManager:
|
|
|
85
95
|
|
|
86
96
|
return None
|
|
87
97
|
|
|
98
|
+
def _is_key_encrypted(self, key_path: str) -> bool:
|
|
99
|
+
"""Check if an SSH private key is encrypted with a passphrase."""
|
|
100
|
+
try:
|
|
101
|
+
with open(key_path, 'r') as f:
|
|
102
|
+
content = f.read(500) # Read first 500 bytes
|
|
103
|
+
# OpenSSH encrypted keys contain these markers
|
|
104
|
+
return 'aes256-ctr' in content or 'aes128-ctr' in content or 'bcrypt' in content
|
|
105
|
+
except Exception:
|
|
106
|
+
return False
|
|
107
|
+
|
|
108
|
+
def _is_key_in_agent(self, key_path: str) -> bool:
|
|
109
|
+
"""Check if the SSH key is loaded in the ssh-agent."""
|
|
110
|
+
try:
|
|
111
|
+
result = subprocess.run(
|
|
112
|
+
["ssh-add", "-l"],
|
|
113
|
+
capture_output=True,
|
|
114
|
+
text=True,
|
|
115
|
+
timeout=5
|
|
116
|
+
)
|
|
117
|
+
if result.returncode != 0:
|
|
118
|
+
return False
|
|
119
|
+
# Check if the key fingerprint is in the agent
|
|
120
|
+
# Get the fingerprint of our key
|
|
121
|
+
fp_result = subprocess.run(
|
|
122
|
+
["ssh-keygen", "-lf", key_path],
|
|
123
|
+
capture_output=True,
|
|
124
|
+
text=True,
|
|
125
|
+
timeout=5
|
|
126
|
+
)
|
|
127
|
+
if fp_result.returncode == 0:
|
|
128
|
+
# Extract fingerprint (e.g., SHA256:xxx)
|
|
129
|
+
parts = fp_result.stdout.split()
|
|
130
|
+
if len(parts) >= 2:
|
|
131
|
+
fingerprint = parts[1]
|
|
132
|
+
return fingerprint in result.stdout
|
|
133
|
+
return False
|
|
134
|
+
except Exception:
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
def _get_ssh_setup_instructions(self) -> str:
|
|
138
|
+
"""Get provider-specific SSH setup instructions."""
|
|
139
|
+
provider_instructions = {
|
|
140
|
+
"runpod": (
|
|
141
|
+
"SSH authentication failed. Please add your SSH public key to RunPod:\n"
|
|
142
|
+
"1. Visit https://www.runpod.io/console/user/settings\n"
|
|
143
|
+
"2. Go to 'SSH Public Keys' section\n"
|
|
144
|
+
"3. Add your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
|
|
145
|
+
"4. Try connecting again"
|
|
146
|
+
),
|
|
147
|
+
"lambda_labs": (
|
|
148
|
+
"SSH authentication failed. Lambda Labs SSH key mismatch.\n\n"
|
|
149
|
+
"IMPORTANT: Lambda Labs assigns SSH keys at instance creation time.\n"
|
|
150
|
+
"Your current instance may have been created with a different key.\n\n"
|
|
151
|
+
"To fix this:\n"
|
|
152
|
+
"1. Terminate the current instance\n"
|
|
153
|
+
"2. Visit https://cloud.lambdalabs.com/ssh-keys\n"
|
|
154
|
+
"3. Add your public key (~/.ssh/id_ed25519.pub) if not already there\n"
|
|
155
|
+
"4. Create a new instance - it will use your registered key\n\n"
|
|
156
|
+
"To view your public key, run: cat ~/.ssh/id_ed25519.pub"
|
|
157
|
+
),
|
|
158
|
+
"vastai": (
|
|
159
|
+
"SSH authentication failed. Please add your SSH public key to Vast.ai:\n"
|
|
160
|
+
"1. Visit https://cloud.vast.ai/account/\n"
|
|
161
|
+
"2. Go to 'SSH Keys' section\n"
|
|
162
|
+
"3. Add your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
|
|
163
|
+
"4. Try connecting again"
|
|
164
|
+
),
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
return provider_instructions.get(
|
|
168
|
+
self.provider_type,
|
|
169
|
+
(
|
|
170
|
+
"SSH authentication failed. Please ensure your SSH public key is added to your GPU provider.\n"
|
|
171
|
+
"Upload your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub) to your provider's dashboard.\n"
|
|
172
|
+
"Then try connecting again."
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
|
|
88
176
|
async def connect_to_pod(self, pod_id:str) -> dict[str, object]:
|
|
89
177
|
"""
|
|
90
178
|
Connects to existing pod and set up ssh tunnel
|
|
@@ -94,8 +182,6 @@ class PodKernelManager:
|
|
|
94
182
|
Response:
|
|
95
183
|
dict with connection status
|
|
96
184
|
"""
|
|
97
|
-
import sys
|
|
98
|
-
|
|
99
185
|
# Check if already connected to this pod
|
|
100
186
|
if self.pod and self.pod.id == pod_id:
|
|
101
187
|
# Check if tunnel is still alive
|
|
@@ -105,18 +191,13 @@ class PodKernelManager:
|
|
|
105
191
|
"message": f"Already connected to pod {pod_id}"
|
|
106
192
|
}
|
|
107
193
|
# Tunnel died, clean up and reconnect
|
|
108
|
-
print(f"[POD MANAGER] Existing tunnel dead, reconnecting...", file=sys.stderr, flush=True)
|
|
109
194
|
await self.disconnect()
|
|
110
195
|
|
|
111
196
|
# If connected to different pod, disconnect first
|
|
112
197
|
if self.pod and self.pod.id != pod_id:
|
|
113
|
-
print(f"[POD MANAGER] Disconnecting from {self.pod.id} to connect to {pod_id}", file=sys.stderr, flush=True)
|
|
114
198
|
await self.disconnect()
|
|
115
199
|
|
|
116
|
-
self.pod = await self.
|
|
117
|
-
|
|
118
|
-
print(f"[POD MANAGER] Pod status: {self.pod.status}", file=sys.stderr, flush=True)
|
|
119
|
-
print(f"[POD MANAGER] SSH connection: {self.pod.sshConnection}", file=sys.stderr, flush=True)
|
|
200
|
+
self.pod = await self.provider_service.get_pod(pod_id)
|
|
120
201
|
|
|
121
202
|
if not self.pod.sshConnection:
|
|
122
203
|
return{
|
|
@@ -132,7 +213,7 @@ class PodKernelManager:
|
|
|
132
213
|
}
|
|
133
214
|
|
|
134
215
|
# Parse SSH connection string
|
|
135
|
-
# Format can be: "ssh root@ip -p port" OR "root@ip -p port"
|
|
216
|
+
# Format can be: "ssh root@ip -p port" OR "root@ip -p port" OR "ssh ubuntu@ip"
|
|
136
217
|
ssh_parts = self.pod.sshConnection.split()
|
|
137
218
|
|
|
138
219
|
# Find the part containing @ (user@host)
|
|
@@ -148,8 +229,8 @@ class PodKernelManager:
|
|
|
148
229
|
"message": f"Invalid SSH connection format (no user@host found): {self.pod.sshConnection}"
|
|
149
230
|
}
|
|
150
231
|
|
|
151
|
-
# Extract host from user@host
|
|
152
|
-
ssh_host = host_part.split("@")
|
|
232
|
+
# Extract user and host from user@host
|
|
233
|
+
ssh_user, ssh_host = host_part.split("@")
|
|
153
234
|
ssh_port = "22"
|
|
154
235
|
|
|
155
236
|
# Extract port if specified with -p flag
|
|
@@ -158,33 +239,22 @@ class PodKernelManager:
|
|
|
158
239
|
if port_idx + 1 < len(ssh_parts):
|
|
159
240
|
ssh_port = ssh_parts[port_idx + 1]
|
|
160
241
|
|
|
161
|
-
print(f"[POD MANAGER] Parsed SSH host: {ssh_host}, port: {ssh_port}", file=sys.stderr, flush=True)
|
|
162
|
-
|
|
163
242
|
#deploy worker code to pod
|
|
164
|
-
|
|
165
|
-
deploy_result = await self._deploy_worker(ssh_host, ssh_port)
|
|
166
|
-
print(f"[POD MANAGER] Deploy result: {deploy_result}", file=sys.stderr, flush=True)
|
|
243
|
+
deploy_result = await self._deploy_worker(ssh_user, ssh_host, ssh_port)
|
|
167
244
|
if deploy_result.get("status") == "error":
|
|
168
245
|
return deploy_result
|
|
169
246
|
|
|
170
247
|
#create ssh tunnel for ZMQ ports
|
|
171
|
-
|
|
172
|
-
tunnel_result = await self._create_ssh_tunnel(ssh_host, ssh_port)
|
|
173
|
-
print(f"[POD MANAGER] Tunnel result: {tunnel_result}", file=sys.stderr, flush=True)
|
|
248
|
+
tunnel_result = await self._create_ssh_tunnel(ssh_user, ssh_host, ssh_port)
|
|
174
249
|
if tunnel_result.get("status") == "error":
|
|
175
250
|
return tunnel_result
|
|
176
251
|
|
|
177
252
|
#start remote worker
|
|
178
|
-
worker_result = await self._start_remote_worker(ssh_host, ssh_port)
|
|
253
|
+
worker_result = await self._start_remote_worker(ssh_user, ssh_host, ssh_port)
|
|
179
254
|
if worker_result.get("status") == "error":
|
|
180
255
|
await self.disconnect()
|
|
181
256
|
return worker_result
|
|
182
257
|
|
|
183
|
-
# Note: Worker may take a few seconds to start and install matplotlib
|
|
184
|
-
# The connection should work even if verification fails
|
|
185
|
-
print(f"[POD MANAGER] Remote worker is starting (matplotlib install may take a few seconds)", file=sys.stderr, flush=True)
|
|
186
|
-
print(f"[POD MANAGER] Connection established - try running code in ~5 seconds", file=sys.stderr, flush=True)
|
|
187
|
-
|
|
188
258
|
return {
|
|
189
259
|
"status": "ok",
|
|
190
260
|
"message": f"Connected to pod {pod_id}",
|
|
@@ -196,11 +266,12 @@ class PodKernelManager:
|
|
|
196
266
|
}
|
|
197
267
|
}
|
|
198
268
|
|
|
199
|
-
async def _deploy_worker(self, ssh_host: str, ssh_port: str) -> dict[str,object]:
|
|
269
|
+
async def _deploy_worker(self, ssh_user: str, ssh_host: str, ssh_port: str) -> dict[str,object]:
|
|
200
270
|
"""
|
|
201
271
|
Deploy worker code to remote pod via Secure Copy Protocol.
|
|
202
272
|
|
|
203
273
|
args:
|
|
274
|
+
ssh_user: SSH username (e.g., 'root' or 'ubuntu')
|
|
204
275
|
ssh_host: SSH host address
|
|
205
276
|
ssh_port: SSH port
|
|
206
277
|
|
|
@@ -208,6 +279,21 @@ class PodKernelManager:
|
|
|
208
279
|
dict with deployment status
|
|
209
280
|
"""
|
|
210
281
|
try:
|
|
282
|
+
# Check if SSH key is encrypted and not in agent
|
|
283
|
+
ssh_key = self._get_ssh_key()
|
|
284
|
+
if ssh_key and self._is_key_encrypted(ssh_key):
|
|
285
|
+
if not self._is_key_in_agent(ssh_key):
|
|
286
|
+
key_name = os.path.basename(ssh_key)
|
|
287
|
+
return {
|
|
288
|
+
"status": "error",
|
|
289
|
+
"message": (
|
|
290
|
+
f"Your SSH key ({key_name}) is protected with a passphrase but not loaded in ssh-agent.\n\n"
|
|
291
|
+
f"To fix this, run:\n"
|
|
292
|
+
f" ssh-add {ssh_key}\n\n"
|
|
293
|
+
f"Enter your passphrase when prompted, then try connecting again."
|
|
294
|
+
)
|
|
295
|
+
}
|
|
296
|
+
|
|
211
297
|
# Create temporary tarball of morecompute package
|
|
212
298
|
project_root = Path(__file__).parent.parent.parent
|
|
213
299
|
morecompute_dir = project_root / "morecompute"
|
|
@@ -231,7 +317,7 @@ class PodKernelManager:
|
|
|
231
317
|
"-o", "BatchMode=yes", # Prevent password prompts, fail fast if key auth doesn't work
|
|
232
318
|
"-o", "ConnectTimeout=10",
|
|
233
319
|
tmp_path,
|
|
234
|
-
f"
|
|
320
|
+
f"{ssh_user}@{ssh_host}:/tmp/morecompute.tar.gz"
|
|
235
321
|
])
|
|
236
322
|
|
|
237
323
|
result = subprocess.run(
|
|
@@ -244,14 +330,11 @@ class PodKernelManager:
|
|
|
244
330
|
if result.returncode != 0:
|
|
245
331
|
error_msg = result.stderr.lower()
|
|
246
332
|
if "permission denied" in error_msg or "publickey" in error_msg:
|
|
333
|
+
# Get provider-specific SSH setup instructions
|
|
334
|
+
ssh_help = self._get_ssh_setup_instructions()
|
|
247
335
|
return {
|
|
248
336
|
"status": "error",
|
|
249
|
-
"message":
|
|
250
|
-
"SSH authentication failed. Please add your SSH public key to Prime Intellect:\n"
|
|
251
|
-
"1. Visit https://app.primeintellect.ai/dashboard/tokens\n"
|
|
252
|
-
"2. Upload your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
|
|
253
|
-
"3. Try connecting again"
|
|
254
|
-
)
|
|
337
|
+
"message": ssh_help
|
|
255
338
|
}
|
|
256
339
|
elif "host key verification failed" in error_msg:
|
|
257
340
|
return {
|
|
@@ -265,6 +348,9 @@ class PodKernelManager:
|
|
|
265
348
|
}
|
|
266
349
|
|
|
267
350
|
# Extract on remote and install dependencies
|
|
351
|
+
# Use sudo for non-root users to run pip install
|
|
352
|
+
pip_cmd = "pip install --quiet pyzmq matplotlib" if ssh_user == "root" else "sudo pip install --quiet pyzmq matplotlib"
|
|
353
|
+
|
|
268
354
|
ssh_cmd = ["ssh", "-p", ssh_port]
|
|
269
355
|
|
|
270
356
|
if ssh_key:
|
|
@@ -275,11 +361,11 @@ class PodKernelManager:
|
|
|
275
361
|
"-o", "UserKnownHostsFile=/dev/null",
|
|
276
362
|
"-o", "BatchMode=yes",
|
|
277
363
|
"-o", "ConnectTimeout=10",
|
|
278
|
-
f"
|
|
364
|
+
f"{ssh_user}@{ssh_host}",
|
|
279
365
|
(
|
|
280
366
|
"cd /tmp && "
|
|
281
367
|
"tar -xzf morecompute.tar.gz && "
|
|
282
|
-
"
|
|
368
|
+
f"{pip_cmd} && "
|
|
283
369
|
"echo 'Deployment complete'"
|
|
284
370
|
)
|
|
285
371
|
])
|
|
@@ -308,11 +394,12 @@ class PodKernelManager:
|
|
|
308
394
|
"message": f"Deployment error: {str(e)}"
|
|
309
395
|
}
|
|
310
396
|
|
|
311
|
-
async def _create_ssh_tunnel(self, ssh_host: str, ssh_port: str) -> dict[str, object]:
|
|
397
|
+
async def _create_ssh_tunnel(self, ssh_user: str, ssh_host: str, ssh_port: str) -> dict[str, object]:
|
|
312
398
|
"""
|
|
313
399
|
Create SSH tunnel for ZMQ ports.
|
|
314
400
|
|
|
315
401
|
args:
|
|
402
|
+
ssh_user: SSH username (e.g., 'root' or 'ubuntu')
|
|
316
403
|
ssh_host: SSH host address
|
|
317
404
|
ssh_port: SSH port
|
|
318
405
|
|
|
@@ -336,7 +423,7 @@ class PodKernelManager:
|
|
|
336
423
|
"-N", # No command execution
|
|
337
424
|
"-L", f"{self.local_cmd_port}:localhost:{self.remote_cmd_port}",
|
|
338
425
|
"-L", f"{self.local_pub_port}:localhost:{self.remote_pub_port}",
|
|
339
|
-
f"
|
|
426
|
+
f"{ssh_user}@{ssh_host}"
|
|
340
427
|
])
|
|
341
428
|
|
|
342
429
|
self.ssh_tunnel_proc = subprocess.Popen(
|
|
@@ -369,17 +456,17 @@ class PodKernelManager:
|
|
|
369
456
|
}
|
|
370
457
|
|
|
371
458
|
except Exception as e:
|
|
372
|
-
print(f"[POD MANAGER] Exception creating tunnel: {e}", file=sys.stderr, flush=True)
|
|
373
459
|
return {
|
|
374
460
|
"status": "error",
|
|
375
461
|
"message": f"Tunnel creation error: {str(e)}"
|
|
376
462
|
}
|
|
377
463
|
|
|
378
|
-
async def _start_remote_worker(self, ssh_host: str, ssh_port: str) -> dict[str, object]:
|
|
464
|
+
async def _start_remote_worker(self, ssh_user: str, ssh_host: str, ssh_port: str) -> dict[str, object]:
|
|
379
465
|
"""
|
|
380
466
|
Start ZMQ worker on remote pod.
|
|
381
467
|
|
|
382
468
|
args:
|
|
469
|
+
ssh_user: SSH username (e.g., 'root' or 'ubuntu')
|
|
383
470
|
ssh_host: SSH host address
|
|
384
471
|
ssh_port: SSH port
|
|
385
472
|
|
|
@@ -387,31 +474,46 @@ class PodKernelManager:
|
|
|
387
474
|
dict with worker start status
|
|
388
475
|
"""
|
|
389
476
|
try:
|
|
390
|
-
print(f"[POD MANAGER] Starting remote worker on {ssh_host}:{ssh_port}", file=sys.stderr, flush=True)
|
|
391
|
-
|
|
392
477
|
# Start worker in background on remote pod
|
|
393
478
|
# Use 'python3' instead of sys.executable since remote pod may have different Python path
|
|
479
|
+
# Use sudo for non-root users
|
|
480
|
+
python_cmd = "python3" if ssh_user == "root" else "sudo python3"
|
|
481
|
+
|
|
394
482
|
ssh_key = self._get_ssh_key()
|
|
395
483
|
worker_cmd = ["ssh", "-p", ssh_port]
|
|
396
484
|
|
|
397
485
|
if ssh_key:
|
|
398
486
|
worker_cmd.extend(["-i", ssh_key])
|
|
399
487
|
|
|
488
|
+
# Build the command - for non-root users, we need to pass env vars through sudo
|
|
489
|
+
if ssh_user == "root":
|
|
490
|
+
remote_cmd = (
|
|
491
|
+
f"cd /tmp && "
|
|
492
|
+
f"MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} "
|
|
493
|
+
f"MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} "
|
|
494
|
+
f"setsid python3 -u /tmp/morecompute/execution/worker.py "
|
|
495
|
+
f"</dev/null >/tmp/worker.log 2>&1 & "
|
|
496
|
+
f"echo $!"
|
|
497
|
+
)
|
|
498
|
+
else:
|
|
499
|
+
# For non-root, use sudo with env vars passed via sudo's env mechanism
|
|
500
|
+
remote_cmd = (
|
|
501
|
+
f"cd /tmp && "
|
|
502
|
+
f"sudo MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} "
|
|
503
|
+
f"MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} "
|
|
504
|
+
f"setsid python3 -u /tmp/morecompute/execution/worker.py "
|
|
505
|
+
f"</dev/null >/tmp/worker.log 2>&1 & "
|
|
506
|
+
f"echo $!"
|
|
507
|
+
)
|
|
508
|
+
|
|
400
509
|
worker_cmd.extend([
|
|
401
510
|
"-o", "StrictHostKeyChecking=no",
|
|
402
511
|
"-o", "UserKnownHostsFile=/dev/null",
|
|
403
512
|
"-o", "BatchMode=yes",
|
|
404
513
|
"-o", "ConnectTimeout=10",
|
|
405
|
-
f"
|
|
514
|
+
f"{ssh_user}@{ssh_host}",
|
|
406
515
|
"sh", "-c",
|
|
407
|
-
|
|
408
|
-
f"'cd /tmp && "
|
|
409
|
-
f"MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} "
|
|
410
|
-
f"MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} "
|
|
411
|
-
f"setsid python3 -u /tmp/morecompute/execution/worker.py "
|
|
412
|
-
f"</dev/null >/tmp/worker.log 2>&1 & "
|
|
413
|
-
f"echo $!'"
|
|
414
|
-
)
|
|
516
|
+
f"'{remote_cmd}'"
|
|
415
517
|
])
|
|
416
518
|
|
|
417
519
|
result = subprocess.run(
|
|
@@ -428,13 +530,10 @@ class PodKernelManager:
|
|
|
428
530
|
}
|
|
429
531
|
|
|
430
532
|
remote_pid = result.stdout.strip()
|
|
431
|
-
print(f"[POD MANAGER] Remote worker PID: {remote_pid}", file=sys.stderr, flush=True)
|
|
432
533
|
|
|
433
534
|
# Wait for worker to be ready
|
|
434
535
|
await asyncio.sleep(2)
|
|
435
536
|
|
|
436
|
-
print(f"[POD MANAGER] Remote worker should be ready now", file=sys.stderr, flush=True)
|
|
437
|
-
|
|
438
537
|
return {
|
|
439
538
|
"status": "ok",
|
|
440
539
|
"message": "Remote worker started",
|
|
@@ -580,7 +679,7 @@ class PodKernelManager:
|
|
|
580
679
|
|
|
581
680
|
# Get updated pod info
|
|
582
681
|
try:
|
|
583
|
-
updated_pod = await self.
|
|
682
|
+
updated_pod = await self.provider_service.get_pod(pod.id)
|
|
584
683
|
pod_status = updated_pod.status
|
|
585
684
|
except Exception:
|
|
586
685
|
pod_status = "unknown"
|
|
@@ -594,12 +693,14 @@ class PodKernelManager:
|
|
|
594
693
|
"gpu_type": pod.gpuName,
|
|
595
694
|
"gpu_count": pod.gpuCount,
|
|
596
695
|
"price_hr": pod.priceHr,
|
|
597
|
-
"ssh_connection": pod.sshConnection
|
|
696
|
+
"ssh_connection": pod.sshConnection,
|
|
697
|
+
"provider": self.provider_type
|
|
598
698
|
},
|
|
599
699
|
"tunnel": {
|
|
600
700
|
"alive": tunnel_alive,
|
|
601
701
|
"local_cmd_port": self.local_cmd_port,
|
|
602
702
|
"local_pub_port": self.local_pub_port
|
|
603
703
|
},
|
|
604
|
-
"executor_attached": self.executor is not None
|
|
704
|
+
"executor_attached": self.executor is not None,
|
|
705
|
+
"provider": self.provider_type
|
|
605
706
|
}
|
|
@@ -1,36 +1,55 @@
|
|
|
1
1
|
"""Service for monitoring GPU pod status updates."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
import
|
|
5
|
-
from typing import Callable, Awaitable
|
|
4
|
+
from typing import Callable, Awaitable, Union
|
|
6
5
|
from cachetools import TTLCache
|
|
7
6
|
|
|
8
7
|
from .prime_intellect import PrimeIntellectService
|
|
8
|
+
from .providers.base_provider import BaseGPUProvider
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
PodUpdateCallback = Callable[[dict], Awaitable[None]]
|
|
12
12
|
|
|
13
|
+
# Type alias for supported provider services
|
|
14
|
+
ProviderService = Union[PrimeIntellectService, BaseGPUProvider]
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
class PodMonitor:
|
|
15
|
-
"""Monitors GPU pod status and broadcasts updates.
|
|
18
|
+
"""Monitors GPU pod status and broadcasts updates.
|
|
19
|
+
|
|
20
|
+
Supports monitoring pods from any GPU provider that implements
|
|
21
|
+
the BaseGPUProvider interface.
|
|
22
|
+
"""
|
|
16
23
|
|
|
17
24
|
POLL_INTERVAL_SECONDS = 5
|
|
18
25
|
|
|
19
26
|
def __init__(
|
|
20
27
|
self,
|
|
21
|
-
prime_intellect: PrimeIntellectService,
|
|
22
28
|
pod_cache: TTLCache,
|
|
23
|
-
update_callback: PodUpdateCallback
|
|
29
|
+
update_callback: PodUpdateCallback,
|
|
30
|
+
prime_intellect: PrimeIntellectService | None = None,
|
|
31
|
+
provider_service: BaseGPUProvider | None = None,
|
|
24
32
|
):
|
|
25
33
|
"""
|
|
26
34
|
Initialize pod monitor.
|
|
27
35
|
|
|
28
36
|
Args:
|
|
29
|
-
prime_intellect: Prime Intellect API service
|
|
30
37
|
pod_cache: Cache to clear on updates
|
|
31
38
|
update_callback: Async callback for broadcasting updates
|
|
39
|
+
prime_intellect: Legacy Prime Intellect API service (deprecated, use provider_service)
|
|
40
|
+
provider_service: GPU provider service implementing BaseGPUProvider
|
|
32
41
|
"""
|
|
33
|
-
|
|
42
|
+
# Support both old and new interface
|
|
43
|
+
if provider_service is not None:
|
|
44
|
+
self.provider = provider_service
|
|
45
|
+
self.provider_name = provider_service.PROVIDER_NAME
|
|
46
|
+
elif prime_intellect is not None:
|
|
47
|
+
# Backwards compatibility
|
|
48
|
+
self.provider = prime_intellect
|
|
49
|
+
self.provider_name = "prime_intellect"
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError("Either prime_intellect or provider_service must be provided")
|
|
52
|
+
|
|
34
53
|
self.pod_cache = pod_cache
|
|
35
54
|
self.update_callback = update_callback
|
|
36
55
|
self.monitoring_tasks: dict[str, asyncio.Task] = {}
|
|
@@ -44,12 +63,10 @@ class PodMonitor:
|
|
|
44
63
|
"""
|
|
45
64
|
# Don't start duplicate monitors
|
|
46
65
|
if pod_id in self.monitoring_tasks:
|
|
47
|
-
print(f"[POD MONITOR] Already monitoring pod {pod_id}", file=sys.stderr, flush=True)
|
|
48
66
|
return
|
|
49
67
|
|
|
50
68
|
task = asyncio.create_task(self._monitor_loop(pod_id))
|
|
51
69
|
self.monitoring_tasks[pod_id] = task
|
|
52
|
-
print(f"[POD MONITOR] Started monitoring pod {pod_id}", file=sys.stderr, flush=True)
|
|
53
70
|
|
|
54
71
|
async def stop_monitoring(self, pod_id: str) -> None:
|
|
55
72
|
"""
|
|
@@ -65,7 +82,26 @@ class PodMonitor:
|
|
|
65
82
|
await task
|
|
66
83
|
except asyncio.CancelledError:
|
|
67
84
|
pass
|
|
68
|
-
|
|
85
|
+
|
|
86
|
+
def _normalize_status(self, status: str) -> str:
|
|
87
|
+
"""Normalize status across different providers."""
|
|
88
|
+
# Common status normalization
|
|
89
|
+
status_map = {
|
|
90
|
+
# Common statuses
|
|
91
|
+
"running": "ACTIVE",
|
|
92
|
+
"active": "ACTIVE",
|
|
93
|
+
"ready": "ACTIVE",
|
|
94
|
+
"starting": "STARTING",
|
|
95
|
+
"pending": "PENDING",
|
|
96
|
+
"stopped": "STOPPED",
|
|
97
|
+
"terminated": "TERMINATED",
|
|
98
|
+
"error": "ERROR",
|
|
99
|
+
# Provider-specific
|
|
100
|
+
"exited": "TERMINATED",
|
|
101
|
+
"loading": "STARTING",
|
|
102
|
+
"booting": "STARTING",
|
|
103
|
+
}
|
|
104
|
+
return status_map.get(status.lower(), status.upper())
|
|
69
105
|
|
|
70
106
|
async def _monitor_loop(self, pod_id: str) -> None:
|
|
71
107
|
"""
|
|
@@ -78,61 +114,53 @@ class PodMonitor:
|
|
|
78
114
|
while True:
|
|
79
115
|
try:
|
|
80
116
|
# Fetch current pod status
|
|
81
|
-
pod = await self.
|
|
117
|
+
pod = await self.provider.get_pod(pod_id)
|
|
82
118
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
file=sys.stderr,
|
|
86
|
-
flush=True
|
|
87
|
-
)
|
|
119
|
+
# Normalize the status
|
|
120
|
+
normalized_status = self._normalize_status(pod.status)
|
|
88
121
|
|
|
89
122
|
# Clear cache to force fresh data
|
|
90
123
|
self.pod_cache.clear()
|
|
91
124
|
|
|
92
|
-
# Broadcast update
|
|
125
|
+
# Broadcast update with provider info
|
|
93
126
|
await self.update_callback({
|
|
94
127
|
"type": "pod_status_update",
|
|
95
128
|
"data": {
|
|
96
129
|
"pod_id": pod_id,
|
|
97
130
|
"name": pod.name,
|
|
98
|
-
"status":
|
|
131
|
+
"status": normalized_status,
|
|
99
132
|
"ssh_connection": pod.sshConnection,
|
|
100
133
|
"ip": pod.ip,
|
|
101
134
|
"gpu_name": pod.gpuName,
|
|
102
|
-
"
|
|
135
|
+
"gpu_count": pod.gpuCount,
|
|
136
|
+
"price_hr": pod.priceHr,
|
|
137
|
+
"provider": self.provider_name
|
|
103
138
|
}
|
|
104
139
|
})
|
|
105
140
|
|
|
106
141
|
# Stop monitoring if ERROR or TERMINATED
|
|
107
|
-
if
|
|
108
|
-
print(
|
|
109
|
-
f"[POD MONITOR] Pod {pod_id} reached terminal state: {pod.status}",
|
|
110
|
-
file=sys.stderr,
|
|
111
|
-
flush=True
|
|
112
|
-
)
|
|
142
|
+
if normalized_status in {"ERROR", "TERMINATED"}:
|
|
113
143
|
break
|
|
114
144
|
|
|
115
145
|
# If ACTIVE and has SSH connection, pod is fully ready - stop monitoring
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
flush=True
|
|
121
|
-
)
|
|
122
|
-
break
|
|
146
|
+
# Note: Modal doesn't support SSH, so we just check for ACTIVE
|
|
147
|
+
if normalized_status == "ACTIVE":
|
|
148
|
+
if pod.sshConnection or self.provider_name == "modal":
|
|
149
|
+
break
|
|
123
150
|
|
|
124
151
|
# Wait before next check
|
|
125
152
|
await asyncio.sleep(self.POLL_INTERVAL_SECONDS)
|
|
126
153
|
|
|
127
|
-
except Exception
|
|
128
|
-
print(
|
|
129
|
-
f"[POD MONITOR] Error checking pod {pod_id}: {e}",
|
|
130
|
-
file=sys.stderr,
|
|
131
|
-
flush=True
|
|
132
|
-
)
|
|
154
|
+
except Exception:
|
|
133
155
|
await asyncio.sleep(self.POLL_INTERVAL_SECONDS)
|
|
134
156
|
|
|
135
157
|
finally:
|
|
136
158
|
# Clean up
|
|
137
159
|
self.monitoring_tasks.pop(pod_id, None)
|
|
138
|
-
|
|
160
|
+
|
|
161
|
+
def stop_all(self) -> None:
|
|
162
|
+
"""Stop monitoring all pods."""
|
|
163
|
+
for pod_id in list(self.monitoring_tasks.keys()):
|
|
164
|
+
task = self.monitoring_tasks.pop(pod_id, None)
|
|
165
|
+
if task and not task.done():
|
|
166
|
+
task.cancel()
|
|
@@ -97,9 +97,7 @@ class PrimeIntellectService:
|
|
|
97
97
|
"""
|
|
98
98
|
Create a new pod
|
|
99
99
|
"""
|
|
100
|
-
import sys
|
|
101
100
|
payload = pod_request.model_dump(exclude_none=True)
|
|
102
|
-
print(f"[PI SERVICE] Creating pod with payload: {payload}", file=sys.stderr, flush=True)
|
|
103
101
|
|
|
104
102
|
response = await self._make_request(
|
|
105
103
|
"POST",
|
|
@@ -294,9 +292,7 @@ class PrimeIntellectService:
|
|
|
294
292
|
Disk response with disk details
|
|
295
293
|
|
|
296
294
|
"""
|
|
297
|
-
import sys
|
|
298
295
|
payload = disk_request.model_dump(exclude_none=True)
|
|
299
|
-
print(f"[PI SERVICE] Creating disk with payload: + {payload}", file=sys.stderr, flush=True)
|
|
300
296
|
response = await self._make_request("POST", "/disks/", json_data=payload)
|
|
301
297
|
return DiskResponse.model_validate(response)
|
|
302
298
|
|