more-compute 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. frontend/app/globals.css +734 -27
  2. frontend/app/layout.tsx +13 -3
  3. frontend/components/Notebook.tsx +2 -14
  4. frontend/components/cell/MonacoCell.tsx +99 -5
  5. frontend/components/layout/Sidebar.tsx +39 -4
  6. frontend/components/panels/ClaudePanel.tsx +461 -0
  7. frontend/components/popups/ComputePopup.tsx +738 -447
  8. frontend/components/popups/FilterPopup.tsx +305 -189
  9. frontend/components/popups/MetricsPopup.tsx +20 -1
  10. frontend/components/popups/ProviderConfigModal.tsx +322 -0
  11. frontend/components/popups/ProviderDropdown.tsx +398 -0
  12. frontend/components/popups/SettingsPopup.tsx +1 -1
  13. frontend/contexts/ClaudeContext.tsx +392 -0
  14. frontend/contexts/PodWebSocketContext.tsx +16 -21
  15. frontend/hooks/useInlineDiff.ts +269 -0
  16. frontend/lib/api.ts +323 -12
  17. frontend/lib/settings.ts +5 -0
  18. frontend/lib/websocket-native.ts +4 -8
  19. frontend/lib/websocket.ts +1 -2
  20. frontend/package-lock.json +733 -36
  21. frontend/package.json +2 -0
  22. frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
  23. frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
  24. frontend/public/assets/icons/providers/runpod.svg +9 -0
  25. frontend/public/assets/icons/providers/vastai.svg +1 -0
  26. frontend/settings.md +54 -0
  27. frontend/tsconfig.tsbuildinfo +1 -0
  28. frontend/types/claude.ts +194 -0
  29. kernel_run.py +13 -0
  30. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
  31. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
  32. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
  33. morecompute/__init__.py +1 -1
  34. morecompute/__version__.py +1 -1
  35. morecompute/execution/executor.py +24 -67
  36. morecompute/execution/worker.py +6 -72
  37. morecompute/models/api_models.py +62 -0
  38. morecompute/notebook.py +11 -0
  39. morecompute/server.py +641 -133
  40. morecompute/services/claude_service.py +392 -0
  41. morecompute/services/pod_manager.py +168 -67
  42. morecompute/services/pod_monitor.py +67 -39
  43. morecompute/services/prime_intellect.py +0 -4
  44. morecompute/services/providers/__init__.py +92 -0
  45. morecompute/services/providers/base_provider.py +336 -0
  46. morecompute/services/providers/lambda_labs_provider.py +394 -0
  47. morecompute/services/providers/provider_factory.py +194 -0
  48. morecompute/services/providers/runpod_provider.py +504 -0
  49. morecompute/services/providers/vastai_provider.py +407 -0
  50. morecompute/utils/cell_magics.py +0 -3
  51. morecompute/utils/config_util.py +93 -3
  52. morecompute/utils/special_commands.py +5 -32
  53. morecompute/utils/version_check.py +117 -0
  54. frontend/styling_README.md +0 -23
  55. {more_compute-0.4.4.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
  56. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
  57. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,44 @@
1
1
  import asyncio
2
2
  import subprocess
3
3
  import os
4
- import sys
5
4
  import tempfile
6
5
  import tarfile
7
6
  from pathlib import Path
8
7
  from typing import TYPE_CHECKING
9
8
 
10
- from .prime_intellect import PrimeIntellectService, PodResponse
9
+ from .providers.base_provider import BaseGPUProvider
10
+ from ..models.api_models import PodResponse
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from ..execution.executor import NextZmqExecutor
14
14
 
15
+ # Type alias for supported provider services
16
+ ProviderService = BaseGPUProvider
17
+
18
+
15
19
  class PodKernelManager:
16
20
  """
17
- Manages remote GPU pod connections (currently PI as provider, hope to provide other providers in the future)
18
- and SSH tunnels for ZMQ execution
21
+ Manages remote GPU pod connections and SSH tunnels for ZMQ execution.
22
+
23
+ Supports multiple GPU providers:
24
+ - RunPod
25
+ - Lambda Labs
26
+ - Vast.ai
19
27
  """
20
- pi_service: PrimeIntellectService
28
+ provider_service: ProviderService
29
+ provider_type: str
21
30
  pod: PodResponse | None
22
31
  ssh_tunnel_proc: subprocess.Popen[bytes] | None
23
32
  local_cmd_port: int
24
33
  local_pub_port: int
25
- remote_cmd_port : int
34
+ remote_cmd_port: int
26
35
  remote_pub_port: int
27
36
  executor: "NextZmqExecutor | None"
28
37
  _ssh_key_cache: str | None
29
38
 
30
39
  def __init__(
31
40
  self,
32
- pi_service: PrimeIntellectService,
41
+ provider_service: ProviderService,
33
42
  local_cmd_port: int = 15555,
34
43
  local_pub_port: int = 15556,
35
44
  remote_cmd_port: int = 5555,
@@ -39,13 +48,15 @@ class PodKernelManager:
39
48
  Initialize pod manager
40
49
 
41
50
  args:
42
- pi_service : Prime Intellect API service
51
+ provider_service: GPU provider service implementing BaseGPUProvider
43
52
  local_cmd_port: Local port for REQ/REP tunnel
44
53
  local_pub_port: Local port for PUB/SUB tunnel
45
54
  remote_cmd_port: Remote port for REQ/REP socket
46
55
  remote_pub_port: Remote port for PUB/SUB socket
47
56
  """
48
- self.pi_service = pi_service
57
+ self.provider_service = provider_service
58
+ self.provider_type = getattr(provider_service, 'PROVIDER_NAME', 'unknown')
59
+
49
60
  self.pod = None
50
61
  self.ssh_tunnel_proc = None
51
62
  self.local_cmd_port = local_cmd_port
@@ -71,9 +82,8 @@ class PodKernelManager:
71
82
  self._ssh_key_cache = expanded
72
83
  return expanded
73
84
 
74
- # Try common SSH key paths (including Prime Intellect's recommended name)
85
+ # Try common SSH key paths
75
86
  common_keys = [
76
- "~/.ssh/primeintellect_ed25519", # Prime Intellect's recommended name
77
87
  "~/.ssh/id_ed25519",
78
88
  "~/.ssh/id_rsa",
79
89
  ]
@@ -85,6 +95,84 @@ class PodKernelManager:
85
95
 
86
96
  return None
87
97
 
98
+ def _is_key_encrypted(self, key_path: str) -> bool:
99
+ """Check if an SSH private key is encrypted with a passphrase."""
100
+ try:
101
+ with open(key_path, 'r') as f:
102
+ content = f.read(500) # Read first 500 bytes
103
+ # OpenSSH encrypted keys contain these markers
104
+ return 'aes256-ctr' in content or 'aes128-ctr' in content or 'bcrypt' in content
105
+ except Exception:
106
+ return False
107
+
108
+ def _is_key_in_agent(self, key_path: str) -> bool:
109
+ """Check if the SSH key is loaded in the ssh-agent."""
110
+ try:
111
+ result = subprocess.run(
112
+ ["ssh-add", "-l"],
113
+ capture_output=True,
114
+ text=True,
115
+ timeout=5
116
+ )
117
+ if result.returncode != 0:
118
+ return False
119
+ # Check if the key fingerprint is in the agent
120
+ # Get the fingerprint of our key
121
+ fp_result = subprocess.run(
122
+ ["ssh-keygen", "-lf", key_path],
123
+ capture_output=True,
124
+ text=True,
125
+ timeout=5
126
+ )
127
+ if fp_result.returncode == 0:
128
+ # Extract fingerprint (e.g., SHA256:xxx)
129
+ parts = fp_result.stdout.split()
130
+ if len(parts) >= 2:
131
+ fingerprint = parts[1]
132
+ return fingerprint in result.stdout
133
+ return False
134
+ except Exception:
135
+ return False
136
+
137
+ def _get_ssh_setup_instructions(self) -> str:
138
+ """Get provider-specific SSH setup instructions."""
139
+ provider_instructions = {
140
+ "runpod": (
141
+ "SSH authentication failed. Please add your SSH public key to RunPod:\n"
142
+ "1. Visit https://www.runpod.io/console/user/settings\n"
143
+ "2. Go to 'SSH Public Keys' section\n"
144
+ "3. Add your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
145
+ "4. Try connecting again"
146
+ ),
147
+ "lambda_labs": (
148
+ "SSH authentication failed. Lambda Labs SSH key mismatch.\n\n"
149
+ "IMPORTANT: Lambda Labs assigns SSH keys at instance creation time.\n"
150
+ "Your current instance may have been created with a different key.\n\n"
151
+ "To fix this:\n"
152
+ "1. Terminate the current instance\n"
153
+ "2. Visit https://cloud.lambdalabs.com/ssh-keys\n"
154
+ "3. Add your public key (~/.ssh/id_ed25519.pub) if not already there\n"
155
+ "4. Create a new instance - it will use your registered key\n\n"
156
+ "To view your public key, run: cat ~/.ssh/id_ed25519.pub"
157
+ ),
158
+ "vastai": (
159
+ "SSH authentication failed. Please add your SSH public key to Vast.ai:\n"
160
+ "1. Visit https://cloud.vast.ai/account/\n"
161
+ "2. Go to 'SSH Keys' section\n"
162
+ "3. Add your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
163
+ "4. Try connecting again"
164
+ ),
165
+ }
166
+
167
+ return provider_instructions.get(
168
+ self.provider_type,
169
+ (
170
+ "SSH authentication failed. Please ensure your SSH public key is added to your GPU provider.\n"
171
+ "Upload your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub) to your provider's dashboard.\n"
172
+ "Then try connecting again."
173
+ )
174
+ )
175
+
88
176
  async def connect_to_pod(self, pod_id:str) -> dict[str, object]:
89
177
  """
90
178
  Connects to existing pod and set up ssh tunnel
@@ -94,8 +182,6 @@ class PodKernelManager:
94
182
  Response:
95
183
  dict with connection status
96
184
  """
97
- import sys
98
-
99
185
  # Check if already connected to this pod
100
186
  if self.pod and self.pod.id == pod_id:
101
187
  # Check if tunnel is still alive
@@ -105,18 +191,13 @@ class PodKernelManager:
105
191
  "message": f"Already connected to pod {pod_id}"
106
192
  }
107
193
  # Tunnel died, clean up and reconnect
108
- print(f"[POD MANAGER] Existing tunnel dead, reconnecting...", file=sys.stderr, flush=True)
109
194
  await self.disconnect()
110
195
 
111
196
  # If connected to different pod, disconnect first
112
197
  if self.pod and self.pod.id != pod_id:
113
- print(f"[POD MANAGER] Disconnecting from {self.pod.id} to connect to {pod_id}", file=sys.stderr, flush=True)
114
198
  await self.disconnect()
115
199
 
116
- self.pod = await self.pi_service.get_pod(pod_id)
117
-
118
- print(f"[POD MANAGER] Pod status: {self.pod.status}", file=sys.stderr, flush=True)
119
- print(f"[POD MANAGER] SSH connection: {self.pod.sshConnection}", file=sys.stderr, flush=True)
200
+ self.pod = await self.provider_service.get_pod(pod_id)
120
201
 
121
202
  if not self.pod.sshConnection:
122
203
  return{
@@ -132,7 +213,7 @@ class PodKernelManager:
132
213
  }
133
214
 
134
215
  # Parse SSH connection string
135
- # Format can be: "ssh root@ip -p port" OR "root@ip -p port"
216
+ # Format can be: "ssh root@ip -p port" OR "root@ip -p port" OR "ssh ubuntu@ip"
136
217
  ssh_parts = self.pod.sshConnection.split()
137
218
 
138
219
  # Find the part containing @ (user@host)
@@ -148,8 +229,8 @@ class PodKernelManager:
148
229
  "message": f"Invalid SSH connection format (no user@host found): {self.pod.sshConnection}"
149
230
  }
150
231
 
151
- # Extract host from user@host
152
- ssh_host = host_part.split("@")[1]
232
+ # Extract user and host from user@host
233
+ ssh_user, ssh_host = host_part.split("@")
153
234
  ssh_port = "22"
154
235
 
155
236
  # Extract port if specified with -p flag
@@ -158,33 +239,22 @@ class PodKernelManager:
158
239
  if port_idx + 1 < len(ssh_parts):
159
240
  ssh_port = ssh_parts[port_idx + 1]
160
241
 
161
- print(f"[POD MANAGER] Parsed SSH host: {ssh_host}, port: {ssh_port}", file=sys.stderr, flush=True)
162
-
163
242
  #deploy worker code to pod
164
- print(f"[POD MANAGER] Deploying worker code to pod...", file=sys.stderr, flush=True)
165
- deploy_result = await self._deploy_worker(ssh_host, ssh_port)
166
- print(f"[POD MANAGER] Deploy result: {deploy_result}", file=sys.stderr, flush=True)
243
+ deploy_result = await self._deploy_worker(ssh_user, ssh_host, ssh_port)
167
244
  if deploy_result.get("status") == "error":
168
245
  return deploy_result
169
246
 
170
247
  #create ssh tunnel for ZMQ ports
171
- print(f"[POD MANAGER] Creating SSH tunnel...", file=sys.stderr, flush=True)
172
- tunnel_result = await self._create_ssh_tunnel(ssh_host, ssh_port)
173
- print(f"[POD MANAGER] Tunnel result: {tunnel_result}", file=sys.stderr, flush=True)
248
+ tunnel_result = await self._create_ssh_tunnel(ssh_user, ssh_host, ssh_port)
174
249
  if tunnel_result.get("status") == "error":
175
250
  return tunnel_result
176
251
 
177
252
  #start remote worker
178
- worker_result = await self._start_remote_worker(ssh_host, ssh_port)
253
+ worker_result = await self._start_remote_worker(ssh_user, ssh_host, ssh_port)
179
254
  if worker_result.get("status") == "error":
180
255
  await self.disconnect()
181
256
  return worker_result
182
257
 
183
- # Note: Worker may take a few seconds to start and install matplotlib
184
- # The connection should work even if verification fails
185
- print(f"[POD MANAGER] Remote worker is starting (matplotlib install may take a few seconds)", file=sys.stderr, flush=True)
186
- print(f"[POD MANAGER] Connection established - try running code in ~5 seconds", file=sys.stderr, flush=True)
187
-
188
258
  return {
189
259
  "status": "ok",
190
260
  "message": f"Connected to pod {pod_id}",
@@ -196,11 +266,12 @@ class PodKernelManager:
196
266
  }
197
267
  }
198
268
 
199
- async def _deploy_worker(self, ssh_host: str, ssh_port: str) -> dict[str,object]:
269
+ async def _deploy_worker(self, ssh_user: str, ssh_host: str, ssh_port: str) -> dict[str,object]:
200
270
  """
201
271
  Deploy worker code to remote pod via Secure Copy Protocol.
202
272
 
203
273
  args:
274
+ ssh_user: SSH username (e.g., 'root' or 'ubuntu')
204
275
  ssh_host: SSH host address
205
276
  ssh_port: SSH port
206
277
 
@@ -208,6 +279,21 @@ class PodKernelManager:
208
279
  dict with deployment status
209
280
  """
210
281
  try:
282
+ # Check if SSH key is encrypted and not in agent
283
+ ssh_key = self._get_ssh_key()
284
+ if ssh_key and self._is_key_encrypted(ssh_key):
285
+ if not self._is_key_in_agent(ssh_key):
286
+ key_name = os.path.basename(ssh_key)
287
+ return {
288
+ "status": "error",
289
+ "message": (
290
+ f"Your SSH key ({key_name}) is protected with a passphrase but not loaded in ssh-agent.\n\n"
291
+ f"To fix this, run:\n"
292
+ f" ssh-add {ssh_key}\n\n"
293
+ f"Enter your passphrase when prompted, then try connecting again."
294
+ )
295
+ }
296
+
211
297
  # Create temporary tarball of morecompute package
212
298
  project_root = Path(__file__).parent.parent.parent
213
299
  morecompute_dir = project_root / "morecompute"
@@ -231,7 +317,7 @@ class PodKernelManager:
231
317
  "-o", "BatchMode=yes", # Prevent password prompts, fail fast if key auth doesn't work
232
318
  "-o", "ConnectTimeout=10",
233
319
  tmp_path,
234
- f"root@{ssh_host}:/tmp/morecompute.tar.gz"
320
+ f"{ssh_user}@{ssh_host}:/tmp/morecompute.tar.gz"
235
321
  ])
236
322
 
237
323
  result = subprocess.run(
@@ -244,14 +330,11 @@ class PodKernelManager:
244
330
  if result.returncode != 0:
245
331
  error_msg = result.stderr.lower()
246
332
  if "permission denied" in error_msg or "publickey" in error_msg:
333
+ # Get provider-specific SSH setup instructions
334
+ ssh_help = self._get_ssh_setup_instructions()
247
335
  return {
248
336
  "status": "error",
249
- "message": (
250
- "SSH authentication failed. Please add your SSH public key to Prime Intellect:\n"
251
- "1. Visit https://app.primeintellect.ai/dashboard/tokens\n"
252
- "2. Upload your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
253
- "3. Try connecting again"
254
- )
337
+ "message": ssh_help
255
338
  }
256
339
  elif "host key verification failed" in error_msg:
257
340
  return {
@@ -265,6 +348,9 @@ class PodKernelManager:
265
348
  }
266
349
 
267
350
  # Extract on remote and install dependencies
351
+ # Use sudo for non-root users to run pip install
352
+ pip_cmd = "pip install --quiet pyzmq matplotlib" if ssh_user == "root" else "sudo pip install --quiet pyzmq matplotlib"
353
+
268
354
  ssh_cmd = ["ssh", "-p", ssh_port]
269
355
 
270
356
  if ssh_key:
@@ -275,11 +361,11 @@ class PodKernelManager:
275
361
  "-o", "UserKnownHostsFile=/dev/null",
276
362
  "-o", "BatchMode=yes",
277
363
  "-o", "ConnectTimeout=10",
278
- f"root@{ssh_host}",
364
+ f"{ssh_user}@{ssh_host}",
279
365
  (
280
366
  "cd /tmp && "
281
367
  "tar -xzf morecompute.tar.gz && "
282
- "pip install --quiet pyzmq matplotlib && "
368
+ f"{pip_cmd} && "
283
369
  "echo 'Deployment complete'"
284
370
  )
285
371
  ])
@@ -308,11 +394,12 @@ class PodKernelManager:
308
394
  "message": f"Deployment error: {str(e)}"
309
395
  }
310
396
 
311
- async def _create_ssh_tunnel(self, ssh_host: str, ssh_port: str) -> dict[str, object]:
397
+ async def _create_ssh_tunnel(self, ssh_user: str, ssh_host: str, ssh_port: str) -> dict[str, object]:
312
398
  """
313
399
  Create SSH tunnel for ZMQ ports.
314
400
 
315
401
  args:
402
+ ssh_user: SSH username (e.g., 'root' or 'ubuntu')
316
403
  ssh_host: SSH host address
317
404
  ssh_port: SSH port
318
405
 
@@ -336,7 +423,7 @@ class PodKernelManager:
336
423
  "-N", # No command execution
337
424
  "-L", f"{self.local_cmd_port}:localhost:{self.remote_cmd_port}",
338
425
  "-L", f"{self.local_pub_port}:localhost:{self.remote_pub_port}",
339
- f"root@{ssh_host}"
426
+ f"{ssh_user}@{ssh_host}"
340
427
  ])
