more-compute 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. frontend/app/globals.css +322 -77
  2. frontend/app/layout.tsx +98 -82
  3. frontend/components/Cell.tsx +234 -95
  4. frontend/components/Notebook.tsx +430 -199
  5. frontend/components/{AddCellButton.tsx → cell/AddCellButton.tsx} +0 -2
  6. frontend/components/cell/MonacoCell.tsx +726 -0
  7. frontend/components/layout/ConnectionBanner.tsx +41 -0
  8. frontend/components/{Sidebar.tsx → layout/Sidebar.tsx} +16 -11
  9. frontend/components/modals/ConfirmModal.tsx +154 -0
  10. frontend/components/modals/SuccessModal.tsx +140 -0
  11. frontend/components/output/MarkdownRenderer.tsx +116 -0
  12. frontend/components/popups/ComputePopup.tsx +674 -365
  13. frontend/components/popups/MetricsPopup.tsx +11 -7
  14. frontend/components/popups/SettingsPopup.tsx +11 -13
  15. frontend/contexts/PodWebSocketContext.tsx +247 -0
  16. frontend/eslint.config.mjs +11 -0
  17. frontend/lib/monaco-themes.ts +160 -0
  18. frontend/lib/settings.ts +128 -26
  19. frontend/lib/themes.json +9973 -0
  20. frontend/lib/websocket-native.ts +19 -8
  21. frontend/lib/websocket.ts +59 -11
  22. frontend/next.config.ts +8 -0
  23. frontend/package-lock.json +1705 -3
  24. frontend/package.json +8 -1
  25. frontend/styling_README.md +18 -0
  26. kernel_run.py +159 -42
  27. more_compute-0.2.0.dist-info/METADATA +126 -0
  28. more_compute-0.2.0.dist-info/RECORD +100 -0
  29. morecompute/__version__.py +1 -1
  30. morecompute/execution/executor.py +31 -20
  31. morecompute/execution/worker.py +68 -7
  32. morecompute/models/__init__.py +31 -0
  33. morecompute/models/api_models.py +197 -0
  34. morecompute/notebook.py +50 -7
  35. morecompute/server.py +574 -94
  36. morecompute/services/data_manager.py +379 -0
  37. morecompute/services/lsp_service.py +335 -0
  38. morecompute/services/pod_manager.py +122 -20
  39. morecompute/services/pod_monitor.py +138 -0
  40. morecompute/services/prime_intellect.py +87 -63
  41. morecompute/utils/config_util.py +59 -0
  42. morecompute/utils/special_commands.py +11 -5
  43. morecompute/utils/zmq_util.py +51 -0
  44. frontend/components/MarkdownRenderer.tsx +0 -84
  45. frontend/components/popups/PythonPopup.tsx +0 -292
  46. more_compute-0.1.4.dist-info/METADATA +0 -173
  47. more_compute-0.1.4.dist-info/RECORD +0 -86
  48. /frontend/components/{CellButton.tsx → cell/CellButton.tsx} +0 -0
  49. /frontend/components/{ErrorModal.tsx → modals/ErrorModal.tsx} +0 -0
  50. /frontend/components/{CellOutput.tsx → output/CellOutput.tsx} +0 -0
  51. /frontend/components/{ErrorDisplay.tsx → output/ErrorDisplay.tsx} +0 -0
  52. {more_compute-0.1.4.dist-info → more_compute-0.2.0.dist-info}/WHEEL +0 -0
  53. {more_compute-0.1.4.dist-info → more_compute-0.2.0.dist-info}/entry_points.txt +0 -0
  54. {more_compute-0.1.4.dist-info → more_compute-0.2.0.dist-info}/licenses/LICENSE +0 -0
  55. {more_compute-0.1.4.dist-info → more_compute-0.2.0.dist-info}/top_level.txt +0 -0
morecompute/server.py CHANGED
@@ -3,6 +3,8 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPExcept
3
3
  from fastapi.responses import PlainTextResponse
4
4
  from fastapi.staticfiles import StaticFiles
5
5
  import os
6
+ import sys
7
+ import asyncio
6
8
  from datetime import datetime, timezone
7
9
  from pathlib import Path
8
10
  import importlib.metadata as importlib_metadata
@@ -16,8 +18,20 @@ from .utils.system_environment_util import DeviceMetrics
16
18
  from .utils.error_utils import ErrorUtils
17
19
  from .utils.cache_util import make_cache_key
18
20
  from .utils.notebook_util import coerce_cell_source
19
- from .services.prime_intellect import PrimeIntellectService, CreatePodRequest, PodResponse
21
+ from .utils.config_util import load_api_key_from_env, save_api_key_to_env
22
+ from .utils.zmq_util import reconnect_zmq_sockets, reset_to_local_zmq
23
+ from .services.prime_intellect import PrimeIntellectService
20
24
  from .services.pod_manager import PodKernelManager
