more-compute 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontend/app/globals.css +734 -27
- frontend/app/layout.tsx +13 -3
- frontend/components/Notebook.tsx +2 -14
- frontend/components/cell/MonacoCell.tsx +99 -5
- frontend/components/layout/Sidebar.tsx +39 -4
- frontend/components/panels/ClaudePanel.tsx +461 -0
- frontend/components/popups/ComputePopup.tsx +739 -418
- frontend/components/popups/FilterPopup.tsx +305 -189
- frontend/components/popups/MetricsPopup.tsx +20 -1
- frontend/components/popups/ProviderConfigModal.tsx +322 -0
- frontend/components/popups/ProviderDropdown.tsx +398 -0
- frontend/components/popups/SettingsPopup.tsx +1 -1
- frontend/contexts/ClaudeContext.tsx +392 -0
- frontend/contexts/PodWebSocketContext.tsx +16 -21
- frontend/hooks/useInlineDiff.ts +269 -0
- frontend/lib/api.ts +323 -12
- frontend/lib/settings.ts +5 -0
- frontend/lib/websocket-native.ts +4 -8
- frontend/lib/websocket.ts +1 -2
- frontend/package-lock.json +733 -36
- frontend/package.json +2 -0
- frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
- frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
- frontend/public/assets/icons/providers/runpod.svg +9 -0
- frontend/public/assets/icons/providers/vastai.svg +1 -0
- frontend/settings.md +54 -0
- frontend/tsconfig.tsbuildinfo +1 -0
- frontend/types/claude.ts +194 -0
- kernel_run.py +13 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
- morecompute/__init__.py +1 -1
- morecompute/__version__.py +1 -1
- morecompute/execution/executor.py +24 -67
- morecompute/execution/worker.py +6 -72
- morecompute/models/api_models.py +62 -0
- morecompute/notebook.py +11 -0
- morecompute/server.py +641 -133
- morecompute/services/claude_service.py +392 -0
- morecompute/services/pod_manager.py +168 -67
- morecompute/services/pod_monitor.py +67 -39
- morecompute/services/prime_intellect.py +0 -4
- morecompute/services/providers/__init__.py +92 -0
- morecompute/services/providers/base_provider.py +336 -0
- morecompute/services/providers/lambda_labs_provider.py +394 -0
- morecompute/services/providers/provider_factory.py +194 -0
- morecompute/services/providers/runpod_provider.py +504 -0
- morecompute/services/providers/vastai_provider.py +407 -0
- morecompute/utils/cell_magics.py +0 -3
- morecompute/utils/config_util.py +93 -3
- morecompute/utils/special_commands.py +5 -32
- morecompute/utils/version_check.py +117 -0
- frontend/styling_README.md +0 -23
- {more_compute-0.4.3.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
morecompute/server.py
CHANGED
|
@@ -18,19 +18,35 @@ from .utils.system_environment_util import DeviceMetrics
|
|
|
18
18
|
from .utils.error_utils import ErrorUtils
|
|
19
19
|
from .utils.cache_util import make_cache_key
|
|
20
20
|
from .utils.notebook_util import coerce_cell_source
|
|
21
|
-
from .utils.config_util import load_api_key, save_api_key
|
|
21
|
+
from .utils.config_util import load_api_key, save_api_key, get_active_provider as get_active_provider_name, set_active_provider as set_active_provider_name
|
|
22
22
|
from .utils.zmq_util import reconnect_zmq_sockets, reset_to_local_zmq
|
|
23
|
-
from .services.prime_intellect import PrimeIntellectService
|
|
24
23
|
from .services.pod_manager import PodKernelManager
|
|
25
24
|
from .services.data_manager import DataManager
|
|
26
25
|
from .services.pod_monitor import PodMonitor
|
|
27
26
|
from .services.lsp_service import LSPService
|
|
27
|
+
from .services.claude_service import ClaudeService, ClaudeContext as ClaudeCtx, ProposedEdit
|
|
28
|
+
from .services.providers import (
|
|
29
|
+
list_providers as list_all_providers,
|
|
30
|
+
get_provider,
|
|
31
|
+
configure_provider,
|
|
32
|
+
get_active_provider,
|
|
33
|
+
set_active_provider,
|
|
34
|
+
refresh_provider,
|
|
35
|
+
BaseGPUProvider,
|
|
36
|
+
)
|
|
28
37
|
from .models.api_models import (
|
|
29
38
|
ApiKeyRequest,
|
|
30
39
|
ApiKeyResponse,
|
|
31
40
|
ConfigStatusResponse,
|
|
32
41
|
CreatePodRequest,
|
|
33
42
|
PodResponse,
|
|
43
|
+
ProviderInfo,
|
|
44
|
+
ProviderListResponse,
|
|
45
|
+
ProviderConfigRequest,
|
|
46
|
+
SetActiveProviderRequest,
|
|
47
|
+
GpuAvailabilityResponse,
|
|
48
|
+
PodListResponse,
|
|
49
|
+
CreatePodWithProviderRequest,
|
|
34
50
|
)
|
|
35
51
|
|
|
36
52
|
|
|
@@ -49,7 +65,7 @@ def resolve_path(requested_path: str) -> Path:
|
|
|
49
65
|
return target
|
|
50
66
|
|
|
51
67
|
|
|
52
|
-
app = FastAPI()
|
|
68
|
+
app = FastAPI(redirect_slashes=False)
|
|
53
69
|
gpu_cache = TTLCache(maxsize=50, ttl = 60)
|
|
54
70
|
pod_cache = TTLCache(maxsize = 100, ttl = 300)
|
|
55
71
|
packages_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for packages
|
|
@@ -67,21 +83,23 @@ else:
|
|
|
67
83
|
error_utils = ErrorUtils()
|
|
68
84
|
executor = NextZmqExecutor(error_utils=error_utils)
|
|
69
85
|
metrics = DeviceMetrics()
|
|
70
|
-
prime_api_key = load_api_key("PRIME_INTELLECT_API_KEY")
|
|
71
|
-
prime_intellect = PrimeIntellectService(api_key=prime_api_key) if prime_api_key else None
|
|
72
86
|
pod_manager: PodKernelManager | None = None
|
|
73
|
-
|
|
87
|
+
pod_connection_error: str | None = None # Store connection errors for status endpoint
|
|
88
|
+
data_manager = DataManager()
|
|
74
89
|
pod_monitor: PodMonitor | None = None
|
|
75
|
-
if prime_intellect:
|
|
76
|
-
pod_monitor = PodMonitor(
|
|
77
|
-
prime_intellect=prime_intellect,
|
|
78
|
-
pod_cache=pod_cache,
|
|
79
|
-
update_callback=lambda msg: manager.broadcast_pod_update(msg)
|
|
80
|
-
)
|
|
81
90
|
|
|
82
91
|
# LSP service for Python autocomplete
|
|
83
92
|
lsp_service: LSPService | None = None
|
|
84
93
|
|
|
94
|
+
# Claude AI service
|
|
95
|
+
claude_api_key = load_api_key("CLAUDE_API_KEY")
|
|
96
|
+
claude_service: ClaudeService | None = None
|
|
97
|
+
if claude_api_key:
|
|
98
|
+
try:
|
|
99
|
+
claude_service = ClaudeService(api_key=claude_api_key)
|
|
100
|
+
except ImportError:
|
|
101
|
+
pass # anthropic package not installed
|
|
102
|
+
|
|
85
103
|
|
|
86
104
|
@app.on_event("startup")
|
|
87
105
|
async def startup_event():
|
|
@@ -90,9 +108,7 @@ async def startup_event():
|
|
|
90
108
|
try:
|
|
91
109
|
lsp_service = LSPService(workspace_root=BASE_DIR)
|
|
92
110
|
await lsp_service.start()
|
|
93
|
-
|
|
94
|
-
except Exception as e:
|
|
95
|
-
print(f"[LSP] Failed to start language server: {e}", file=sys.stderr, flush=True)
|
|
111
|
+
except Exception:
|
|
96
112
|
lsp_service = None
|
|
97
113
|
|
|
98
114
|
|
|
@@ -104,12 +120,9 @@ async def shutdown_event():
|
|
|
104
120
|
# Shutdown executor and worker process
|
|
105
121
|
if executor and executor.worker_proc:
|
|
106
122
|
try:
|
|
107
|
-
print("[EXECUTOR] Shutting down worker process...", file=sys.stderr, flush=True)
|
|
108
123
|
executor.worker_proc.terminate()
|
|
109
124
|
executor.worker_proc.wait(timeout=2)
|
|
110
|
-
|
|
111
|
-
except Exception as e:
|
|
112
|
-
print(f"[EXECUTOR] Error during worker shutdown, forcing kill: {e}", file=sys.stderr, flush=True)
|
|
125
|
+
except Exception:
|
|
113
126
|
try:
|
|
114
127
|
executor.worker_proc.kill()
|
|
115
128
|
except Exception:
|
|
@@ -119,9 +132,8 @@ async def shutdown_event():
|
|
|
119
132
|
if lsp_service:
|
|
120
133
|
try:
|
|
121
134
|
await lsp_service.shutdown()
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
print(f"[LSP] Error during shutdown: {e}", file=sys.stderr, flush=True)
|
|
135
|
+
except Exception:
|
|
136
|
+
pass
|
|
125
137
|
|
|
126
138
|
|
|
127
139
|
@app.get("/api/packages")
|
|
@@ -159,12 +171,9 @@ async def list_installed_packages(force_refresh: bool = False):
|
|
|
159
171
|
result = {"packages": packages}
|
|
160
172
|
packages_cache[cache_key] = result
|
|
161
173
|
return result
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
except Exception as e:
|
|
166
|
-
print(f"[API/PACKAGES] Failed to fetch remote packages: {e}", file=sys.stderr, flush=True)
|
|
167
|
-
# Fall through to local packages
|
|
174
|
+
# Fall through to local packages on error
|
|
175
|
+
except Exception:
|
|
176
|
+
pass # Fall through to local packages
|
|
168
177
|
|
|
169
178
|
# Local packages (fallback or when not connected)
|
|
170
179
|
packages = []
|
|
@@ -238,12 +247,9 @@ print(json.dumps({
|
|
|
238
247
|
if returncode == 0 and stdout.strip():
|
|
239
248
|
import json
|
|
240
249
|
return json.loads(stdout)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
except Exception as e:
|
|
245
|
-
print(f"[API/METRICS] Failed to fetch remote metrics: {e}", file=sys.stderr, flush=True)
|
|
246
|
-
# Fall through to local metrics
|
|
250
|
+
# Fall through to local metrics on error
|
|
251
|
+
except Exception:
|
|
252
|
+
pass # Fall through to local metrics
|
|
247
253
|
|
|
248
254
|
# Local metrics (fallback or when not connected)
|
|
249
255
|
return metrics.get_all_devices()
|
|
@@ -463,15 +469,11 @@ class WebSocketManager:
|
|
|
463
469
|
tasks.discard(task)
|
|
464
470
|
# Check for exceptions in completed tasks
|
|
465
471
|
try:
|
|
466
|
-
|
|
467
|
-
if exc:
|
|
468
|
-
print(f"[SERVER] Task raised exception: {exc}", file=sys.stderr, flush=True)
|
|
469
|
-
import traceback
|
|
470
|
-
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
|
472
|
+
task.exception()
|
|
471
473
|
except asyncio.CancelledError:
|
|
472
474
|
pass
|
|
473
|
-
except Exception
|
|
474
|
-
|
|
475
|
+
except Exception:
|
|
476
|
+
pass
|
|
475
477
|
|
|
476
478
|
while True:
|
|
477
479
|
try:
|
|
@@ -503,6 +505,9 @@ class WebSocketManager:
|
|
|
503
505
|
"reset_kernel": self._handle_reset_kernel,
|
|
504
506
|
"load_notebook": self._handle_load_notebook,
|
|
505
507
|
"save_notebook": self._handle_save_notebook,
|
|
508
|
+
"claude_message": self._handle_claude_message,
|
|
509
|
+
"claude_apply_edit": self._handle_claude_apply_edit,
|
|
510
|
+
"claude_reject_edit": self._handle_claude_reject_edit,
|
|
506
511
|
}
|
|
507
512
|
|
|
508
513
|
handler = handlers.get(message_type)
|
|
@@ -529,8 +534,6 @@ class WebSocketManager:
|
|
|
529
534
|
result = await self.executor.execute_cell(cell_index, source, websocket)
|
|
530
535
|
except Exception as e:
|
|
531
536
|
error_msg = str(e)
|
|
532
|
-
print(f"[SERVER ERROR] execute_cell failed: {error_msg}", file=sys.stderr, flush=True)
|
|
533
|
-
|
|
534
537
|
# Send error to frontend
|
|
535
538
|
result = {
|
|
536
539
|
'status': 'error',
|
|
@@ -608,8 +611,8 @@ class WebSocketManager:
|
|
|
608
611
|
# Save the notebook after moving cells
|
|
609
612
|
try:
|
|
610
613
|
self.notebook.save_to_file()
|
|
611
|
-
except Exception
|
|
612
|
-
|
|
614
|
+
except Exception:
|
|
615
|
+
pass # Silently continue if save fails
|
|
613
616
|
await self.broadcast_notebook_update()
|
|
614
617
|
|
|
615
618
|
async def _handle_load_notebook(self, websocket: WebSocket, data: dict):
|
|
@@ -634,22 +637,16 @@ class WebSocketManager:
|
|
|
634
637
|
cell_index = None
|
|
635
638
|
|
|
636
639
|
import sys
|
|
637
|
-
print(f"[SERVER] Interrupt request received for cell {cell_index}", file=sys.stderr, flush=True)
|
|
638
|
-
|
|
639
640
|
# Perform the interrupt (this may take up to 1 second)
|
|
640
641
|
# The execution handler will send the appropriate error and completion messages
|
|
641
642
|
await self.executor.interrupt_kernel(cell_index=cell_index)
|
|
642
643
|
|
|
643
|
-
print(f"[SERVER] Interrupt completed, execution handler will send completion messages", file=sys.stderr, flush=True)
|
|
644
|
-
|
|
645
644
|
# Note: We don't send completion messages here anymore because:
|
|
646
645
|
# 1. For shell commands: AsyncSpecialCommandHandler._execute_shell_command sends them
|
|
647
646
|
# 2. For Python code: The worker sends them
|
|
648
647
|
# Sending duplicate messages causes the frontend to get confused
|
|
649
648
|
|
|
650
649
|
async def _handle_reset_kernel(self, websocket: WebSocket, data: dict):
|
|
651
|
-
import sys
|
|
652
|
-
print(f"[SERVER] Resetting kernel", file=sys.stderr, flush=True)
|
|
653
650
|
self.executor.reset_kernel()
|
|
654
651
|
self.notebook.clear_all_outputs()
|
|
655
652
|
|
|
@@ -663,6 +660,120 @@ class WebSocketManager:
|
|
|
663
660
|
})
|
|
664
661
|
await self.broadcast_notebook_update()
|
|
665
662
|
|
|
663
|
+
async def _handle_claude_message(self, websocket: WebSocket, data: dict):
|
|
664
|
+
"""Handle a message to Claude and stream the response."""
|
|
665
|
+
import uuid
|
|
666
|
+
|
|
667
|
+
if not claude_service:
|
|
668
|
+
await websocket.send_json({
|
|
669
|
+
"type": "claude_error",
|
|
670
|
+
"data": {"error": "Claude API key not configured. Please configure it in the Claude panel."}
|
|
671
|
+
})
|
|
672
|
+
return
|
|
673
|
+
|
|
674
|
+
message = data.get("message", "")
|
|
675
|
+
history = data.get("history", [])
|
|
676
|
+
model = data.get("model", "sonnet") # Default to sonnet
|
|
677
|
+
|
|
678
|
+
if not message.strip():
|
|
679
|
+
await websocket.send_json({
|
|
680
|
+
"type": "claude_error",
|
|
681
|
+
"data": {"error": "Message cannot be empty"}
|
|
682
|
+
})
|
|
683
|
+
return
|
|
684
|
+
|
|
685
|
+
# Build context from notebook state
|
|
686
|
+
cells = self.notebook.cells
|
|
687
|
+
context = ClaudeCtx(
|
|
688
|
+
cells=cells,
|
|
689
|
+
gpu_info=None, # Could fetch metrics here if needed
|
|
690
|
+
metrics=None,
|
|
691
|
+
packages=None
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# Generate message ID
|
|
695
|
+
message_id = str(uuid.uuid4())
|
|
696
|
+
|
|
697
|
+
# Send stream start
|
|
698
|
+
await websocket.send_json({
|
|
699
|
+
"type": "claude_stream_start",
|
|
700
|
+
"data": {"messageId": message_id}
|
|
701
|
+
})
|
|
702
|
+
|
|
703
|
+
full_response = []
|
|
704
|
+
try:
|
|
705
|
+
async for chunk in claude_service.stream_response(message, context, history, model=model):
|
|
706
|
+
full_response.append(chunk)
|
|
707
|
+
await websocket.send_json({
|
|
708
|
+
"type": "claude_stream_chunk",
|
|
709
|
+
"data": {"messageId": message_id, "chunk": chunk}
|
|
710
|
+
})
|
|
711
|
+
|
|
712
|
+
# Parse edit blocks from full response
|
|
713
|
+
full_text = "".join(full_response)
|
|
714
|
+
proposed_edits = ClaudeService.parse_edit_blocks(full_text, cells)
|
|
715
|
+
|
|
716
|
+
# Convert edits to serializable format
|
|
717
|
+
edits_data = [
|
|
718
|
+
{
|
|
719
|
+
"id": str(uuid.uuid4()),
|
|
720
|
+
"cellIndex": edit.cell_index,
|
|
721
|
+
"originalCode": edit.original_code,
|
|
722
|
+
"newCode": edit.new_code,
|
|
723
|
+
"explanation": edit.explanation,
|
|
724
|
+
"status": "pending"
|
|
725
|
+
}
|
|
726
|
+
for edit in proposed_edits
|
|
727
|
+
]
|
|
728
|
+
|
|
729
|
+
await websocket.send_json({
|
|
730
|
+
"type": "claude_stream_end",
|
|
731
|
+
"data": {
|
|
732
|
+
"messageId": message_id,
|
|
733
|
+
"fullResponse": full_text,
|
|
734
|
+
"proposedEdits": edits_data
|
|
735
|
+
}
|
|
736
|
+
})
|
|
737
|
+
|
|
738
|
+
except Exception as e:
|
|
739
|
+
await websocket.send_json({
|
|
740
|
+
"type": "claude_error",
|
|
741
|
+
"data": {"error": f"Error communicating with Claude: {str(e)}"}
|
|
742
|
+
})
|
|
743
|
+
|
|
744
|
+
async def _handle_claude_apply_edit(self, websocket: WebSocket, data: dict):
|
|
745
|
+
"""Apply a proposed edit to a cell."""
|
|
746
|
+
cell_index = data.get("cellIndex")
|
|
747
|
+
new_code = data.get("newCode", "")
|
|
748
|
+
edit_id = data.get("editId", "")
|
|
749
|
+
|
|
750
|
+
if cell_index is None or not (0 <= cell_index < len(self.notebook.cells)):
|
|
751
|
+
await websocket.send_json({
|
|
752
|
+
"type": "claude_error",
|
|
753
|
+
"data": {"error": f"Invalid cell index: {cell_index}"}
|
|
754
|
+
})
|
|
755
|
+
return
|
|
756
|
+
|
|
757
|
+
# Update the cell source
|
|
758
|
+
self.notebook.update_cell(cell_index, new_code)
|
|
759
|
+
|
|
760
|
+
# Broadcast the notebook update
|
|
761
|
+
await self.broadcast_notebook_update()
|
|
762
|
+
|
|
763
|
+
await websocket.send_json({
|
|
764
|
+
"type": "claude_edit_applied",
|
|
765
|
+
"data": {"editId": edit_id, "cellIndex": cell_index}
|
|
766
|
+
})
|
|
767
|
+
|
|
768
|
+
async def _handle_claude_reject_edit(self, websocket: WebSocket, data: dict):
|
|
769
|
+
"""Reject a proposed edit (just acknowledge, no action needed on notebook)."""
|
|
770
|
+
edit_id = data.get("editId", "")
|
|
771
|
+
|
|
772
|
+
await websocket.send_json({
|
|
773
|
+
"type": "claude_edit_rejected",
|
|
774
|
+
"data": {"editId": edit_id}
|
|
775
|
+
})
|
|
776
|
+
|
|
666
777
|
async def _send_error(self, websocket: WebSocket, error_message: str):
|
|
667
778
|
await websocket.send_json({"type": "error", "data": {"error": error_message}})
|
|
668
779
|
|
|
@@ -676,37 +787,366 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
|
676
787
|
await manager.handle_message_loop(websocket)
|
|
677
788
|
|
|
678
789
|
|
|
679
|
-
#
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
"""Check if Prime Intellect API is configured."""
|
|
683
|
-
return ConfigStatusResponse(configured=prime_intellect is not None)
|
|
790
|
+
# ============================================================================
|
|
791
|
+
# Multi-Provider GPU API
|
|
792
|
+
# ============================================================================
|
|
684
793
|
|
|
794
|
+
@app.get("/api/gpu/providers")
|
|
795
|
+
async def list_gpu_providers():
|
|
796
|
+
"""List all available GPU providers with their configuration status."""
|
|
797
|
+
providers = list_all_providers()
|
|
798
|
+
active = get_active_provider_name()
|
|
685
799
|
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
800
|
+
return {
|
|
801
|
+
"providers": [
|
|
802
|
+
{
|
|
803
|
+
"name": p.name,
|
|
804
|
+
"display_name": p.display_name,
|
|
805
|
+
"api_key_env_name": p.api_key_env_name,
|
|
806
|
+
"supports_ssh": p.supports_ssh,
|
|
807
|
+
"dashboard_url": p.dashboard_url,
|
|
808
|
+
"configured": p.configured,
|
|
809
|
+
"is_active": p.is_active
|
|
810
|
+
}
|
|
811
|
+
for p in providers
|
|
812
|
+
],
|
|
813
|
+
"active_provider": active
|
|
814
|
+
}
|
|
690
815
|
|
|
691
|
-
if not request.api_key.strip():
|
|
692
|
-
raise HTTPException(status_code=400, detail="API key is required")
|
|
693
816
|
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
817
|
+
@app.post("/api/gpu/providers/{provider_name}/config")
|
|
818
|
+
async def configure_gpu_provider(provider_name: str, request: ProviderConfigRequest):
|
|
819
|
+
"""Configure a GPU provider with API key."""
|
|
820
|
+
global pod_monitor
|
|
821
|
+
|
|
822
|
+
# Handle Modal's special case (requires two tokens)
|
|
823
|
+
if provider_name == "modal" and request.token_secret:
|
|
824
|
+
# Save token secret separately
|
|
825
|
+
save_api_key("MODAL_TOKEN_SECRET", request.token_secret)
|
|
826
|
+
|
|
827
|
+
success = configure_provider(provider_name, request.api_key, make_active=request.make_active)
|
|
828
|
+
|
|
829
|
+
if not success:
|
|
830
|
+
raise HTTPException(status_code=400, detail=f"Provider '{provider_name}' not found")
|
|
831
|
+
|
|
832
|
+
# If this is being set as active, update the pod monitor
|
|
833
|
+
if request.make_active:
|
|
834
|
+
provider = get_provider(provider_name)
|
|
835
|
+
if provider:
|
|
698
836
|
pod_monitor = PodMonitor(
|
|
699
|
-
|
|
837
|
+
provider_service=provider,
|
|
700
838
|
pod_cache=pod_cache,
|
|
701
839
|
update_callback=lambda msg: manager.broadcast_pod_update(msg)
|
|
702
840
|
)
|
|
703
841
|
|
|
704
|
-
|
|
842
|
+
return {
|
|
843
|
+
"configured": True,
|
|
844
|
+
"provider": provider_name,
|
|
845
|
+
"is_active": request.make_active
|
|
846
|
+
}
|
|
705
847
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
848
|
+
|
|
849
|
+
@app.post("/api/gpu/providers/active")
|
|
850
|
+
async def set_active_gpu_provider(request: SetActiveProviderRequest):
|
|
851
|
+
"""Set the active GPU provider."""
|
|
852
|
+
success = set_active_provider(request.provider)
|
|
853
|
+
if not success:
|
|
854
|
+
raise HTTPException(
|
|
855
|
+
status_code=400,
|
|
856
|
+
detail=f"Cannot activate provider '{request.provider}'. Make sure it is configured with a valid API key."
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
# Update pod monitor with new provider
|
|
860
|
+
global pod_monitor
|
|
861
|
+
provider = get_provider(request.provider)
|
|
862
|
+
if provider:
|
|
863
|
+
pod_monitor = PodMonitor(
|
|
864
|
+
provider_service=provider,
|
|
865
|
+
pod_cache=pod_cache,
|
|
866
|
+
update_callback=lambda msg: manager.broadcast_pod_update(msg)
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
return {
|
|
870
|
+
"active_provider": request.provider,
|
|
871
|
+
"success": True
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
@app.get("/api/gpu/providers/{provider_name}/availability")
|
|
876
|
+
async def get_provider_gpu_availability(
|
|
877
|
+
provider_name: str,
|
|
878
|
+
regions: list[str] | None = None,
|
|
879
|
+
gpu_count: int | None = None,
|
|
880
|
+
gpu_type: str | None = None,
|
|
881
|
+
# RunPod specific filters
|
|
882
|
+
secure_cloud: bool | None = None,
|
|
883
|
+
community_cloud: bool | None = None,
|
|
884
|
+
# Vast.ai specific filters
|
|
885
|
+
verified: bool | None = None,
|
|
886
|
+
min_reliability: float | None = None,
|
|
887
|
+
min_gpu_ram: float | None = None
|
|
888
|
+
):
|
|
889
|
+
"""Get available GPU resources from a specific provider.
|
|
890
|
+
|
|
891
|
+
Args:
|
|
892
|
+
provider_name: Provider identifier (runpod, lambda_labs, vastai)
|
|
893
|
+
regions: Filter by region
|
|
894
|
+
gpu_count: Filter by GPU count
|
|
895
|
+
gpu_type: Filter by GPU type (partial match)
|
|
896
|
+
secure_cloud: RunPod - only show Secure Cloud GPUs
|
|
897
|
+
community_cloud: RunPod - only show Community Cloud GPUs
|
|
898
|
+
verified: Vast.ai - only show verified hosts
|
|
899
|
+
min_reliability: Vast.ai - minimum reliability score (0.0-1.0)
|
|
900
|
+
min_gpu_ram: Vast.ai - minimum GPU RAM in GB
|
|
901
|
+
"""
|
|
902
|
+
provider = get_provider(provider_name)
|
|
903
|
+
if not provider:
|
|
904
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
905
|
+
|
|
906
|
+
if not provider.is_configured:
|
|
907
|
+
raise HTTPException(
|
|
908
|
+
status_code=503,
|
|
909
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
cache_key = make_cache_key(
|
|
913
|
+
f"gpu_avail_{provider_name}",
|
|
914
|
+
regions=regions,
|
|
915
|
+
gpu_count=gpu_count,
|
|
916
|
+
gpu_type=gpu_type,
|
|
917
|
+
secure_cloud=secure_cloud,
|
|
918
|
+
community_cloud=community_cloud,
|
|
919
|
+
verified=verified,
|
|
920
|
+
min_reliability=min_reliability,
|
|
921
|
+
min_gpu_ram=min_gpu_ram
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
if cache_key in gpu_cache:
|
|
925
|
+
return gpu_cache[cache_key]
|
|
926
|
+
|
|
927
|
+
result = await provider.get_gpu_availability(
|
|
928
|
+
regions=regions,
|
|
929
|
+
gpu_count=gpu_count,
|
|
930
|
+
gpu_type=gpu_type,
|
|
931
|
+
secure_cloud=secure_cloud,
|
|
932
|
+
community_cloud=community_cloud,
|
|
933
|
+
verified=verified,
|
|
934
|
+
min_reliability=min_reliability,
|
|
935
|
+
min_gpu_ram=min_gpu_ram
|
|
936
|
+
)
|
|
937
|
+
gpu_cache[cache_key] = result
|
|
938
|
+
return result
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
@app.get("/api/gpu/providers/{provider_name}/pods")
|
|
942
|
+
async def get_provider_pods(
|
|
943
|
+
provider_name: str,
|
|
944
|
+
status: str | None = None,
|
|
945
|
+
limit: int = 100,
|
|
946
|
+
offset: int = 0
|
|
947
|
+
):
|
|
948
|
+
"""Get list of pods from a specific provider."""
|
|
949
|
+
provider = get_provider(provider_name)
|
|
950
|
+
if not provider:
|
|
951
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
952
|
+
|
|
953
|
+
if not provider.is_configured:
|
|
954
|
+
raise HTTPException(
|
|
955
|
+
status_code=503,
|
|
956
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
cache_key = make_cache_key(
|
|
960
|
+
f"gpu_pods_{provider_name}",
|
|
961
|
+
status=status,
|
|
962
|
+
limit=limit,
|
|
963
|
+
offset=offset
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
if cache_key in pod_cache:
|
|
967
|
+
return pod_cache[cache_key]
|
|
968
|
+
|
|
969
|
+
result = await provider.get_pods(status=status, limit=limit, offset=offset)
|
|
970
|
+
pod_cache[cache_key] = result
|
|
971
|
+
return result
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
@app.post("/api/gpu/providers/{provider_name}/pods")
|
|
975
|
+
async def create_provider_pod(provider_name: str, pod_request: CreatePodRequest):
|
|
976
|
+
"""Create a new GPU pod with a specific provider."""
|
|
977
|
+
provider = get_provider(provider_name)
|
|
978
|
+
if not provider:
|
|
979
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
980
|
+
|
|
981
|
+
if not provider.is_configured:
|
|
982
|
+
raise HTTPException(
|
|
983
|
+
status_code=503,
|
|
984
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
try:
|
|
988
|
+
result = await provider.create_pod(pod_request)
|
|
989
|
+
|
|
990
|
+
# Clear cache and start monitoring
|
|
991
|
+
pod_cache.clear()
|
|
992
|
+
|
|
993
|
+
# Create/update pod monitor for this provider and start monitoring
|
|
994
|
+
global pod_monitor
|
|
995
|
+
pod_monitor = PodMonitor(
|
|
996
|
+
provider_service=provider,
|
|
997
|
+
pod_cache=pod_cache,
|
|
998
|
+
update_callback=lambda msg: manager.broadcast_pod_update(msg)
|
|
999
|
+
)
|
|
1000
|
+
await pod_monitor.start_monitoring(result.id)
|
|
1001
|
+
|
|
1002
|
+
return result
|
|
1003
|
+
|
|
1004
|
+
except HTTPException as e:
|
|
1005
|
+
if e.status_code == 402:
|
|
1006
|
+
raise HTTPException(
|
|
1007
|
+
status_code=402,
|
|
1008
|
+
detail=f"Insufficient funds in your {provider.PROVIDER_DISPLAY_NAME} account."
|
|
1009
|
+
)
|
|
1010
|
+
elif e.status_code in (401, 403):
|
|
1011
|
+
raise HTTPException(
|
|
1012
|
+
status_code=e.status_code,
|
|
1013
|
+
detail=f"Authentication failed. Please check your {provider.PROVIDER_DISPLAY_NAME} API key."
|
|
1014
|
+
)
|
|
1015
|
+
else:
|
|
1016
|
+
raise
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
@app.get("/api/gpu/providers/{provider_name}/pods/{pod_id}")
|
|
1020
|
+
async def get_provider_pod(provider_name: str, pod_id: str):
|
|
1021
|
+
"""Get details of a specific pod from a provider."""
|
|
1022
|
+
provider = get_provider(provider_name)
|
|
1023
|
+
if not provider:
|
|
1024
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
1025
|
+
|
|
1026
|
+
if not provider.is_configured:
|
|
1027
|
+
raise HTTPException(
|
|
1028
|
+
status_code=503,
|
|
1029
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
return await provider.get_pod(pod_id)
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
@app.delete("/api/gpu/providers/{provider_name}/pods/{pod_id}")
|
|
1036
|
+
async def delete_provider_pod(provider_name: str, pod_id: str):
|
|
1037
|
+
"""Delete a pod from a specific provider."""
|
|
1038
|
+
provider = get_provider(provider_name)
|
|
1039
|
+
if not provider:
|
|
1040
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
1041
|
+
|
|
1042
|
+
if not provider.is_configured:
|
|
1043
|
+
raise HTTPException(
|
|
1044
|
+
status_code=503,
|
|
1045
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
result = await provider.delete_pod(pod_id)
|
|
1049
|
+
pod_cache.clear()
|
|
1050
|
+
return result
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
@app.get("/api/gpu/providers/{provider_name}/ssh-keys")
|
|
1054
|
+
async def get_provider_ssh_keys(provider_name: str):
|
|
1055
|
+
"""Get list of SSH keys registered with a provider (Lambda Labs only)."""
|
|
1056
|
+
provider = get_provider(provider_name)
|
|
1057
|
+
if not provider:
|
|
1058
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
1059
|
+
|
|
1060
|
+
if not provider.is_configured:
|
|
1061
|
+
raise HTTPException(
|
|
1062
|
+
status_code=503,
|
|
1063
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
# Only Lambda Labs supports listing SSH keys via API
|
|
1067
|
+
if provider_name != "lambda_labs":
|
|
1068
|
+
return {
|
|
1069
|
+
"supported": False,
|
|
1070
|
+
"message": f"{provider.PROVIDER_DISPLAY_NAME} does not support listing SSH keys via API. Please check your provider's dashboard.",
|
|
1071
|
+
"dashboard_url": provider.DASHBOARD_URL
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
try:
|
|
1075
|
+
# Get detailed SSH key info
|
|
1076
|
+
detailed_keys = await provider.get_ssh_keys_detailed()
|
|
1077
|
+
key_names = await provider._get_ssh_key_ids()
|
|
1078
|
+
|
|
1079
|
+
return {
|
|
1080
|
+
"supported": True,
|
|
1081
|
+
"ssh_keys": key_names,
|
|
1082
|
+
"ssh_keys_detailed": detailed_keys,
|
|
1083
|
+
"selected_key": key_names[0] if key_names else None,
|
|
1084
|
+
"note": "ed25519 keys are preferred. The first key shown will be used when creating new instances."
|
|
1085
|
+
}
|
|
1086
|
+
except Exception as e:
|
|
1087
|
+
raise HTTPException(status_code=500, detail=f"Failed to fetch SSH keys: {str(e)}")
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
@app.post("/api/gpu/providers/{provider_name}/pods/{pod_id}/connect")
|
|
1091
|
+
async def connect_to_provider_pod(provider_name: str, pod_id: str):
|
|
1092
|
+
"""Connect to a GPU pod from a specific provider."""
|
|
1093
|
+
global pod_manager
|
|
1094
|
+
|
|
1095
|
+
provider = get_provider(provider_name)
|
|
1096
|
+
if not provider:
|
|
1097
|
+
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
|
|
1098
|
+
|
|
1099
|
+
if not provider.is_configured:
|
|
1100
|
+
raise HTTPException(
|
|
1101
|
+
status_code=503,
|
|
1102
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
if not provider.SUPPORTS_SSH:
|
|
1106
|
+
raise HTTPException(
|
|
1107
|
+
status_code=400,
|
|
1108
|
+
detail=f"{provider.PROVIDER_DISPLAY_NAME} does not support SSH connections. Use the provider's SDK for code execution."
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
if pod_manager is None:
|
|
1112
|
+
pod_manager = PodKernelManager(provider_service=provider)
|
|
1113
|
+
else:
|
|
1114
|
+
# Update the provider on the pod manager
|
|
1115
|
+
pod_manager.provider_service = provider
|
|
1116
|
+
pod_manager.provider_type = provider_name
|
|
1117
|
+
|
|
1118
|
+
# Start the connection in the background
|
|
1119
|
+
asyncio.create_task(_connect_to_pod_background(pod_id))
|
|
1120
|
+
|
|
1121
|
+
return {
|
|
1122
|
+
"status": "connecting",
|
|
1123
|
+
"message": "Connection initiated. Check status endpoint for updates.",
|
|
1124
|
+
"pod_id": pod_id,
|
|
1125
|
+
"provider": provider_name
|
|
1126
|
+
}
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
# ============================================================================
|
|
1130
|
+
# Legacy GPU API (Prime Intellect - for backwards compatibility)
|
|
1131
|
+
# ============================================================================
|
|
1132
|
+
|
|
1133
|
+
# GPU connection API (legacy endpoints - use provider system)
|
|
1134
|
+
@app.get("/api/gpu/config", response_model=ConfigStatusResponse)
|
|
1135
|
+
async def get_gpu_config() -> ConfigStatusResponse:
|
|
1136
|
+
"""Check if any GPU provider is configured."""
|
|
1137
|
+
active_provider = get_active_provider()
|
|
1138
|
+
if active_provider and active_provider.is_configured:
|
|
1139
|
+
return ConfigStatusResponse(configured=True)
|
|
1140
|
+
return ConfigStatusResponse(configured=False)
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
@app.post("/api/gpu/config", response_model=ApiKeyResponse)
|
|
1144
|
+
async def set_gpu_config(request: ApiKeyRequest) -> ApiKeyResponse:
|
|
1145
|
+
"""Legacy endpoint - use /api/gpu/providers/{provider}/config instead."""
|
|
1146
|
+
raise HTTPException(
|
|
1147
|
+
status_code=400,
|
|
1148
|
+
detail="Please use /api/gpu/providers/{provider}/config to configure a specific provider"
|
|
1149
|
+
)
|
|
710
1150
|
|
|
711
1151
|
|
|
712
1152
|
@app.get("/api/gpu/availability")
|
|
@@ -714,36 +1154,56 @@ async def get_gpu_availability(
|
|
|
714
1154
|
regions: list[str] | None = None,
|
|
715
1155
|
gpu_count: int | None = None,
|
|
716
1156
|
gpu_type: str | None = None,
|
|
717
|
-
|
|
1157
|
+
# RunPod specific filters
|
|
1158
|
+
secure_cloud: bool | None = None,
|
|
1159
|
+
community_cloud: bool | None = None,
|
|
1160
|
+
# Vast.ai specific filters
|
|
1161
|
+
verified: bool | None = None,
|
|
1162
|
+
min_reliability: float | None = None,
|
|
1163
|
+
min_gpu_ram: float | None = None
|
|
718
1164
|
):
|
|
719
|
-
"""Get available GPU resources from
|
|
720
|
-
|
|
721
|
-
|
|
1165
|
+
"""Get available GPU resources from active provider."""
|
|
1166
|
+
active_provider = get_active_provider()
|
|
1167
|
+
if not active_provider or not active_provider.is_configured:
|
|
1168
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
722
1169
|
|
|
723
1170
|
cache_key = make_cache_key(
|
|
724
|
-
"
|
|
725
|
-
regions
|
|
726
|
-
gpu_count
|
|
727
|
-
gpu_type
|
|
728
|
-
|
|
1171
|
+
f"gpu_avail_{active_provider.PROVIDER_NAME}",
|
|
1172
|
+
regions=regions,
|
|
1173
|
+
gpu_count=gpu_count,
|
|
1174
|
+
gpu_type=gpu_type,
|
|
1175
|
+
secure_cloud=secure_cloud,
|
|
1176
|
+
community_cloud=community_cloud,
|
|
1177
|
+
verified=verified,
|
|
1178
|
+
min_reliability=min_reliability,
|
|
1179
|
+
min_gpu_ram=min_gpu_ram
|
|
729
1180
|
)
|
|
730
1181
|
|
|
731
1182
|
if cache_key in gpu_cache:
|
|
732
1183
|
return gpu_cache[cache_key]
|
|
733
1184
|
|
|
734
|
-
|
|
735
|
-
|
|
1185
|
+
result = await active_provider.get_gpu_availability(
|
|
1186
|
+
regions=regions,
|
|
1187
|
+
gpu_count=gpu_count,
|
|
1188
|
+
gpu_type=gpu_type,
|
|
1189
|
+
secure_cloud=secure_cloud,
|
|
1190
|
+
community_cloud=community_cloud,
|
|
1191
|
+
verified=verified,
|
|
1192
|
+
min_reliability=min_reliability,
|
|
1193
|
+
min_gpu_ram=min_gpu_ram
|
|
1194
|
+
)
|
|
736
1195
|
gpu_cache[cache_key] = result
|
|
737
1196
|
return result
|
|
738
1197
|
|
|
739
1198
|
@app.get("/api/gpu/pods")
|
|
740
1199
|
async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int = 0):
|
|
741
|
-
"""Get list of user's GPU pods."""
|
|
742
|
-
|
|
743
|
-
|
|
1200
|
+
"""Get list of user's GPU pods from active provider."""
|
|
1201
|
+
active_provider = get_active_provider()
|
|
1202
|
+
if not active_provider or not active_provider.is_configured:
|
|
1203
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
744
1204
|
|
|
745
1205
|
cache_key = make_cache_key(
|
|
746
|
-
"
|
|
1206
|
+
f"gpu_pod_{active_provider.PROVIDER_NAME}",
|
|
747
1207
|
status=status,
|
|
748
1208
|
limit=limit,
|
|
749
1209
|
offset=offset
|
|
@@ -752,29 +1212,25 @@ async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int
|
|
|
752
1212
|
if cache_key in pod_cache:
|
|
753
1213
|
return pod_cache[cache_key]
|
|
754
1214
|
|
|
755
|
-
|
|
756
|
-
result = await prime_intellect.get_pods(status, limit, offset)
|
|
1215
|
+
result = await active_provider.get_pods(status=status, limit=limit, offset=offset)
|
|
757
1216
|
pod_cache[cache_key] = result
|
|
758
1217
|
return result
|
|
759
1218
|
|
|
760
1219
|
|
|
761
1220
|
@app.post("/api/gpu/pods")
|
|
762
1221
|
async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
|
|
763
|
-
"""Create a new GPU pod."""
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
768
|
-
|
|
769
|
-
print(f"[CREATE POD] Received request: {pod_request.model_dump()}", file=sys.stderr, flush=True)
|
|
1222
|
+
"""Create a new GPU pod with active provider."""
|
|
1223
|
+
active_provider = get_active_provider()
|
|
1224
|
+
if not active_provider or not active_provider.is_configured:
|
|
1225
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
770
1226
|
|
|
771
1227
|
try:
|
|
772
|
-
result = await
|
|
773
|
-
print(f"[CREATE POD] Success: {result}", file=sys.stderr, flush=True)
|
|
1228
|
+
result = await active_provider.create_pod(pod_request)
|
|
774
1229
|
|
|
775
1230
|
# Clear cache and start monitoring
|
|
776
1231
|
pod_cache.clear()
|
|
777
|
-
|
|
1232
|
+
if pod_monitor:
|
|
1233
|
+
await pod_monitor.start_monitoring(result.id)
|
|
778
1234
|
|
|
779
1235
|
return result
|
|
780
1236
|
|
|
@@ -782,30 +1238,30 @@ async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
|
|
|
782
1238
|
if e.status_code == 402:
|
|
783
1239
|
raise HTTPException(
|
|
784
1240
|
status_code=402,
|
|
785
|
-
detail="Insufficient funds
|
|
1241
|
+
detail=f"Insufficient funds. Please add credits to your {active_provider.PROVIDER_DISPLAY_NAME} account."
|
|
786
1242
|
)
|
|
787
1243
|
elif e.status_code in (401, 403):
|
|
788
1244
|
raise HTTPException(
|
|
789
1245
|
status_code=e.status_code,
|
|
790
|
-
detail="Authentication failed. Please check your
|
|
1246
|
+
detail=f"Authentication failed. Please check your {active_provider.PROVIDER_DISPLAY_NAME} API key."
|
|
791
1247
|
)
|
|
792
1248
|
else:
|
|
793
|
-
print(f"[CREATE POD] Error: {e}", file=sys.stderr, flush=True)
|
|
794
1249
|
raise
|
|
795
1250
|
|
|
796
1251
|
|
|
797
1252
|
@app.get("/api/gpu/pods/{pod_id}")
|
|
798
1253
|
async def get_gpu_pod(pod_id: str) -> PodResponse:
|
|
799
1254
|
"""Get details of a specific GPU pod."""
|
|
800
|
-
|
|
801
|
-
|
|
1255
|
+
active_provider = get_active_provider()
|
|
1256
|
+
if not active_provider or not active_provider.is_configured:
|
|
1257
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
802
1258
|
|
|
803
|
-
cache_key = make_cache_key("
|
|
1259
|
+
cache_key = make_cache_key(f"gpu_pod_detail_{active_provider.PROVIDER_NAME}", pod_id=pod_id)
|
|
804
1260
|
|
|
805
1261
|
if cache_key in pod_cache:
|
|
806
1262
|
return pod_cache[cache_key]
|
|
807
1263
|
|
|
808
|
-
result = await
|
|
1264
|
+
result = await active_provider.get_pod(pod_id)
|
|
809
1265
|
pod_cache[cache_key] = result
|
|
810
1266
|
return result
|
|
811
1267
|
|
|
@@ -813,22 +1269,23 @@ async def get_gpu_pod(pod_id: str) -> PodResponse:
|
|
|
813
1269
|
@app.delete("/api/gpu/pods/{pod_id}")
|
|
814
1270
|
async def delete_gpu_pod(pod_id: str):
|
|
815
1271
|
"""Delete a GPU pod."""
|
|
816
|
-
|
|
817
|
-
|
|
1272
|
+
active_provider = get_active_provider()
|
|
1273
|
+
if not active_provider or not active_provider.is_configured:
|
|
1274
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
818
1275
|
|
|
819
|
-
result = await
|
|
1276
|
+
result = await active_provider.delete_pod(pod_id)
|
|
820
1277
|
pod_cache.clear()
|
|
821
1278
|
return result
|
|
822
1279
|
|
|
823
1280
|
|
|
824
1281
|
async def _connect_to_pod_background(pod_id: str):
|
|
825
1282
|
"""Background task to connect to pod without blocking the HTTP response."""
|
|
826
|
-
global pod_manager
|
|
827
|
-
import sys
|
|
1283
|
+
global pod_manager, pod_connection_error
|
|
828
1284
|
|
|
829
|
-
|
|
830
|
-
|
|
1285
|
+
# Clear any previous error
|
|
1286
|
+
pod_connection_error = None
|
|
831
1287
|
|
|
1288
|
+
try:
|
|
832
1289
|
# Disconnect from any existing pod first
|
|
833
1290
|
# TO-DO have to fix this for multi-gpu
|
|
834
1291
|
if pod_manager and pod_manager.pod is not None:
|
|
@@ -845,21 +1302,21 @@ async def _connect_to_pod_background(pod_id: str):
|
|
|
845
1302
|
pub_addr=addresses["pub_addr"],
|
|
846
1303
|
is_remote=True # Critical: Tell executor this is a remote worker
|
|
847
1304
|
)
|
|
848
|
-
print(f"[CONNECT BACKGROUND] Successfully connected to pod {pod_id}, executor.is_remote=True", file=sys.stderr, flush=True)
|
|
849
1305
|
else:
|
|
850
|
-
# Connection failed -
|
|
851
|
-
|
|
1306
|
+
# Connection failed - store the error message
|
|
1307
|
+
error_msg = result.get("message", "Connection failed")
|
|
1308
|
+
pod_connection_error = error_msg
|
|
852
1309
|
if pod_manager and pod_manager.pod:
|
|
853
1310
|
await pod_manager.disconnect()
|
|
854
1311
|
|
|
855
1312
|
except Exception as e:
|
|
856
|
-
|
|
1313
|
+
pod_connection_error = str(e)
|
|
857
1314
|
# Clean up on error
|
|
858
1315
|
if pod_manager and pod_manager.pod:
|
|
859
1316
|
try:
|
|
860
1317
|
await pod_manager.disconnect()
|
|
861
|
-
except Exception
|
|
862
|
-
|
|
1318
|
+
except Exception:
|
|
1319
|
+
pass
|
|
863
1320
|
|
|
864
1321
|
|
|
865
1322
|
@app.post("/api/gpu/pods/{pod_id}/connect")
|
|
@@ -867,11 +1324,12 @@ async def connect_to_pod(pod_id: str):
|
|
|
867
1324
|
"""Connect to a GPU pod and establish SSH tunnel for remote execution."""
|
|
868
1325
|
global pod_manager
|
|
869
1326
|
|
|
870
|
-
|
|
871
|
-
|
|
1327
|
+
active_provider = get_active_provider()
|
|
1328
|
+
if not active_provider or not active_provider.is_configured:
|
|
1329
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
872
1330
|
|
|
873
1331
|
if pod_manager is None:
|
|
874
|
-
pod_manager = PodKernelManager(
|
|
1332
|
+
pod_manager = PodKernelManager(provider_service=active_provider)
|
|
875
1333
|
|
|
876
1334
|
# Start the connection in the background
|
|
877
1335
|
asyncio.create_task(_connect_to_pod_background(pod_id))
|
|
@@ -905,9 +1363,21 @@ async def get_pod_connection_status():
|
|
|
905
1363
|
"""
|
|
906
1364
|
Get status of current pod connection.
|
|
907
1365
|
|
|
908
|
-
Returns connection status AND any running pods from
|
|
1366
|
+
Returns connection status AND any running pods from the active provider's API.
|
|
909
1367
|
This ensures we don't lose track of running pods after backend restart.
|
|
910
1368
|
"""
|
|
1369
|
+
global pod_connection_error
|
|
1370
|
+
|
|
1371
|
+
# Check if there's a connection error to report
|
|
1372
|
+
if pod_connection_error:
|
|
1373
|
+
error_msg = pod_connection_error
|
|
1374
|
+
pod_connection_error = None # Clear after reporting
|
|
1375
|
+
return {
|
|
1376
|
+
"connected": False,
|
|
1377
|
+
"pod": None,
|
|
1378
|
+
"error": error_msg
|
|
1379
|
+
}
|
|
1380
|
+
|
|
911
1381
|
# Check local connection state first
|
|
912
1382
|
local_status = None
|
|
913
1383
|
if pod_manager is not None:
|
|
@@ -915,10 +1385,11 @@ async def get_pod_connection_status():
|
|
|
915
1385
|
if local_status.get("connected"):
|
|
916
1386
|
return local_status
|
|
917
1387
|
|
|
918
|
-
# If not locally connected, check
|
|
919
|
-
|
|
1388
|
+
# If not locally connected, check the active provider's API for any running pods
|
|
1389
|
+
active_provider = get_active_provider()
|
|
1390
|
+
if active_provider and active_provider.is_configured:
|
|
920
1391
|
try:
|
|
921
|
-
pods_response = await
|
|
1392
|
+
pods_response = await active_provider.get_pods(status=None, limit=100, offset=0)
|
|
922
1393
|
pods = pods_response.get("data", [])
|
|
923
1394
|
|
|
924
1395
|
# Find any ACTIVE pods with SSH connection info
|
|
@@ -942,10 +1413,11 @@ async def get_pod_connection_status():
|
|
|
942
1413
|
"price_hr": first_pod.get("priceHr"),
|
|
943
1414
|
"ssh_connection": first_pod.get("sshConnection")
|
|
944
1415
|
},
|
|
1416
|
+
"provider": active_provider.PROVIDER_NAME,
|
|
945
1417
|
"message": "Found running pod but not connected. Backend may have restarted."
|
|
946
1418
|
}
|
|
947
|
-
except Exception
|
|
948
|
-
|
|
1419
|
+
except Exception:
|
|
1420
|
+
pass
|
|
949
1421
|
|
|
950
1422
|
# No connection and no running pods found
|
|
951
1423
|
return {"connected": False, "pod": None}
|
|
@@ -964,7 +1436,8 @@ async def get_worker_logs():
|
|
|
964
1436
|
if not host_part:
|
|
965
1437
|
raise HTTPException(status_code=500, detail="Invalid SSH connection")
|
|
966
1438
|
|
|
967
|
-
|
|
1439
|
+
# Extract user and host from user@host
|
|
1440
|
+
ssh_user, ssh_host = host_part.split("@")
|
|
968
1441
|
ssh_port = ssh_parts[ssh_parts.index("-p") + 1] if "-p" in ssh_parts else "22"
|
|
969
1442
|
|
|
970
1443
|
ssh_key = pod_manager._get_ssh_key()
|
|
@@ -975,7 +1448,7 @@ async def get_worker_logs():
|
|
|
975
1448
|
"-o", "StrictHostKeyChecking=no",
|
|
976
1449
|
"-o", "UserKnownHostsFile=/dev/null",
|
|
977
1450
|
"-o", "BatchMode=yes",
|
|
978
|
-
f"
|
|
1451
|
+
f"{ssh_user}@{ssh_host}",
|
|
979
1452
|
"cat /tmp/worker.log 2>&1 || echo 'No worker log found'"
|
|
980
1453
|
])
|
|
981
1454
|
|
|
@@ -1100,15 +1573,16 @@ async def create_dataset_disk(request: Request):
|
|
|
1100
1573
|
Returns:
|
|
1101
1574
|
Dict with disk_id, mount_path, instructions
|
|
1102
1575
|
"""
|
|
1103
|
-
|
|
1104
|
-
|
|
1576
|
+
active_provider = get_active_provider()
|
|
1577
|
+
if not active_provider or not active_provider.is_configured:
|
|
1578
|
+
raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
|
|
1105
1579
|
|
|
1106
1580
|
try:
|
|
1107
1581
|
body = await request.json()
|
|
1108
1582
|
pod_id = body.get("pod_id")
|
|
1109
1583
|
disk_name = body.get("disk_name")
|
|
1110
1584
|
size_gb = body.get("size_gb")
|
|
1111
|
-
provider_type = body.get("provider_type",
|
|
1585
|
+
provider_type = body.get("provider_type", active_provider.PROVIDER_NAME)
|
|
1112
1586
|
|
|
1113
1587
|
if not pod_id or not disk_name or not size_gb:
|
|
1114
1588
|
raise HTTPException(status_code=400, detail="pod_id, disk_name, and size_gb are required")
|
|
@@ -1155,3 +1629,37 @@ async def get_subset_code(
|
|
|
1155
1629
|
return result
|
|
1156
1630
|
except Exception as exc:
|
|
1157
1631
|
raise HTTPException(status_code=500, detail=f"Failed to generate subset code: {exc}")
|
|
1632
|
+
|
|
1633
|
+
|
|
1634
|
+
# ============================================================================
|
|
1635
|
+
# Claude AI API
|
|
1636
|
+
# ============================================================================
|
|
1637
|
+
|
|
1638
|
+
@app.get("/api/claude/config")
|
|
1639
|
+
async def get_claude_config():
|
|
1640
|
+
"""Check if Claude API is configured."""
|
|
1641
|
+
return {"configured": claude_service is not None}
|
|
1642
|
+
|
|
1643
|
+
|
|
1644
|
+
@app.post("/api/claude/config")
|
|
1645
|
+
async def set_claude_config(request: Request):
|
|
1646
|
+
"""Save Claude API key to user config and reinitialize service."""
|
|
1647
|
+
global claude_service
|
|
1648
|
+
|
|
1649
|
+
body = await request.json()
|
|
1650
|
+
api_key = body.get("api_key", "").strip()
|
|
1651
|
+
|
|
1652
|
+
if not api_key:
|
|
1653
|
+
raise HTTPException(status_code=400, detail="API key is required")
|
|
1654
|
+
|
|
1655
|
+
try:
|
|
1656
|
+
# Test the API key by creating a service
|
|
1657
|
+
test_service = ClaudeService(api_key=api_key)
|
|
1658
|
+
# If successful, save and use it
|
|
1659
|
+
save_api_key("CLAUDE_API_KEY", api_key)
|
|
1660
|
+
claude_service = test_service
|
|
1661
|
+
return {"configured": True, "message": "Claude API key saved successfully"}
|
|
1662
|
+
except ImportError as e:
|
|
1663
|
+
raise HTTPException(status_code=500, detail=f"anthropic package not installed: {e}")
|
|
1664
|
+
except Exception:
|
|
1665
|
+
raise HTTPException(status_code=400, detail="Invalid API key. Please check your credentials.")
|