341
428
 
342
429
  self.ssh_tunnel_proc = subprocess.Popen(
@@ -369,17 +456,17 @@ class PodKernelManager:
369
456
  }
370
457
 
371
458
  except Exception as e:
372
- print(f"[POD MANAGER] Exception creating tunnel: {e}", file=sys.stderr, flush=True)
373
459
  return {
374
460
  "status": "error",
375
461
  "message": f"Tunnel creation error: {str(e)}"
376
462
  }
377
463
 
378
- async def _start_remote_worker(self, ssh_host: str, ssh_port: str) -> dict[str, object]:
464
+ async def _start_remote_worker(self, ssh_user: str, ssh_host: str, ssh_port: str) -> dict[str, object]:
379
465
  """
380
466
  Start ZMQ worker on remote pod.
381
467
 
382
468
  args:
469
+ ssh_user: SSH username (e.g., 'root' or 'ubuntu')
383
470
  ssh_host: SSH host address
384
471
  ssh_port: SSH port
385
472
 
@@ -387,31 +474,46 @@ class PodKernelManager:
387
474
  dict with worker start status
388
475
  """
389
476
  try:
390
- print(f"[POD MANAGER] Starting remote worker on {ssh_host}:{ssh_port}", file=sys.stderr, flush=True)
391
-
392
477
  # Start worker in background on remote pod