25
+ from .services.data_manager import DataManager
26
+ from .services.pod_monitor import PodMonitor
27
+ from .services.lsp_service import LSPService
28
+ from .models.api_models import (
29
+ ApiKeyRequest,
30
+ ApiKeyResponse,
31
+ ConfigStatusResponse,
32
+ CreatePodRequest,
33
+ PodResponse,
34
+ )
21
35
 
22
36
 
23
37
  BASE_DIR = Path(os.getenv("MORECOMPUTE_ROOT", Path.cwd())).resolve()
@@ -45,7 +59,6 @@ environments_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for environ
45
59
  if ASSETS_DIR.exists():
46
60
  app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
47
61
 
48
- # Global instances for the application state
49
62
  notebook_path_env = os.getenv("MORECOMPUTE_NOTEBOOK_PATH")
50
63
  if notebook_path_env:
51
64
  notebook = Notebook(file_path=notebook_path_env)
@@ -54,34 +67,56 @@ else:
54
67
  error_utils = ErrorUtils()
55
68
  executor = NextZmqExecutor(error_utils=error_utils)
56
69
  metrics = DeviceMetrics()
57
-
58
- # Initialize Prime Intellect service if API key is provided
59
- # Check environment variable first, then .env file (commonly gitignored)
60
- prime_api_key = os.getenv("PRIME_INTELLECT_API_KEY")
61
- if not prime_api_key:
62
- env_path = BASE_DIR / ".env"
63
- if env_path.exists():
64
- try:
65
- with env_path.open("r", encoding="utf-8") as f:
66
- for line in f:
67
- line = line.strip()
68
- if line.startswith("PRIME_INTELLECT_API_KEY="):
69
- prime_api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
70
- break
71
- except Exception:
72
- pass
73
-
70
+ prime_api_key = load_api_key_from_env("PRIME_INTELLECT_API_KEY", BASE_DIR / ".env")
74
71
  prime_intellect = PrimeIntellectService(api_key=prime_api_key) if prime_api_key else None
75
72
  pod_manager: PodKernelManager | None = None
73
+ data_manager = DataManager(prime_intellect=prime_intellect)
74
+ pod_monitor: PodMonitor | None = None
75
+ if prime_intellect:
76
+ pod_monitor = PodMonitor(
77
+ prime_intellect=prime_intellect,
78
+ pod_cache=pod_cache,
79
+ update_callback=lambda msg: manager.broadcast_pod_update(msg)
80
+ )
81
+
82
+ # LSP service for Python autocomplete
83
+ lsp_service: LSPService | None = None
84
+
85
+
86
+ @app.on_event("startup")
87
+ async def startup_event():
88
+ """Initialize services on startup."""
89
+ global lsp_service
90
+ try:
91
+ lsp_service = LSPService(workspace_root=BASE_DIR)
92
+ await lsp_service.start()
93
+ print("[LSP] Pyright language server started successfully", file=sys.stderr, flush=True)
94
+ except Exception as e:
95
+ print(f"[LSP] Failed to start language server: {e}", file=sys.stderr, flush=True)
96
+ lsp_service = None
97
+
98
+
99
+ @app.on_event("shutdown")
100
+ async def shutdown_event():
101
+ """Cleanup services on shutdown."""
102
+ global lsp_service
103
+ if lsp_service:
104
+ try:
105
+ await lsp_service.shutdown()
106
+ print("[LSP] Language server shutdown complete", file=sys.stderr, flush=True)
107
+ except Exception as e:
108
+ print(f"[LSP] Error during shutdown: {e}", file=sys.stderr, flush=True)
76
109
 
77
110
 
78
111
  @app.get("/api/packages")
79
112
  async def list_installed_packages(force_refresh: bool = False):
80
113
  """
81
114
  Return installed packages for the current Python runtime.
115
+ Fetches from remote pod if connected, otherwise from local environment.
82
116
  Args:
83
117
  force_refresh: If True, bypass cache and fetch fresh data
84
118
  """
119
+ global pod_manager
85
120
  cache_key = "packages_list"
86
121
 
87
122
  # Check cache first unless force refresh is requested
@@ -89,6 +124,29 @@ async def list_installed_packages(force_refresh: bool = False):
89
124
  return packages_cache[cache_key]
90
125
 
91
126
  try:
127
+ # If connected to remote pod, fetch packages from there
128
+ if pod_manager and pod_manager.pod:
129
+ try:
130
+ stdout, stderr, returncode = await pod_manager.execute_ssh_command(
131
+ "python3 -m pip list --format=json 2>/dev/null || pip list --format=json"
132
+ )
133
+
134
+ if returncode == 0 and stdout.strip():
135
+ import json
136
+ pkgs_data = json.loads(stdout)
137
+ packages = [{"name": p["name"], "version": p["version"]} for p in pkgs_data]
138
+ packages.sort(key=lambda p: p["name"].lower())
139
+ result = {"packages": packages}
140
+ packages_cache[cache_key] = result
141
+ return result
142
+ else:
143
+ print(f"[API/PACKAGES] Remote command failed (code {returncode}): {stderr}", file=sys.stderr, flush=True)
144
+ # Fall through to local packages
145
+ except Exception as e:
146
+ print(f"[API/PACKAGES] Failed to fetch remote packages: {e}", file=sys.stderr, flush=True)
147
+ # Fall through to local packages
148
+
149
+ # Local packages (fallback or when not connected)
92
150
  packages = []