393
478
  # Use 'python3' instead of sys.executable since remote pod may have different Python path
479
+ # Use sudo for non-root users
480
+ python_cmd = "python3" if ssh_user == "root" else "sudo python3"
481
+
394
482
  ssh_key = self._get_ssh_key()
395
483
  worker_cmd = ["ssh", "-p", ssh_port]
396
484
 
397
485
  if ssh_key:
398
486
  worker_cmd.extend(["-i", ssh_key])
399
487
 
488
+ # Build the command - for non-root users, we need to pass env vars through sudo
489
+ if ssh_user == "root":
490
+ remote_cmd = (
491
+ f"cd /tmp && "
492
+ f"MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} "
493
+ f"MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} "
494
+ f"setsid python3 -u /tmp/morecompute/execution/worker.py "
495
+ f"</dev/null >/tmp/worker.log 2>&1 & "
496
+ f"echo $!"
497
+ )
498
+ else:
499
+ # For non-root, use sudo with env vars passed via sudo's env mechanism
500
+ remote_cmd = (
501
+ f"cd /tmp && "
502
+ f"sudo MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} "
503
+ f"MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} "
504
+ f"setsid python3 -u /tmp/morecompute/execution/worker.py "
505
+ f"</dev/null >/tmp/worker.log 2>&1 & "
506
+ f"echo $!"
507
+ )
508
+
400
509
  worker_cmd.extend([
401
510
  "-o", "StrictHostKeyChecking=no",
402
511
  "-o", "UserKnownHostsFile=/dev/null",
403
512
  "-o", "BatchMode=yes",
404
513
  "-o", "ConnectTimeout=10",
405
- f"root@{ssh_host}",
514
+ f"{ssh_user}@{ssh_host}",
406
515
  "sh", "-c",
407
- (
408
- f"'cd /tmp && "
409
- f"MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} "
410
- f"MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} "
411
- f"setsid python3 -u /tmp/morecompute/execution/worker.py "
412
- f"</dev/null >/tmp/worker.log 2>&1 & "
413
- f"echo $!'"
414
- )
516
+ f"'{remote_cmd}'"
415
517
  ])