93
151
  for dist in importlib_metadata.distributions():
94
152
  name = dist.metadata.get("Name") or dist.metadata.get("Summary") or dist.metadata.get("name")
@@ -106,7 +164,68 @@ async def list_installed_packages(force_refresh: bool = False):
106
164
 
107
165
  @app.get("/api/metrics")
108
166
  async def get_metrics():
167
+ global pod_manager
109
168
  try:
169
+ # If connected to remote pod, fetch metrics from there
170
+ if pod_manager and pod_manager.pod:
171
+ try:
172
+ # Python script to collect metrics on remote pod
173
+ metrics_script = """
174
+ import json, psutil
175
+ try:
176
+ import pynvml
177
+ pynvml.nvmlInit()
178
+ gpu_count = pynvml.nvmlDeviceGetCount()
179
+ gpus = []
180
+ for i in range(gpu_count):
181
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
182
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
183
+ mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
184
+ try:
185
+ temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
186
+ except:
187
+ temp = None
188
+ gpus.append({
189
+ "util_percent": util.gpu,
190
+ "mem_used": mem.used,
191
+ "mem_total": mem.total,
192
+ "temperature_c": temp
193
+ })
194
+ pynvml.nvmlShutdown()
195
+ except:
196
+ gpus = []
197
+ cpu = psutil.cpu_percent(interval=0.1)
198
+ mem = psutil.virtual_memory()
199
+ disk = psutil.disk_usage('/')
200
+ net = psutil.net_io_counters()
201
+ proc = psutil.Process()
202
+ mem_info = proc.memory_info()
203
+ print(json.dumps({
204
+ "cpu": {"percent": cpu, "cores": psutil.cpu_count()},
205
+ "memory": {"percent": mem.percent, "used": mem.used, "total": mem.total},
206
+ "storage": {"percent": disk.percent, "used": disk.used, "total": disk.total},
207
+ "gpu": gpus,
208
+ "network": {"bytes_sent": net.bytes_sent, "bytes_recv": net.bytes_recv},
209
+ "process": {"rss": mem_info.rss, "threads": proc.num_threads()}
210
+ }))
211
+ """
212
+ # Escape single quotes in the script for shell
213
+ escaped_script = metrics_script.replace("'", "'\"'\"'")
214
+ stdout, stderr, returncode = await pod_manager.execute_ssh_command(
215
+ f"python3 -c '{escaped_script}'"
216
+ )
217
+
218
+ if returncode == 0 and stdout.strip():
219
+ import json
220
+ return json.loads(stdout)
221
+ else:
222
+ print(f"[API/METRICS] Remote command failed (code {returncode}): {stderr}", file=sys.stderr, flush=True)
223
+ # Fall through to local metrics
224
+ except Exception as e:
225
+ print(f"[API/METRICS] Failed to fetch remote metrics: {e}", file=sys.stderr, flush=True)
226
+ # Fall through to local metrics
227
+
228
+ # Local metrics (fallback or when not connected)
110
229
  return metrics.get_all_devices()
111
230
  except Exception as exc:
112
231
  raise HTTPException(status_code=500, detail=f"Failed to get metrics: {exc}")
@@ -183,6 +302,80 @@ async def fix_indentation(request: Request):
183
302
  raise HTTPException(status_code=500, detail=f"Failed to fix indentation: {exc}")
184
303
 
185
304
 