416
518
 
417
519
  result = subprocess.run(
@@ -428,13 +530,10 @@ class PodKernelManager:
428
530
  }
429
531
 
430
532
  remote_pid = result.stdout.strip()
431
- print(f"[POD MANAGER] Remote worker PID: {remote_pid}", file=sys.stderr, flush=True)
432
533
 
433
534
  # Wait for worker to be ready
434
535
  await asyncio.sleep(2)
435
536
 
436
- print(f"[POD MANAGER] Remote worker should be ready now", file=sys.stderr, flush=True)
437
-
438
537
  return {
439
538
  "status": "ok",
440
539
  "message": "Remote worker started",
@@ -580,7 +679,7 @@ class PodKernelManager:
580
679
 
581
680
  # Get updated pod info
582
681
  try:
583
- updated_pod = await self.pi_service.get_pod(pod.id)
682
+ updated_pod = await self.provider_service.get_pod(pod.id)
584
683
  pod_status = updated_pod.status
585
684
  except Exception:
586
685
  pod_status = "unknown"
@@ -594,12 +693,14 @@ class PodKernelManager:
594
693
  "gpu_type": pod.gpuName,
595
694
  "gpu_count": pod.gpuCount,
596
695
  "price_hr": pod.priceHr,
597
- "ssh_connection": pod.sshConnection
696
+ "ssh_connection": pod.sshConnection,
697
+ "provider": self.provider_type
598
698
  },