305
+ @app.post("/api/lsp/completions")
306
+ async def get_lsp_completions(request: Request):
307
+ """
308
+ Get LSP code completions for Python.
309
+
310
+ Body:
311
+ cell_id: Unique cell identifier
312
+ source: Full source code of the cell
313
+ line: Line number (0-indexed)
314
+ character: Character position in line
315
+
316
+ Returns:
317
+ List of completion items with label, kind, detail, documentation
318
+ """
319
+ if not lsp_service:
320
+ raise HTTPException(status_code=503, detail="LSP service not available")
321
+
322
+ try:
323
+ body = await request.json()
324
+ cell_id = body.get("cell_id", "0")
325
+ source = body.get("source", "")
326
+ line = body.get("line", 0)
327
+ character = body.get("character", 0)
328
+
329
+ completions = await lsp_service.get_completions(
330
+ cell_id=str(cell_id),
331
+ source=source,
332
+ line=line,
333
+ character=character
334
+ )
335
+
336
+ return {"completions": completions}
337
+
338
+ except Exception as exc:
339
+ raise HTTPException(status_code=500, detail=f"LSP completion error: {exc}")
340
+
341
+
342
+ @app.post("/api/lsp/hover")
343
+ async def get_lsp_hover(request: Request):
344
+ """
345
+ Get hover information for Python code.
346
+
347
+ Body:
348
+ cell_id: Unique cell identifier
349
+ source: Full source code of the cell
350
+ line: Line number (0-indexed)
351
+ character: Character position in line
352
+
353
+ Returns:
354
+ Hover information with documentation
355
+ """
356
+ if not lsp_service:
357
+ raise HTTPException(status_code=503, detail="LSP service not available")
358
+
359
+ try:
360
+ body = await request.json()
361
+ cell_id = body.get("cell_id", "0")
362
+ source = body.get("source", "")
363
+ line = body.get("line", 0)
364
+ character = body.get("character", 0)
365
+
366
+ hover_info = await lsp_service.get_hover(
367
+ cell_id=str(cell_id),
368
+ source=source,
369
+ line=line,
370
+ character=character
371
+ )
372
+
373
+ return {"hover": hover_info}
374
+
375
+ except Exception as exc:
376
+ raise HTTPException(status_code=500, detail=f"LSP hover error: {exc}")
377
+
378
+
186
379
  @app.get("/api/file")
187
380
  async def read_file(path: str, max_bytes: int = 256_000):
188
381
  file_path = resolve_path(path)
@@ -234,6 +427,14 @@ class WebSocketManager:
234
427
  "data": updated_data
235
428
  })
236
429
 
430
+ async def broadcast_pod_update(self, message: dict):
431
+ """Broadcast pod status updates to all connected clients."""
432
+ for client in self.clients:
433
+ try:
434
+ await client.send_json(message)
435
+ except Exception:
436
+ pass
437
+
237
438
  async def handle_message_loop(self, websocket: WebSocket):
238
439
  """Main loop to handle incoming WebSocket messages."""
239
440
  while True:
@@ -255,6 +456,7 @@ class WebSocketManager:
255
456
  "add_cell": self._handle_add_cell,
256
457
  "delete_cell": self._handle_delete_cell,
257
458
  "update_cell": self._handle_update_cell,
459
+ "move_cell": self._handle_move_cell,
258
460
  "interrupt_kernel": self._handle_interrupt_kernel,
259
461
  "reset_kernel": self._handle_reset_kernel,
260
462
  "load_notebook": self._handle_load_notebook,
@@ -319,7 +521,15 @@ class WebSocketManager:
319
521
  async def _handle_add_cell(self, websocket: WebSocket, data: dict):
320
522
  index = data.get('index', len(self.notebook.cells))
321
523
  cell_type = data.get('cell_type', 'code')
322
- self.notebook.add_cell(index=index, cell_type=cell_type)
524
+ source = data.get('source', '')
525
+ full_cell = data.get('full_cell')
526
+
527
+ if full_cell:
528
+ # Restore full cell data (for undo functionality)
529
+ self.notebook.add_cell(index=index, cell_type=cell_type, source=source, full_cell=full_cell)
530
+ else:
531
+ # Normal add cell
532
+ self.notebook.add_cell(index=index, cell_type=cell_type, source=source)
323
533
  await self.broadcast_notebook_update()
324
534
 
325
535
  async def _handle_delete_cell(self, websocket: WebSocket, data: dict):
@@ -336,6 +546,17 @@ class WebSocketManager:
336
546
  #self.notebook.save_to_file()
337
547
  #to -do?
338
548
 
549
+ async def _handle_move_cell(self, websocket: WebSocket, data: dict):
550
+ from_index = data.get('from_index')
551
+ to_index = data.get('to_index')
552
+ if from_index is not None and to_index is not None:
553
+ self.notebook.move_cell(from_index, to_index)
554
+ # Save the notebook after moving cells
555
+ try:
556
+ self.notebook.save_to_file()
557
+ except Exception as e:
558
+ print(f"Warning: Failed to save notebook after moving cell: {e}", file=sys.stderr)
559
+ await self.broadcast_notebook_update()
339
560
 
340
561
  async def _handle_load_notebook(self, websocket: WebSocket, data: dict):
341
562
  # In a real app, this would load from a file path in `data`
@@ -420,45 +641,35 @@ async def websocket_endpoint(websocket: WebSocket):
420
641
  await manager.handle_message_loop(websocket)
421
642
 
422
643
 
423
- #gpu connection api
424
- @app.get("/api/gpu/config")
425
- async def get_gpu_config():
644
+ # GPU connection API
645
+ @app.get("/api/gpu/config", response_model=ConfigStatusResponse)
646
+ async def get_gpu_config() -> ConfigStatusResponse:
426
647
  """Check if Prime Intellect API is configured."""