599
699
  "tunnel": {
600
700
  "alive": tunnel_alive,
601
701
  "local_cmd_port": self.local_cmd_port,
602
702
  "local_pub_port": self.local_pub_port
603
703
  },
604
- "executor_attached": self.executor is not None
704
+ "executor_attached": self.executor is not None,
705
+ "provider": self.provider_type
605
706
  }
@@ -1,36 +1,55 @@
1
1
  """Service for monitoring GPU pod status updates."""
2
2
 
3
3
  import asyncio
4
- import sys
5
- from typing import Callable, Awaitable
4
+ from typing import Callable, Awaitable, Union
6
5
  from cachetools import TTLCache
7
6
 
8
7
  from .prime_intellect import PrimeIntellectService
8
+ from .providers.base_provider import BaseGPUProvider
9
9
 
10
10
 
11
11
  PodUpdateCallback = Callable[[dict], Awaitable[None]]
12
12
 
13
+ # Type alias for supported provider services
14
+ ProviderService = Union[PrimeIntellectService, BaseGPUProvider]
15
+
13
16
 
14
17
  class PodMonitor:
15
- """Monitors GPU pod status and broadcasts updates."""
18
+ """Monitors GPU pod status and broadcasts updates.
19
+
20
+ Supports monitoring pods from any GPU provider that implements
21
+ the BaseGPUProvider interface.
22
+ """
16
23
 