427
- return {"configured": prime_intellect is not None}
428
-
429
-
430
- @app.post("/api/gpu/config")
431
- async def set_gpu_config(request: Request):
432
- """Save Prime Intellect API key to .env file (commonly gitignored) and reinitialize service."""
433
- global prime_intellect
648
+ return ConfigStatusResponse(configured=prime_intellect is not None)
434
649
 
435
- try:
436
- body = await request.json()
437
- api_key = body.get("api_key", "").strip()
438
- if not api_key:
439
- raise HTTPException(status_code=400, detail="API key is required")
440
650
 
441
- env_path = BASE_DIR / ".env"
651
+ @app.post("/api/gpu/config", response_model=ApiKeyResponse)
652
+ async def set_gpu_config(request: ApiKeyRequest) -> ApiKeyResponse:
653
+ """Save Prime Intellect API key to .env file and reinitialize service."""
654
+ global prime_intellect, pod_monitor
442
655
 
443
- # Read existing .env content
444
- existing_lines = []
445
- if env_path.exists():
446
- with env_path.open("r", encoding="utf-8") as f:
447
- existing_lines = f.readlines()
656
+ if not request.api_key.strip():
657
+ raise HTTPException(status_code=400, detail="API key is required")
448
658
 
449
- # Remove any existing PRIME_INTELLECT_API_KEY lines
450
- new_lines = [line for line in existing_lines if not line.strip().startswith("PRIME_INTELLECT_API_KEY=")]
451
- # Add the new API key
452
- new_lines.append(f"PRIME_INTELLECT_API_KEY={api_key}\n")
453
- # Write back to .env
454
- with env_path.open("w", encoding="utf-8") as f:
455
- f.writelines(new_lines)
456
- prime_intellect = PrimeIntellectService(api_key=api_key)
659
+ try:
660
+ save_api_key_to_env("PRIME_INTELLECT_API_KEY", request.api_key, BASE_DIR / ".env")
661
+ prime_intellect = PrimeIntellectService(api_key=request.api_key)
662
+ if prime_intellect:
663
+ pod_monitor = PodMonitor(
664
+ prime_intellect=prime_intellect,
665
+ pod_cache=pod_cache,
666
+ update_callback=lambda msg: manager.broadcast_pod_update(msg)
667
+ )
457
668
 
458
- return {"configured": True, "message": "API key saved successfully"}
669
+ return ApiKeyResponse(configured=True, message="API key saved successfully")
459
670
 
460
- except HTTPException:
461
- raise
671
+ except ValueError as e:
672
+ raise HTTPException(status_code=400, detail=str(e))
462
673
  except Exception as exc:
463
674
  raise HTTPException(status_code=500, detail=f"Failed to save API key: {exc}")
464
675
 
@@ -511,28 +722,34 @@ async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int
511
722
  pod_cache[cache_key] = result
512
723
  return result
513
724
 
725
+
514
726
  @app.post("/api/gpu/pods")
515
727
  async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
516
728
  """Create a new GPU pod."""
517
729
  import sys
518
- print(f"[CREATE POD] Received request: {pod_request.model_dump()}", file=sys.stderr, flush=True)
519
730
 
520
- if not prime_intellect:
731
+ if not prime_intellect or not pod_monitor:
521
732
  raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
522
733
 
734
+ print(f"[CREATE POD] Received request: {pod_request.model_dump()}", file=sys.stderr, flush=True)
735
+
523
736
  try:
524
737
  result = await prime_intellect.create_pod(pod_request)
525
738
  print(f"[CREATE POD] Success: {result}", file=sys.stderr, flush=True)
739
+
740
+ # Clear cache and start monitoring
526
741
  pod_cache.clear()
742
+ await pod_monitor.start_monitoring(result.id)
527
743
 
528
744
  return result
745
+
529
746
  except HTTPException as e:
530
747
  if e.status_code == 402:
531
748
  raise HTTPException(
532
749
  status_code=402,
533
750
  detail="Insufficient funds in your Prime Intellect wallet. Please add credits at https://app.primeintellect.ai/dashboard/billing"
534
751
  )
535
- elif e.status_code == 401 or e.status_code == 403:
752
+ elif e.status_code in (401, 403):
536
753
  raise HTTPException(
537
754
  status_code=e.status_code,
538
755
  detail="Authentication failed. Please check your Prime Intellect API key."
@@ -569,6 +786,46 @@ async def delete_gpu_pod(pod_id: str):
569
786
  return result
570
787
 
571
788
 
789
+ async def _connect_to_pod_background(pod_id: str):
790
+ """Background task to connect to pod without blocking the HTTP response."""
791
+ global pod_manager
792
+ import sys
793
+
794
+ try:
795
+ print(f"[CONNECT BACKGROUND] Starting connection to pod {pod_id}", file=sys.stderr, flush=True)
796
+
797
+ # Disconnect from any existing pod first
798
+ # TO-DO have to fix this for multi-gpu
799
+ if pod_manager and pod_manager.pod is not None:
800
+ await pod_manager.disconnect()
801
+
802
+ result = await pod_manager.connect_to_pod(pod_id)
803
+
804
+ if result.get("status") == "ok":
805
+ pod_manager.attach_executor(executor)
806
+ addresses = pod_manager.get_executor_addresses()
807
+ reconnect_zmq_sockets(
808
+ executor,
809
+ cmd_addr=addresses["cmd_addr"],
810
+ pub_addr=addresses["pub_addr"]
811
+ )
812
+ print(f"[CONNECT BACKGROUND] Successfully connected to pod {pod_id}", file=sys.stderr, flush=True)
813
+ else:
814
+ # Connection failed - clean up
815
+ print(f"[CONNECT BACKGROUND] Failed to connect: {result}", file=sys.stderr, flush=True)
816
+ if pod_manager and pod_manager.pod:
817
+ await pod_manager.disconnect()
818
+
819
+ except Exception as e:
820
+ print(f"[CONNECT BACKGROUND] Error: {e}", file=sys.stderr, flush=True)
821
+ # Clean up on error
822
+ if pod_manager and pod_manager.pod:
823
+ try:
824
+ await pod_manager.disconnect()
825
+ except Exception as cleanup_err:
826
+ print(f"[CONNECT BACKGROUND] Cleanup error: {cleanup_err}", file=sys.stderr, flush=True)
827
+
828
+
572
829
  @app.post("/api/gpu/pods/{pod_id}/connect")
573
830
  async def connect_to_pod(pod_id: str):
574
831
  """Connect to a GPU pod and establish SSH tunnel for remote execution."""
@@ -576,33 +833,19 @@ async def connect_to_pod(pod_id: str):
576
833
 
577
834
  if not prime_intellect:
578
835
  raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
836
+
579
837
  if pod_manager is None:
580
838
  pod_manager = PodKernelManager(pi_service=prime_intellect)
581
839
 
582
- # Disconnect from any existing pod first, may need to fix later for multi-pod
583
- if pod_manager.pod is not None:
584
- await pod_manager.disconnect()
840
+ # Start the connection in the background
841
+ asyncio.create_task(_connect_to_pod_background(pod_id))
585
842
 
586
- # Connect to the new pod
587
- result = await pod_manager.connect_to_pod(pod_id)
588
-
589
- if result.get("status") == "ok":
590
- pod_manager.attach_executor(executor)
591
- addresses = pod_manager.get_executor_addresses()
592
- executor.cmd_addr = addresses["cmd_addr"]
593
- executor.pub_addr = addresses["pub_addr"]
594
-
595
- # Reconnect executor sockets to tunneled ports
596
- executor.req.close(0) # type: ignore[reportAttributeAccessIssue]
597
- executor.req = executor.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
598
- executor.req.connect(executor.cmd_addr) # type: ignore[reportAttributeAccessIssue]
599
-
600
- executor.sub.close(0) # type: ignore[reportAttributeAccessIssue]
601
- executor.sub = executor.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
602
- executor.sub.connect(executor.pub_addr) # type: ignore[reportAttributeAccessIssue]
603
- executor.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
604
-
605
- return result
843
+ # Return immediately with a "connecting" status
844
+ return {
845
+ "status": "connecting",
846
+ "message": "Connection initiated. Check status endpoint for updates.",
847
+ "pod_id": pod_id
848
+ }
606
849
 
607
850
 
608
851
  @app.post("/api/gpu/pods/disconnect")
@@ -616,26 +859,263 @@ async def disconnect_from_pod():
616
859
  result = await pod_manager.disconnect()
617
860
 
618
861
  # Reset executor to local addresses
619
- executor.cmd_addr = os.getenv('MC_ZMQ_CMD_ADDR', 'tcp://127.0.0.1:5555')
620
- executor.pub_addr = os.getenv('MC_ZMQ_PUB_ADDR', 'tcp://127.0.0.1:5556')
621
-
622
- # Reconnect to local worker
623
- executor.req.close(0) # type: ignore[reportAttributeAccessIssue]
624
- executor.req = executor.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
625
- executor.req.connect(executor.cmd_addr) # type: ignore[reportAttributeAccessIssue]
626
-
627
- executor.sub.close(0) # type: ignore[reportAttributeAccessIssue]
628
- executor.sub = executor.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
629
- executor.sub.connect(executor.pub_addr) # type: ignore[reportAttributeAccessIssue]
630
- executor.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
862
+ reset_to_local_zmq(executor)
631
863
 
632
864
  return result
633
865
 
634
866
 
635
867
  @app.get("/api/gpu/pods/connection/status")
636
868
  async def get_pod_connection_status():
637
- """Get status of current pod connection."""
638
- if pod_manager is None:
639
- return {"connected": False, "pod": None}
869
+ """
870
+ Get status of current pod connection.
871
+
872
+ Returns connection status AND any running pods from Prime Intellect API.
873
+ This ensures we don't lose track of running pods after backend restart.
874
+ """
875
+ # Check local connection state first
876
+ local_status = None
877
+ if pod_manager is not None:
878
+ local_status = await pod_manager.get_status()
879
+ if local_status.get("connected"):
880
+ return local_status
881
+
882
+ # If not locally connected, check Prime Intellect API for any running pods
883
+ if prime_intellect:
884
+ try:
885
+ pods_response = await prime_intellect.get_pods(status=None, limit=100, offset=0)
886
+ pods = pods_response.get("data", [])
887
+
888
+ # Find any ACTIVE pods with SSH connection info
889
+ running_pods = [
890
+ pod for pod in pods
891
+ if pod.get("status") == "ACTIVE" and pod.get("sshConnection")
892
+ ]
893
+
894
+ if running_pods:
895
+ # Return the first running pod as "discovered but not connected"
896
+ first_pod = running_pods[0]
897
+ return {
898
+ "connected": False,
899
+ "discovered_running_pods": running_pods,
900
+ "pod": {
901
+ "id": first_pod.get("id"),
902
+ "name": first_pod.get("name"),
903
+ "status": first_pod.get("status"),
904
+ "gpu_type": first_pod.get("gpuName"),
905
+ "gpu_count": first_pod.get("gpuCount"),
906
+ "price_hr": first_pod.get("priceHr"),
907
+ "ssh_connection": first_pod.get("sshConnection")
908
+ },
909
+ "message": "Found running pod but not connected. Backend may have restarted."
910
+ }
911
+ except Exception as e:
912
+ print(f"[CONNECTION STATUS] Error checking Prime Intellect API: {e}", file=sys.stderr, flush=True)
913
+
914
+ # No connection and no running pods found
915
+ return {"connected": False, "pod": None}
916
+
917
+
918
+ @app.get("/api/gpu/pods/worker-logs")
919
+ async def get_worker_logs():
920
+ """Fetch worker logs from connected pod."""
921
+ import subprocess
922
+
923
+ if not pod_manager or not pod_manager.pod:
924
+ raise HTTPException(status_code=400, detail="Not connected to any pod")
925
+
926
+ ssh_parts = pod_manager.pod.sshConnection.split()
927
+ host_part = next((p for p in ssh_parts if "@" in p), None)
928
+ if not host_part:
929
+ raise HTTPException(status_code=500, detail="Invalid SSH connection")
930
+
931
+ ssh_host = host_part.split("@")[1]
932
+ ssh_port = ssh_parts[ssh_parts.index("-p") + 1] if "-p" in ssh_parts else "22"
933
+
934
+ ssh_key = pod_manager._get_ssh_key()
935
+ cmd = ["ssh", "-p", ssh_port]
936
+ if ssh_key:
937
+ cmd.extend(["-i", ssh_key])
938
+ cmd.extend([
939
+ "-o", "StrictHostKeyChecking=no",
940
+ "-o", "UserKnownHostsFile=/dev/null",
941
+ "-o", "BatchMode=yes",
942
+ f"root@{ssh_host}",
943
+ "cat /tmp/worker.log 2>&1 || echo 'No worker log found'"
944
+ ])
945
+
946
+ try:
947
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
948
+ return {"logs": result.stdout, "stderr": result.stderr, "returncode": result.returncode}
949
+ except Exception as e:
950
+ raise HTTPException(status_code=500, detail=f"Failed to fetch logs: {str(e)}")
951
+
952
+
953
+ # Dataset Management API
954
+ @app.get("/api/datasets/info")
955
+ async def get_dataset_info(name: str, config: str | None = None):
956
+ """
957
+ Get dataset metadata without downloading.
958
+
959
+ Args:
960
+ name: HuggingFace dataset name (e.g., "openai/gsm8k")
961
+ config: Optional dataset configuration
962
+
963
+ Returns:
964
+ Dataset metadata including size, splits, features
965
+ """
966
+ try:
967
+ info = data_manager.get_dataset_info(name, config)
968
+ return {
969
+ "name": info.name,
970
+ "size_gb": info.size_gb,
971
+ "splits": info.splits,
972
+ "features": info.features
973
+ }
974
+ except Exception as exc:
975
+ raise HTTPException(status_code=500, detail=f"Failed to get dataset info: {exc}")
976
+
977
+
978
+ @app.post("/api/datasets/check")
979
+ async def check_dataset_load(request: Request):
980
+ """
981
+ Check if dataset can be loaded and get recommendations.
982
+
983
+ Body:
984
+ name: Dataset name
985
+ config: Optional configuration
986
+ split: Optional split
987
+ auto_stream_threshold_gb: Threshold for auto-streaming (default: 10)
988
+
989
+ Returns:
990
+ Dict with action, recommendation, import_code, alternatives
991
+ """
992
+ try:
993
+ body = await request.json()
994
+ name = body.get("name")
995
+ config = body.get("config")
996
+ split = body.get("split")
997
+ threshold = body.get("auto_stream_threshold_gb", 10.0)
998
+
999
+ if not name:
1000
+ raise HTTPException(status_code=400, detail="Dataset name is required")
1001
+
1002
+ result = await data_manager.load_smart(
1003
+ dataset_name=name,
1004
+ config=config,
1005
+ split=split,
1006
+ auto_stream_threshold_gb=threshold
1007
+ )
1008
+ return result
1009
+ except HTTPException:
1010
+ raise
1011
+ except Exception as exc:
1012
+ raise HTTPException(status_code=500, detail=f"Failed to check dataset: {exc}")
640
1013
 
641
- return await pod_manager.get_status()
1014
+
1015
+ @app.get("/api/datasets/cache")
1016
+ async def list_cached_datasets():
1017
+ """
1018
+ List all cached datasets.
1019
+
1020
+ Returns:
1021
+ List of cached datasets with name, size, path
1022
+ """
1023
+ try:
1024
+ datasets = data_manager.list_cached_datasets()
1025
+ cache_size = data_manager.get_cache_size()
1026
+ return {
1027
+ "datasets": datasets,
1028
+ "total_cache_size_gb": cache_size,
1029
+ "max_cache_size_gb": data_manager.max_cache_size_gb
1030
+ }
1031
+ except Exception as exc:
1032
+ raise HTTPException(status_code=500, detail=f"Failed to list cache: {exc}")
1033
+
1034
+
1035
+ @app.delete("/api/datasets/cache/{dataset_id}")
1036
+ async def clear_dataset_cache(dataset_id: str):
1037
+ """
1038
+ Clear specific dataset from cache.
1039
+
1040
+ Args:
1041
+ dataset_id: Dataset identifier (or "all" to clear everything)
1042
+ """
1043
+ try:
1044
+ if dataset_id == "all":
1045
+ result = data_manager.clear_cache(None)
1046
+ else:
1047
+ result = data_manager.clear_cache(dataset_id)
1048
+ return result
1049
+ except Exception as exc:
1050
+ raise HTTPException(status_code=500, detail=f"Failed to clear cache: {exc}")
1051
+
1052
+
1053
+ @app.post("/api/datasets/disk/create")
1054
+ async def create_dataset_disk(request: Request):
1055
+ """
1056
+ Create disk for large dataset via Prime Intellect.
1057
+
1058
+ Body:
1059
+ pod_id: Pod to attach disk to
1060
+ disk_name: Human-readable name for the disk
1061
+ size_gb: Disk size in GB
1062
+ provider_type: Cloud provider (default: "runpod")
1063
+
1064
+ Returns:
1065
+ Dict with disk_id, mount_path, instructions
1066
+ """
1067
+ if not prime_intellect:
1068
+ raise HTTPException(status_code=503, detail="Prime Intellect API not configured")
1069
+
1070
+ try:
1071
+ body = await request.json()
1072
+ pod_id = body.get("pod_id")
1073
+ disk_name = body.get("disk_name")
1074
+ size_gb = body.get("size_gb")
1075
+ provider_type = body.get("provider_type", "runpod")
1076
+
1077
+ if not pod_id or not disk_name or not size_gb:
1078
+ raise HTTPException(status_code=400, detail="pod_id, disk_name, and size_gb are required")
1079
+
1080
+ result = await data_manager.create_and_attach_disk(
1081
+ pod_id=pod_id,
1082
+ disk_name=disk_name,
1083
+ size_gb=int(size_gb),
1084
+ provider_type=provider_type
1085
+ )
1086
+ return result
1087
+ except HTTPException:
1088
+ raise
1089
+ except Exception as exc:
1090
+ raise HTTPException(status_code=500, detail=f"Failed to create disk: {exc}")
1091
+
1092
+
1093
+ @app.get("/api/datasets/subset")
1094
+ async def get_subset_code(
1095
+ name: str,
1096
+ num_samples: int = 1000,
1097
+ split: str = "train",
1098
+ config: str | None = None
1099
+ ):
1100
+ """
1101
+ Get code to load a dataset subset for testing.
1102
+
1103
+ Args:
1104
+ name: Dataset name
1105
+ num_samples: Number of samples to load (default: 1000)
1106
+ split: Which split to use (default: "train")
1107
+ config: Optional configuration
1108
+
1109
+ Returns:
1110
+ Dict with import_code and recommendation
1111
+ """
1112
+ try:
1113
+ result = data_manager.load_subset(
1114
+ dataset_name=name,
1115
+ num_samples=num_samples,
1116
+ split=split,
1117
+ config=config
1118
+ )
1119
+ return result
1120
+ except Exception as exc:
1121
+ raise HTTPException(status_code=500, detail=f"Failed to generate subset code: {exc}")