17
24
  POLL_INTERVAL_SECONDS = 5
18
25
 
19
26
  def __init__(
20
27
  self,
21
- prime_intellect: PrimeIntellectService,
22
28
  pod_cache: TTLCache,
23
- update_callback: PodUpdateCallback
29
+ update_callback: PodUpdateCallback,
30
+ prime_intellect: PrimeIntellectService | None = None,
31
+ provider_service: BaseGPUProvider | None = None,
24
32
  ):
25
33
  """
26
34
  Initialize pod monitor.
27
35
 
28
36
  Args:
29
- prime_intellect: Prime Intellect API service
30
37
  pod_cache: Cache to clear on updates
31
38
  update_callback: Async callback for broadcasting updates
39
+ prime_intellect: Legacy Prime Intellect API service (deprecated, use provider_service)
40
+ provider_service: GPU provider service implementing BaseGPUProvider
32
41
  """
33
- self.pi_service = prime_intellect
42
+ # Support both old and new interface
43
+ if provider_service is not None:
44
+ self.provider = provider_service
45
+ self.provider_name = provider_service.PROVIDER_NAME
46
+ elif prime_intellect is not None:
47
+ # Backwards compatibility
48
+ self.provider = prime_intellect
49
+ self.provider_name = "prime_intellect"
50
+ else:
51
+ raise ValueError("Either prime_intellect or provider_service must be provided")
52
+
34
53
  self.pod_cache = pod_cache
35
54
  self.update_callback = update_callback
36
55
  self.monitoring_tasks: dict[str, asyncio.Task] = {}
@@ -44,12 +63,10 @@ class PodMonitor:
44
63
  """
45
64
  # Don't start duplicate monitors
46
65
  if pod_id in self.monitoring_tasks:
47
- print(f"[POD MONITOR] Already monitoring pod {pod_id}", file=sys.stderr, flush=True)
48
66
  return
49
67
 
50
68
  task = asyncio.create_task(self._monitor_loop(pod_id))
51
69
  self.monitoring_tasks[pod_id] = task
52
- print(f"[POD MONITOR] Started monitoring pod {pod_id}", file=sys.stderr, flush=True)
53
70
 
54
71
  async def stop_monitoring(self, pod_id: str) -> None:
55
72
  """
@@ -65,7 +82,26 @@ class PodMonitor:
65
82
  await task
66
83
  except asyncio.CancelledError:
67
84
  pass
68
- print(f"[POD MONITOR] Stopped monitoring pod {pod_id}", file=sys.stderr, flush=True)
85
+
86
+ def _normalize_status(self, status: str) -> str:
87
+ """Normalize status across different providers."""
88
+ # Common status normalization
89
+ status_map = {
90
+ # Common statuses
91
+ "running": "ACTIVE",
92
+ "active": "ACTIVE",
93
+ "ready": "ACTIVE",
94
+ "starting": "STARTING",
95
+ "pending": "PENDING",
96
+ "stopped": "STOPPED",
97
+ "terminated": "TERMINATED",
98
+ "error": "ERROR",
99
+ # Provider-specific
100
+ "exited": "TERMINATED",
101
+ "loading": "STARTING",
102
+ "booting": "STARTING",
103
+ }
104
+ return status_map.get(status.lower(), status.upper())
69
105
 
70
106
  async def _monitor_loop(self, pod_id: str) -> None:
71
107
  """
@@ -78,61 +114,53 @@ class PodMonitor:
78
114
  while True:
79
115
  try:
80
116
  # Fetch current pod status
81
- pod = await self.pi_service.get_pod(pod_id)
117
+ pod = await self.provider.get_pod(pod_id)
82
118
 
83
- print(
84
- f"[POD MONITOR] Pod {pod_id} status: {pod.status}",
85
- file=sys.stderr,
86
- flush=True
87
- )
119
+ # Normalize the status
120
+ normalized_status = self._normalize_status(pod.status)
88
121
 
89
122
  # Clear cache to force fresh data
90
123
  self.pod_cache.clear()
91
124
 
92
- # Broadcast update
125
+ # Broadcast update with provider info
93
126
  await self.update_callback({
94
127
  "type": "pod_status_update",
95
128
  "data": {
96
129
  "pod_id": pod_id,
97
130
  "name": pod.name,
98
- "status": pod.status,
131
+ "status": normalized_status,
99
132
  "ssh_connection": pod.sshConnection,
100
133
  "ip": pod.ip,
101
134
  "gpu_name": pod.gpuName,
102
- "price_hr": pod.priceHr
135
+ "gpu_count": pod.gpuCount,
136
+ "price_hr": pod.priceHr,
137
+ "provider": self.provider_name
103
138
  }
104
139
  })
105
140
 
106
141
  # Stop monitoring if ERROR or TERMINATED
107
- if pod.status in {"ERROR", "TERMINATED"}:
108
- print(
109
- f"[POD MONITOR] Pod {pod_id} reached terminal state: {pod.status}",
110
- file=sys.stderr,
111
- flush=True
112
- )
142
+ if normalized_status in {"ERROR", "TERMINATED"}:
113
143
  break
114
144
 
115
145
  # If ACTIVE and has SSH connection, pod is fully ready - stop monitoring
116
- if pod.status == "ACTIVE" and pod.sshConnection:
117
- print(
118
- f"[POD MONITOR] Pod {pod_id} is ACTIVE with SSH connection: {pod.sshConnection}",
119
- file=sys.stderr,
120
- flush=True
121
- )
122
- break
146
+ # Note: Modal doesn't support SSH, so we just check for ACTIVE
147
+ if normalized_status == "ACTIVE":
148
+ if pod.sshConnection or self.provider_name == "modal":
149
+ break
123
150
 
124
151
  # Wait before next check
125
152
  await asyncio.sleep(self.POLL_INTERVAL_SECONDS)
126
153
 
127
- except Exception as e:
128
- print(
129
- f"[POD MONITOR] Error checking pod {pod_id}: {e}",
130
- file=sys.stderr,
131
- flush=True
132
- )
154
+ except Exception:
133
155
  await asyncio.sleep(self.POLL_INTERVAL_SECONDS)
134
156
 
135
157
  finally:
136
158
  # Clean up
137
159
  self.monitoring_tasks.pop(pod_id, None)
138
- print(f"[POD MONITOR] Stopped monitoring pod {pod_id}", file=sys.stderr, flush=True)
160
+
161
+ def stop_all(self) -> None:
162
+ """Stop monitoring all pods."""
163
+ for pod_id in list(self.monitoring_tasks.keys()):
164
+ task = self.monitoring_tasks.pop(pod_id, None)
165
+ if task and not task.done():
166
+ task.cancel()
@@ -97,9 +97,7 @@ class PrimeIntellectService:
97
97
  """
98
98
  Create a new pod
99
99
  """
100
- import sys
101
100
  payload = pod_request.model_dump(exclude_none=True)
102
- print(f"[PI SERVICE] Creating pod with payload: {payload}", file=sys.stderr, flush=True)
103
101
 
104
102
  response = await self._make_request(
105
103
  "POST",
@@ -294,9 +292,7 @@ class PrimeIntellectService:
294
292
  Disk response with disk details
295
293
 
296
294
  """
297
- import sys
298
295
  payload = disk_request.model_dump(exclude_none=True)
299
- print(f"[PI SERVICE] Creating disk with payload: + {payload}", file=sys.stderr, flush=True)
300
296
  response = await self._make_request("POST", "/disks/", json_data=payload)
301
297
  return DiskResponse.model_validate(response)
302
298