more-compute 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. frontend/app/globals.css +734 -27
  2. frontend/app/layout.tsx +13 -3
  3. frontend/components/Notebook.tsx +2 -14
  4. frontend/components/cell/MonacoCell.tsx +99 -5
  5. frontend/components/layout/Sidebar.tsx +39 -4
  6. frontend/components/panels/ClaudePanel.tsx +461 -0
  7. frontend/components/popups/ComputePopup.tsx +739 -418
  8. frontend/components/popups/FilterPopup.tsx +305 -189
  9. frontend/components/popups/MetricsPopup.tsx +20 -1
  10. frontend/components/popups/ProviderConfigModal.tsx +322 -0
  11. frontend/components/popups/ProviderDropdown.tsx +398 -0
  12. frontend/components/popups/SettingsPopup.tsx +1 -1
  13. frontend/contexts/ClaudeContext.tsx +392 -0
  14. frontend/contexts/PodWebSocketContext.tsx +16 -21
  15. frontend/hooks/useInlineDiff.ts +269 -0
  16. frontend/lib/api.ts +323 -12
  17. frontend/lib/settings.ts +5 -0
  18. frontend/lib/websocket-native.ts +4 -8
  19. frontend/lib/websocket.ts +1 -2
  20. frontend/package-lock.json +733 -36
  21. frontend/package.json +2 -0
  22. frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
  23. frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
  24. frontend/public/assets/icons/providers/runpod.svg +9 -0
  25. frontend/public/assets/icons/providers/vastai.svg +1 -0
  26. frontend/settings.md +54 -0
  27. frontend/tsconfig.tsbuildinfo +1 -0
  28. frontend/types/claude.ts +194 -0
  29. kernel_run.py +13 -0
  30. {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
  31. {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
  32. {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
  33. morecompute/__init__.py +1 -1
  34. morecompute/__version__.py +1 -1
  35. morecompute/execution/executor.py +24 -67
  36. morecompute/execution/worker.py +6 -72
  37. morecompute/models/api_models.py +62 -0
  38. morecompute/notebook.py +11 -0
  39. morecompute/server.py +641 -133
  40. morecompute/services/claude_service.py +392 -0
  41. morecompute/services/pod_manager.py +168 -67
  42. morecompute/services/pod_monitor.py +67 -39
  43. morecompute/services/prime_intellect.py +0 -4
  44. morecompute/services/providers/__init__.py +92 -0
  45. morecompute/services/providers/base_provider.py +336 -0
  46. morecompute/services/providers/lambda_labs_provider.py +394 -0
  47. morecompute/services/providers/provider_factory.py +194 -0
  48. morecompute/services/providers/runpod_provider.py +504 -0
  49. morecompute/services/providers/vastai_provider.py +407 -0
  50. morecompute/utils/cell_magics.py +0 -3
  51. morecompute/utils/config_util.py +93 -3
  52. morecompute/utils/special_commands.py +5 -32
  53. morecompute/utils/version_check.py +117 -0
  54. frontend/styling_README.md +0 -23
  55. {more_compute-0.4.3.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
  56. {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
  57. {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
morecompute/server.py CHANGED
@@ -18,19 +18,35 @@ from .utils.system_environment_util import DeviceMetrics
18
18
  from .utils.error_utils import ErrorUtils
19
19
  from .utils.cache_util import make_cache_key
20
20
  from .utils.notebook_util import coerce_cell_source
21
- from .utils.config_util import load_api_key, save_api_key
21
+ from .utils.config_util import load_api_key, save_api_key, get_active_provider as get_active_provider_name, set_active_provider as set_active_provider_name
22
22
  from .utils.zmq_util import reconnect_zmq_sockets, reset_to_local_zmq
23
- from .services.prime_intellect import PrimeIntellectService
24
23
  from .services.pod_manager import PodKernelManager
25
24
  from .services.data_manager import DataManager
26
25
  from .services.pod_monitor import PodMonitor
27
26
  from .services.lsp_service import LSPService
27
+ from .services.claude_service import ClaudeService, ClaudeContext as ClaudeCtx, ProposedEdit
28
+ from .services.providers import (
29
+ list_providers as list_all_providers,
30
+ get_provider,
31
+ configure_provider,
32
+ get_active_provider,
33
+ set_active_provider,
34
+ refresh_provider,
35
+ BaseGPUProvider,
36
+ )
28
37
  from .models.api_models import (
29
38
  ApiKeyRequest,
30
39
  ApiKeyResponse,
31
40
  ConfigStatusResponse,
32
41
  CreatePodRequest,
33
42
  PodResponse,
43
+ ProviderInfo,
44
+ ProviderListResponse,
45
+ ProviderConfigRequest,
46
+ SetActiveProviderRequest,
47
+ GpuAvailabilityResponse,
48
+ PodListResponse,
49
+ CreatePodWithProviderRequest,
34
50
  )
35
51
 
36
52
 
@@ -49,7 +65,7 @@ def resolve_path(requested_path: str) -> Path:
49
65
  return target
50
66
 
51
67
 
52
- app = FastAPI()
68
+ app = FastAPI(redirect_slashes=False)
53
69
  gpu_cache = TTLCache(maxsize=50, ttl = 60)
54
70
  pod_cache = TTLCache(maxsize = 100, ttl = 300)
55
71
  packages_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for packages
@@ -67,21 +83,23 @@ else:
67
83
  error_utils = ErrorUtils()
68
84
  executor = NextZmqExecutor(error_utils=error_utils)
69
85
  metrics = DeviceMetrics()
70
- prime_api_key = load_api_key("PRIME_INTELLECT_API_KEY")
71
- prime_intellect = PrimeIntellectService(api_key=prime_api_key) if prime_api_key else None
72
86
  pod_manager: PodKernelManager | None = None
73
- data_manager = DataManager(prime_intellect=prime_intellect)
87
+ pod_connection_error: str | None = None # Store connection errors for status endpoint
88
+ data_manager = DataManager()
74
89
  pod_monitor: PodMonitor | None = None
75
- if prime_intellect:
76
- pod_monitor = PodMonitor(
77
- prime_intellect=prime_intellect,
78
- pod_cache=pod_cache,
79
- update_callback=lambda msg: manager.broadcast_pod_update(msg)
80
- )
81
90
 
82
91
  # LSP service for Python autocomplete
83
92
  lsp_service: LSPService | None = None
84
93
 
94
+ # Claude AI service
95
+ claude_api_key = load_api_key("CLAUDE_API_KEY")
96
+ claude_service: ClaudeService | None = None
97
+ if claude_api_key:
98
+ try:
99
+ claude_service = ClaudeService(api_key=claude_api_key)
100
+ except ImportError:
101
+ pass # anthropic package not installed
102
+
85
103
 
86
104
  @app.on_event("startup")
87
105
  async def startup_event():
@@ -90,9 +108,7 @@ async def startup_event():
90
108
  try:
91
109
  lsp_service = LSPService(workspace_root=BASE_DIR)
92
110
  await lsp_service.start()
93
- print("[LSP] Pyright language server started successfully", file=sys.stderr, flush=True)
94
- except Exception as e:
95
- print(f"[LSP] Failed to start language server: {e}", file=sys.stderr, flush=True)
111
+ except Exception:
96
112
  lsp_service = None
97
113
 
98
114
 
@@ -104,12 +120,9 @@ async def shutdown_event():
104
120
  # Shutdown executor and worker process
105
121
  if executor and executor.worker_proc:
106
122
  try:
107
- print("[EXECUTOR] Shutting down worker process...", file=sys.stderr, flush=True)
108
123
  executor.worker_proc.terminate()
109
124
  executor.worker_proc.wait(timeout=2)
110
- print("[EXECUTOR] Worker process shutdown complete", file=sys.stderr, flush=True)
111
- except Exception as e:
112
- print(f"[EXECUTOR] Error during worker shutdown, forcing kill: {e}", file=sys.stderr, flush=True)
125
+ except Exception:
113
126
  try:
114
127
  executor.worker_proc.kill()
115
128
  except Exception:
@@ -119,9 +132,8 @@ async def shutdown_event():
119
132
  if lsp_service:
120
133
  try:
121
134
  await lsp_service.shutdown()
122
- print("[LSP] Language server shutdown complete", file=sys.stderr, flush=True)
123
- except Exception as e:
124
- print(f"[LSP] Error during shutdown: {e}", file=sys.stderr, flush=True)
135
+ except Exception:
136
+ pass
125
137
 
126
138
 
127
139
  @app.get("/api/packages")
@@ -159,12 +171,9 @@ async def list_installed_packages(force_refresh: bool = False):
159
171
  result = {"packages": packages}
160
172
  packages_cache[cache_key] = result
161
173
  return result
162
- else:
163
- print(f"[API/PACKAGES] Remote command failed (code {returncode}): {stderr}", file=sys.stderr, flush=True)
164
- # Fall through to local packages
165
- except Exception as e:
166
- print(f"[API/PACKAGES] Failed to fetch remote packages: {e}", file=sys.stderr, flush=True)
167
- # Fall through to local packages
174
+ # Fall through to local packages on error
175
+ except Exception:
176
+ pass # Fall through to local packages
168
177
 
169
178
  # Local packages (fallback or when not connected)
170
179
  packages = []
@@ -238,12 +247,9 @@ print(json.dumps({
238
247
  if returncode == 0 and stdout.strip():
239
248
  import json
240
249
  return json.loads(stdout)
241
- else:
242
- print(f"[API/METRICS] Remote command failed (code {returncode}): {stderr}", file=sys.stderr, flush=True)
243
- # Fall through to local metrics
244
- except Exception as e:
245
- print(f"[API/METRICS] Failed to fetch remote metrics: {e}", file=sys.stderr, flush=True)
246
- # Fall through to local metrics
250
+ # Fall through to local metrics on error
251
+ except Exception:
252
+ pass # Fall through to local metrics
247
253
 
248
254
  # Local metrics (fallback or when not connected)
249
255
  return metrics.get_all_devices()
@@ -463,15 +469,11 @@ class WebSocketManager:
463
469
  tasks.discard(task)
464
470
  # Check for exceptions in completed tasks
465
471
  try:
466
- exc = task.exception()
467
- if exc:
468
- print(f"[SERVER] Task raised exception: {exc}", file=sys.stderr, flush=True)
469
- import traceback
470
- traceback.print_exception(type(exc), exc, exc.__traceback__)
472
+ task.exception()
471
473
  except asyncio.CancelledError:
472
474
  pass
473
- except Exception as e:
474
- print(f"[SERVER] Error in task_done_callback: {e}", file=sys.stderr, flush=True)
475
+ except Exception:
476
+ pass
475
477
 
476
478
  while True:
477
479
  try:
@@ -503,6 +505,9 @@ class WebSocketManager:
503
505
  "reset_kernel": self._handle_reset_kernel,
504
506
  "load_notebook": self._handle_load_notebook,
505
507
  "save_notebook": self._handle_save_notebook,
508
+ "claude_message": self._handle_claude_message,
509
+ "claude_apply_edit": self._handle_claude_apply_edit,
510
+ "claude_reject_edit": self._handle_claude_reject_edit,
506
511
  }
507
512
 
508
513
  handler = handlers.get(message_type)
@@ -529,8 +534,6 @@ class WebSocketManager:
529
534
  result = await self.executor.execute_cell(cell_index, source, websocket)
530
535
  except Exception as e:
531
536
  error_msg = str(e)
532
- print(f"[SERVER ERROR] execute_cell failed: {error_msg}", file=sys.stderr, flush=True)
533
-
534
537
  # Send error to frontend
535
538
  result = {
536
539
  'status': 'error',
@@ -608,8 +611,8 @@ class WebSocketManager:
608
611
  # Save the notebook after moving cells
609
612
  try:
610
613
  self.notebook.save_to_file()
611
- except Exception as e:
612
- print(f"Warning: Failed to save notebook after moving cell: {e}", file=sys.stderr)
614
+ except Exception:
615
+ pass # Silently continue if save fails
613
616
  await self.broadcast_notebook_update()
614
617
 
615
618
  async def _handle_load_notebook(self, websocket: WebSocket, data: dict):
@@ -634,22 +637,16 @@ class WebSocketManager:
634
637
  cell_index = None
635
638
 
636
639
  import sys
637
- print(f"[SERVER] Interrupt request received for cell {cell_index}", file=sys.stderr, flush=True)
638
-
639
640
  # Perform the interrupt (this may take up to 1 second)
640
641
  # The execution handler will send the appropriate error and completion messages
641
642
  await self.executor.interrupt_kernel(cell_index=cell_index)
642
643
 
643
- print(f"[SERVER] Interrupt completed, execution handler will send completion messages", file=sys.stderr, flush=True)
644
-
645
644
  # Note: We don't send completion messages here anymore because:
646
645
  # 1. For shell commands: AsyncSpecialCommandHandler._execute_shell_command sends them
647
646
  # 2. For Python code: The worker sends them
648
647
  # Sending duplicate messages causes the frontend to get confused
649
648
 
650
649
  async def _handle_reset_kernel(self, websocket: WebSocket, data: dict):
651
- import sys
652
- print(f"[SERVER] Resetting kernel", file=sys.stderr, flush=True)
653
650
  self.executor.reset_kernel()
654
651
  self.notebook.clear_all_outputs()
655
652
 
@@ -663,6 +660,120 @@ class WebSocketManager:
663
660
  })
664
661
  await self.broadcast_notebook_update()
665
662
 
663
+ async def _handle_claude_message(self, websocket: WebSocket, data: dict):
664
+ """Handle a message to Claude and stream the response."""
665
+ import uuid
666
+
667
+ if not claude_service:
668
+ await websocket.send_json({
669
+ "type": "claude_error",
670
+ "data": {"error": "Claude API key not configured. Please configure it in the Claude panel."}
671
+ })
672
+ return
673
+
674
+ message = data.get("message", "")
675
+ history = data.get("history", [])
676
+ model = data.get("model", "sonnet") # Default to sonnet
677
+
678
+ if not message.strip():
679
+ await websocket.send_json({
680
+ "type": "claude_error",
681
+ "data": {"error": "Message cannot be empty"}
682
+ })
683
+ return
684
+
685
+ # Build context from notebook state
686
+ cells = self.notebook.cells
687
+ context = ClaudeCtx(
688
+ cells=cells,
689
+ gpu_info=None, # Could fetch metrics here if needed
690
+ metrics=None,
691
+ packages=None
692
+ )
693
+
694
+ # Generate message ID
695
+ message_id = str(uuid.uuid4())
696
+
697
+ # Send stream start
698
+ await websocket.send_json({
699
+ "type": "claude_stream_start",
700
+ "data": {"messageId": message_id}
701
+ })
702
+
703
+ full_response = []
704
+ try:
705
+ async for chunk in claude_service.stream_response(message, context, history, model=model):
706
+ full_response.append(chunk)
707
+ await websocket.send_json({
708
+ "type": "claude_stream_chunk",
709
+ "data": {"messageId": message_id, "chunk": chunk}
710
+ })
711
+
712
+ # Parse edit blocks from full response
713
+ full_text = "".join(full_response)
714
+ proposed_edits = ClaudeService.parse_edit_blocks(full_text, cells)
715
+
716
+ # Convert edits to serializable format
717
+ edits_data = [
718
+ {
719
+ "id": str(uuid.uuid4()),
720
+ "cellIndex": edit.cell_index,
721
+ "originalCode": edit.original_code,
722
+ "newCode": edit.new_code,
723
+ "explanation": edit.explanation,
724
+ "status": "pending"
725
+ }
726
+ for edit in proposed_edits
727
+ ]
728
+
729
+ await websocket.send_json({
730
+ "type": "claude_stream_end",
731
+ "data": {
732
+ "messageId": message_id,
733
+ "fullResponse": full_text,
734
+ "proposedEdits": edits_data
735
+ }
736
+ })
737
+
738
+ except Exception as e:
739
+ await websocket.send_json({
740
+ "type": "claude_error",
741
+ "data": {"error": f"Error communicating with Claude: {str(e)}"}
742
+ })
743
+
744
+ async def _handle_claude_apply_edit(self, websocket: WebSocket, data: dict):
745
+ """Apply a proposed edit to a cell."""
746
+ cell_index = data.get("cellIndex")
747
+ new_code = data.get("newCode", "")
748
+ edit_id = data.get("editId", "")
749
+
750
+ if cell_index is None or not (0 <= cell_index < len(self.notebook.cells)):
751
+ await websocket.send_json({
752
+ "type": "claude_error",
753
+ "data": {"error": f"Invalid cell index: {cell_index}"}
754
+ })
755
+ return
756
+
757
+ # Update the cell source
758
+ self.notebook.update_cell(cell_index, new_code)
759
+
760
+ # Broadcast the notebook update
761
+ await self.broadcast_notebook_update()
762
+
763
+ await websocket.send_json({
764
+ "type": "claude_edit_applied",
765
+ "data": {"editId": edit_id, "cellIndex": cell_index}
766
+ })
767
+
768
+ async def _handle_claude_reject_edit(self, websocket: WebSocket, data: dict):
769
+ """Reject a proposed edit (just acknowledge, no action needed on notebook)."""
770
+ edit_id = data.get("editId", "")
771
+
772
+ await websocket.send_json({
773
+ "type": "claude_edit_rejected",
774
+ "data": {"editId": edit_id}
775
+ })
776
+
666
777
  async def _send_error(self, websocket: WebSocket, error_message: str):
667
778
  await websocket.send_json({"type": "error", "data": {"error": error_message}})
668
779
 
@@ -676,37 +787,366 @@ async def websocket_endpoint(websocket: WebSocket):
676
787
  await manager.handle_message_loop(websocket)
677
788
 
678
789
 
679
- # GPU connection API
680
- @app.get("/api/gpu/config", response_model=ConfigStatusResponse)
681
- async def get_gpu_config() -> ConfigStatusResponse:
682
- """Check if Prime Intellect API is configured."""
683
- return ConfigStatusResponse(configured=prime_intellect is not None)
790
+ # ============================================================================
791
+ # Multi-Provider GPU API
792
+ # ============================================================================
684
793
 
794
+ @app.get("/api/gpu/providers")
795
+ async def list_gpu_providers():
796
+ """List all available GPU providers with their configuration status."""
797
+ providers = list_all_providers()
798
+ active = get_active_provider_name()
685
799
 
686
- @app.post("/api/gpu/config", response_model=ApiKeyResponse)
687
- async def set_gpu_config(request: ApiKeyRequest) -> ApiKeyResponse:
688
- """Save Prime Intellect API key to user config (~/.morecompute/config.json) and reinitialize service."""
689
- global prime_intellect, pod_monitor
800
+ return {
801
+ "providers": [
802
+ {
803
+ "name": p.name,
804
+ "display_name": p.display_name,
805
+ "api_key_env_name": p.api_key_env_name,
806
+ "supports_ssh": p.supports_ssh,
807
+ "dashboard_url": p.dashboard_url,
808
+ "configured": p.configured,
809
+ "is_active": p.is_active
810
+ }
811
+ for p in providers
812
+ ],
813
+ "active_provider": active
814
+ }
690
815
 
691
- if not request.api_key.strip():
692
- raise HTTPException(status_code=400, detail="API key is required")
693
816
 
694
- try:
695
- save_api_key("PRIME_INTELLECT_API_KEY", request.api_key)
696
- prime_intellect = PrimeIntellectService(api_key=request.api_key)
697
- if prime_intellect:
817
+ @app.post("/api/gpu/providers/{provider_name}/config")
818
+ async def configure_gpu_provider(provider_name: str, request: ProviderConfigRequest):
819
+ """Configure a GPU provider with API key."""
820
+ global pod_monitor
821
+
822
+ # Handle Modal's special case (requires two tokens)
823
+ if provider_name == "modal" and request.token_secret:
824
+ # Save token secret separately
825
+ save_api_key("MODAL_TOKEN_SECRET", request.token_secret)
826
+
827
+ success = configure_provider(provider_name, request.api_key, make_active=request.make_active)
828
+
829
+ if not success:
830
+ raise HTTPException(status_code=400, detail=f"Provider '{provider_name}' not found")
831
+
832
+ # If this is being set as active, update the pod monitor
833
+ if request.make_active:
834
+ provider = get_provider(provider_name)
835
+ if provider:
698
836
  pod_monitor = PodMonitor(
699
- prime_intellect=prime_intellect,
837
+ provider_service=provider,
700
838
  pod_cache=pod_cache,
701
839
  update_callback=lambda msg: manager.broadcast_pod_update(msg)
702
840
  )
703
841
 
704
- return ApiKeyResponse(configured=True, message="API key saved successfully")
842
+ return {
843
+ "configured": True,
844
+ "provider": provider_name,
845
+ "is_active": request.make_active
846
+ }
705
847
 
706
- except ValueError as e:
707
- raise HTTPException(status_code=400, detail=str(e))
708
- except Exception as exc:
709
- raise HTTPException(status_code=500, detail=f"Failed to save API key: {exc}")
848
+
849
+ @app.post("/api/gpu/providers/active")
850
+ async def set_active_gpu_provider(request: SetActiveProviderRequest):
851
+ """Set the active GPU provider."""
852
+ success = set_active_provider(request.provider)
853
+ if not success:
854
+ raise HTTPException(
855
+ status_code=400,
856
+ detail=f"Cannot activate provider '{request.provider}'. Make sure it is configured with a valid API key."
857
+ )
858
+
859
+ # Update pod monitor with new provider
860
+ global pod_monitor
861
+ provider = get_provider(request.provider)
862
+ if provider:
863
+ pod_monitor = PodMonitor(
864
+ provider_service=provider,
865
+ pod_cache=pod_cache,
866
+ update_callback=lambda msg: manager.broadcast_pod_update(msg)
867
+ )
868
+
869
+ return {
870
+ "active_provider": request.provider,
871
+ "success": True
872
+ }
873
+
874
+
875
+ @app.get("/api/gpu/providers/{provider_name}/availability")
876
+ async def get_provider_gpu_availability(
877
+ provider_name: str,
878
+ regions: list[str] | None = None,
879
+ gpu_count: int | None = None,
880
+ gpu_type: str | None = None,
881
+ # RunPod specific filters
882
+ secure_cloud: bool | None = None,
883
+ community_cloud: bool | None = None,
884
+ # Vast.ai specific filters
885
+ verified: bool | None = None,
886
+ min_reliability: float | None = None,
887
+ min_gpu_ram: float | None = None
888
+ ):
889
+ """Get available GPU resources from a specific provider.
890
+
891
+ Args:
892
+ provider_name: Provider identifier (runpod, lambda_labs, vastai)
893
+ regions: Filter by region
894
+ gpu_count: Filter by GPU count
895
+ gpu_type: Filter by GPU type (partial match)
896
+ secure_cloud: RunPod - only show Secure Cloud GPUs
897
+ community_cloud: RunPod - only show Community Cloud GPUs
898
+ verified: Vast.ai - only show verified hosts
899
+ min_reliability: Vast.ai - minimum reliability score (0.0-1.0)
900
+ min_gpu_ram: Vast.ai - minimum GPU RAM in GB
901
+ """
902
+ provider = get_provider(provider_name)
903
+ if not provider:
904
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
905
+
906
+ if not provider.is_configured:
907
+ raise HTTPException(
908
+ status_code=503,
909
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
910
+ )
911
+
912
+ cache_key = make_cache_key(
913
+ f"gpu_avail_{provider_name}",
914
+ regions=regions,
915
+ gpu_count=gpu_count,
916
+ gpu_type=gpu_type,
917
+ secure_cloud=secure_cloud,
918
+ community_cloud=community_cloud,
919
+ verified=verified,
920
+ min_reliability=min_reliability,
921
+ min_gpu_ram=min_gpu_ram
922
+ )
923
+
924
+ if cache_key in gpu_cache:
925
+ return gpu_cache[cache_key]
926
+
927
+ result = await provider.get_gpu_availability(
928
+ regions=regions,
929
+ gpu_count=gpu_count,
930
+ gpu_type=gpu_type,
931
+ secure_cloud=secure_cloud,
932
+ community_cloud=community_cloud,
933
+ verified=verified,
934
+ min_reliability=min_reliability,
935
+ min_gpu_ram=min_gpu_ram
936
+ )
937
+ gpu_cache[cache_key] = result
938
+ return result
939
+
940
+
941
+ @app.get("/api/gpu/providers/{provider_name}/pods")
942
+ async def get_provider_pods(
943
+ provider_name: str,
944
+ status: str | None = None,
945
+ limit: int = 100,
946
+ offset: int = 0
947
+ ):
948
+ """Get list of pods from a specific provider."""
949
+ provider = get_provider(provider_name)
950
+ if not provider:
951
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
952
+
953
+ if not provider.is_configured:
954
+ raise HTTPException(
955
+ status_code=503,
956
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
957
+ )
958
+
959
+ cache_key = make_cache_key(
960
+ f"gpu_pods_{provider_name}",
961
+ status=status,
962
+ limit=limit,
963
+ offset=offset
964
+ )
965
+
966
+ if cache_key in pod_cache:
967
+ return pod_cache[cache_key]
968
+
969
+ result = await provider.get_pods(status=status, limit=limit, offset=offset)
970
+ pod_cache[cache_key] = result
971
+ return result
972
+
973
+
974
+ @app.post("/api/gpu/providers/{provider_name}/pods")
975
+ async def create_provider_pod(provider_name: str, pod_request: CreatePodRequest):
976
+ """Create a new GPU pod with a specific provider."""
977
+ provider = get_provider(provider_name)
978
+ if not provider:
979
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
980
+
981
+ if not provider.is_configured:
982
+ raise HTTPException(
983
+ status_code=503,
984
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
985
+ )
986
+
987
+ try:
988
+ result = await provider.create_pod(pod_request)
989
+
990
+ # Clear cache and start monitoring
991
+ pod_cache.clear()
992
+
993
+ # Create/update pod monitor for this provider and start monitoring
994
+ global pod_monitor
995
+ pod_monitor = PodMonitor(
996
+ provider_service=provider,
997
+ pod_cache=pod_cache,
998
+ update_callback=lambda msg: manager.broadcast_pod_update(msg)
999
+ )
1000
+ await pod_monitor.start_monitoring(result.id)
1001
+
1002
+ return result
1003
+
1004
+ except HTTPException as e:
1005
+ if e.status_code == 402:
1006
+ raise HTTPException(
1007
+ status_code=402,
1008
+ detail=f"Insufficient funds in your {provider.PROVIDER_DISPLAY_NAME} account."
1009
+ )
1010
+ elif e.status_code in (401, 403):
1011
+ raise HTTPException(
1012
+ status_code=e.status_code,
1013
+ detail=f"Authentication failed. Please check your {provider.PROVIDER_DISPLAY_NAME} API key."
1014
+ )
1015
+ else:
1016
+ raise
1017
+
1018
+
1019
+ @app.get("/api/gpu/providers/{provider_name}/pods/{pod_id}")
1020
+ async def get_provider_pod(provider_name: str, pod_id: str):
1021
+ """Get details of a specific pod from a provider."""
1022
+ provider = get_provider(provider_name)
1023
+ if not provider:
1024
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
1025
+
1026
+ if not provider.is_configured:
1027
+ raise HTTPException(
1028
+ status_code=503,
1029
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
1030
+ )
1031
+
1032
+ return await provider.get_pod(pod_id)
1033
+
1034
+
1035
+ @app.delete("/api/gpu/providers/{provider_name}/pods/{pod_id}")
1036
+ async def delete_provider_pod(provider_name: str, pod_id: str):
1037
+ """Delete a pod from a specific provider."""
1038
+ provider = get_provider(provider_name)
1039
+ if not provider:
1040
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
1041
+
1042
+ if not provider.is_configured:
1043
+ raise HTTPException(
1044
+ status_code=503,
1045
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
1046
+ )
1047
+
1048
+ result = await provider.delete_pod(pod_id)
1049
+ pod_cache.clear()
1050
+ return result
1051
+
1052
+
1053
+ @app.get("/api/gpu/providers/{provider_name}/ssh-keys")
1054
+ async def get_provider_ssh_keys(provider_name: str):
1055
+ """Get list of SSH keys registered with a provider (Lambda Labs only)."""
1056
+ provider = get_provider(provider_name)
1057
+ if not provider:
1058
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
1059
+
1060
+ if not provider.is_configured:
1061
+ raise HTTPException(
1062
+ status_code=503,
1063
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
1064
+ )
1065
+
1066
+ # Only Lambda Labs supports listing SSH keys via API
1067
+ if provider_name != "lambda_labs":
1068
+ return {
1069
+ "supported": False,
1070
+ "message": f"{provider.PROVIDER_DISPLAY_NAME} does not support listing SSH keys via API. Please check your provider's dashboard.",
1071
+ "dashboard_url": provider.DASHBOARD_URL
1072
+ }
1073
+
1074
+ try:
1075
+ # Get detailed SSH key info
1076
+ detailed_keys = await provider.get_ssh_keys_detailed()
1077
+ key_names = await provider._get_ssh_key_ids()
1078
+
1079
+ return {
1080
+ "supported": True,
1081
+ "ssh_keys": key_names,
1082
+ "ssh_keys_detailed": detailed_keys,
1083
+ "selected_key": key_names[0] if key_names else None,
1084
+ "note": "ed25519 keys are preferred. The first key shown will be used when creating new instances."
1085
+ }
1086
+ except Exception as e:
1087
+ raise HTTPException(status_code=500, detail=f"Failed to fetch SSH keys: {str(e)}")
1088
+
1089
+
1090
+ @app.post("/api/gpu/providers/{provider_name}/pods/{pod_id}/connect")
1091
+ async def connect_to_provider_pod(provider_name: str, pod_id: str):
1092
+ """Connect to a GPU pod from a specific provider."""
1093
+ global pod_manager
1094
+
1095
+ provider = get_provider(provider_name)
1096
+ if not provider:
1097
+ raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
1098
+
1099
+ if not provider.is_configured:
1100
+ raise HTTPException(
1101
+ status_code=503,
1102
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} API key not configured"
1103
+ )
1104
+
1105
+ if not provider.SUPPORTS_SSH:
1106
+ raise HTTPException(
1107
+ status_code=400,
1108
+ detail=f"{provider.PROVIDER_DISPLAY_NAME} does not support SSH connections. Use the provider's SDK for code execution."
1109
+ )
1110
+
1111
+ if pod_manager is None:
1112
+ pod_manager = PodKernelManager(provider_service=provider)
1113
+ else:
1114
+ # Update the provider on the pod manager
1115
+ pod_manager.provider_service = provider
1116
+ pod_manager.provider_type = provider_name
1117
+
1118
+ # Start the connection in the background
1119
+ asyncio.create_task(_connect_to_pod_background(pod_id))
1120
+
1121
+ return {
1122
+ "status": "connecting",
1123
+ "message": "Connection initiated. Check status endpoint for updates.",
1124
+ "pod_id": pod_id,
1125
+ "provider": provider_name
1126
+ }
1127
+
1128
+
1129
+ # ============================================================================
1130
+ # Legacy GPU API (Prime Intellect - for backwards compatibility)
1131
+ # ============================================================================
1132
+
1133
+ # GPU connection API (legacy endpoints - use provider system)
1134
+ @app.get("/api/gpu/config", response_model=ConfigStatusResponse)
1135
+ async def get_gpu_config() -> ConfigStatusResponse:
1136
+ """Check if any GPU provider is configured."""
1137
+ active_provider = get_active_provider()
1138
+ if active_provider and active_provider.is_configured:
1139
+ return ConfigStatusResponse(configured=True)
1140
+ return ConfigStatusResponse(configured=False)
1141
+
1142
+
1143
+ @app.post("/api/gpu/config", response_model=ApiKeyResponse)
1144
+ async def set_gpu_config(request: ApiKeyRequest) -> ApiKeyResponse:
1145
+ """Legacy endpoint - use /api/gpu/providers/{provider}/config instead."""
1146
+ raise HTTPException(
1147
+ status_code=400,
1148
+ detail="Please use /api/gpu/providers/{provider}/config to configure a specific provider"
1149
+ )
710
1150
 
711
1151
 
712
1152
  @app.get("/api/gpu/availability")
@@ -714,36 +1154,56 @@ async def get_gpu_availability(
714
1154
  regions: list[str] | None = None,
715
1155
  gpu_count: int | None = None,
716
1156
  gpu_type: str | None = None,
717
- security: str | None = None
1157
+ # RunPod specific filters
1158
+ secure_cloud: bool | None = None,
1159
+ community_cloud: bool | None = None,
1160
+ # Vast.ai specific filters
1161
+ verified: bool | None = None,
1162
+ min_reliability: float | None = None,
1163
+ min_gpu_ram: float | None = None
718
1164
  ):
719
- """Get available GPU resources from Prime Intellect."""
720
- if not prime_intellect:
721
- raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
1165
+ """Get available GPU resources from active provider."""
1166
+ active_provider = get_active_provider()
1167
+ if not active_provider or not active_provider.is_configured:
1168
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
722
1169
 
723
1170
  cache_key = make_cache_key(
724
- "gpu_avail",
725
- regions = regions,
726
- gpu_count = gpu_count,
727
- gpu_type = gpu_type,
728
- security=security
1171
+ f"gpu_avail_{active_provider.PROVIDER_NAME}",
1172
+ regions=regions,
1173
+ gpu_count=gpu_count,
1174
+ gpu_type=gpu_type,
1175
+ secure_cloud=secure_cloud,
1176
+ community_cloud=community_cloud,
1177
+ verified=verified,
1178
+ min_reliability=min_reliability,
1179
+ min_gpu_ram=min_gpu_ram
729
1180
  )
730
1181
 
731
1182
  if cache_key in gpu_cache:
732
1183
  return gpu_cache[cache_key]
733
1184
 
734
- #cache miss
735
- result = await prime_intellect.get_gpu_availability(regions, gpu_count, gpu_type, security)
1185
+ result = await active_provider.get_gpu_availability(
1186
+ regions=regions,
1187
+ gpu_count=gpu_count,
1188
+ gpu_type=gpu_type,
1189
+ secure_cloud=secure_cloud,
1190
+ community_cloud=community_cloud,
1191
+ verified=verified,
1192
+ min_reliability=min_reliability,
1193
+ min_gpu_ram=min_gpu_ram
1194
+ )
736
1195
  gpu_cache[cache_key] = result
737
1196
  return result
738
1197
 
739
1198
  @app.get("/api/gpu/pods")
740
1199
  async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int = 0):
741
- """Get list of user's GPU pods."""
742
- if not prime_intellect:
743
- raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
1200
+ """Get list of user's GPU pods from active provider."""
1201
+ active_provider = get_active_provider()
1202
+ if not active_provider or not active_provider.is_configured:
1203
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
744
1204
 
745
1205
  cache_key = make_cache_key(
746
- "gpu_pod",
1206
+ f"gpu_pod_{active_provider.PROVIDER_NAME}",
747
1207
  status=status,
748
1208
  limit=limit,
749
1209
  offset=offset
@@ -752,29 +1212,25 @@ async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int
752
1212
  if cache_key in pod_cache:
753
1213
  return pod_cache[cache_key]
754
1214
 
755
- # Cache miss: fetch from API
756
- result = await prime_intellect.get_pods(status, limit, offset)
1215
+ result = await active_provider.get_pods(status=status, limit=limit, offset=offset)
757
1216
  pod_cache[cache_key] = result
758
1217
  return result
759
1218
 
760
1219
 
761
1220
  @app.post("/api/gpu/pods")
762
1221
  async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
763
- """Create a new GPU pod."""
764
- import sys
765
-
766
- if not prime_intellect or not pod_monitor:
767
- raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
768
-
769
- print(f"[CREATE POD] Received request: {pod_request.model_dump()}", file=sys.stderr, flush=True)
1222
+ """Create a new GPU pod with active provider."""
1223
+ active_provider = get_active_provider()
1224
+ if not active_provider or not active_provider.is_configured:
1225
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
770
1226
 
771
1227
  try:
772
- result = await prime_intellect.create_pod(pod_request)
773
- print(f"[CREATE POD] Success: {result}", file=sys.stderr, flush=True)
1228
+ result = await active_provider.create_pod(pod_request)
774
1229
 
775
1230
  # Clear cache and start monitoring
776
1231
  pod_cache.clear()
777
- await pod_monitor.start_monitoring(result.id)
1232
+ if pod_monitor:
1233
+ await pod_monitor.start_monitoring(result.id)
778
1234
 
779
1235
  return result
780
1236
 
@@ -782,30 +1238,30 @@ async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
782
1238
  if e.status_code == 402:
783
1239
  raise HTTPException(
784
1240
  status_code=402,
785
- detail="Insufficient funds in your Prime Intellect wallet. Please add credits at https://app.primeintellect.ai/dashboard/billing"
1241
+ detail=f"Insufficient funds. Please add credits to your {active_provider.PROVIDER_DISPLAY_NAME} account."
786
1242
  )
787
1243
  elif e.status_code in (401, 403):
788
1244
  raise HTTPException(
789
1245
  status_code=e.status_code,
790
- detail="Authentication failed. Please check your Prime Intellect API key."
1246
+ detail=f"Authentication failed. Please check your {active_provider.PROVIDER_DISPLAY_NAME} API key."
791
1247
  )
792
1248
  else:
793
- print(f"[CREATE POD] Error: {e}", file=sys.stderr, flush=True)
794
1249
  raise
795
1250
 
796
1251
 
797
1252
  @app.get("/api/gpu/pods/{pod_id}")
798
1253
  async def get_gpu_pod(pod_id: str) -> PodResponse:
799
1254
  """Get details of a specific GPU pod."""
800
- if not prime_intellect:
801
- raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
1255
+ active_provider = get_active_provider()
1256
+ if not active_provider or not active_provider.is_configured:
1257
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
802
1258
 
803
- cache_key = make_cache_key("gpu_pod_detail", pod_id=pod_id)
1259
+ cache_key = make_cache_key(f"gpu_pod_detail_{active_provider.PROVIDER_NAME}", pod_id=pod_id)
804
1260
 
805
1261
  if cache_key in pod_cache:
806
1262
  return pod_cache[cache_key]
807
1263
 
808
- result = await prime_intellect.get_pod(pod_id)
1264
+ result = await active_provider.get_pod(pod_id)
809
1265
  pod_cache[cache_key] = result
810
1266
  return result
811
1267
 
@@ -813,22 +1269,23 @@ async def get_gpu_pod(pod_id: str) -> PodResponse:
813
1269
  @app.delete("/api/gpu/pods/{pod_id}")
814
1270
  async def delete_gpu_pod(pod_id: str):
815
1271
  """Delete a GPU pod."""
816
- if not prime_intellect:
817
- raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
1272
+ active_provider = get_active_provider()
1273
+ if not active_provider or not active_provider.is_configured:
1274
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
818
1275
 
819
- result = await prime_intellect.delete_pod(pod_id)
1276
+ result = await active_provider.delete_pod(pod_id)
820
1277
  pod_cache.clear()
821
1278
  return result
822
1279
 
823
1280
 
824
1281
  async def _connect_to_pod_background(pod_id: str):
825
1282
  """Background task to connect to pod without blocking the HTTP response."""
826
- global pod_manager
827
- import sys
1283
+ global pod_manager, pod_connection_error
828
1284
 
829
- try:
830
- print(f"[CONNECT BACKGROUND] Starting connection to pod {pod_id}", file=sys.stderr, flush=True)
1285
+ # Clear any previous error
1286
+ pod_connection_error = None
831
1287
 
1288
+ try:
832
1289
  # Disconnect from any existing pod first
833
1290
  # TO-DO have to fix this for multi-gpu
834
1291
  if pod_manager and pod_manager.pod is not None:
@@ -845,21 +1302,21 @@ async def _connect_to_pod_background(pod_id: str):
845
1302
  pub_addr=addresses["pub_addr"],
846
1303
  is_remote=True # Critical: Tell executor this is a remote worker
847
1304
  )
848
- print(f"[CONNECT BACKGROUND] Successfully connected to pod {pod_id}, executor.is_remote=True", file=sys.stderr, flush=True)
849
1305
  else:
850
- # Connection failed - clean up
851
- print(f"[CONNECT BACKGROUND] Failed to connect: {result}", file=sys.stderr, flush=True)
1306
+ # Connection failed - store the error message
1307
+ error_msg = result.get("message", "Connection failed")
1308
+ pod_connection_error = error_msg
852
1309
  if pod_manager and pod_manager.pod:
853
1310
  await pod_manager.disconnect()
854
1311
 
855
1312
  except Exception as e:
856
- print(f"[CONNECT BACKGROUND] Error: {e}", file=sys.stderr, flush=True)
1313
+ pod_connection_error = str(e)
857
1314
  # Clean up on error
858
1315
  if pod_manager and pod_manager.pod:
859
1316
  try:
860
1317
  await pod_manager.disconnect()
861
- except Exception as cleanup_err:
862
- print(f"[CONNECT BACKGROUND] Cleanup error: {cleanup_err}", file=sys.stderr, flush=True)
1318
+ except Exception:
1319
+ pass
863
1320
 
864
1321
 
865
1322
  @app.post("/api/gpu/pods/{pod_id}/connect")
@@ -867,11 +1324,12 @@ async def connect_to_pod(pod_id: str):
867
1324
  """Connect to a GPU pod and establish SSH tunnel for remote execution."""
868
1325
  global pod_manager
869
1326
 
870
- if not prime_intellect:
871
- raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
1327
+ active_provider = get_active_provider()
1328
+ if not active_provider or not active_provider.is_configured:
1329
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
872
1330
 
873
1331
  if pod_manager is None:
874
- pod_manager = PodKernelManager(pi_service=prime_intellect)
1332
+ pod_manager = PodKernelManager(provider_service=active_provider)
875
1333
 
876
1334
  # Start the connection in the background
877
1335
  asyncio.create_task(_connect_to_pod_background(pod_id))
@@ -905,9 +1363,21 @@ async def get_pod_connection_status():
905
1363
  """
906
1364
  Get status of current pod connection.
907
1365
 
908
- Returns connection status AND any running pods from Prime Intellect API.
1366
+ Returns connection status AND any running pods from the active provider's API.
909
1367
  This ensures we don't lose track of running pods after backend restart.
910
1368
  """
1369
+ global pod_connection_error
1370
+
1371
+ # Check if there's a connection error to report
1372
+ if pod_connection_error:
1373
+ error_msg = pod_connection_error
1374
+ pod_connection_error = None # Clear after reporting
1375
+ return {
1376
+ "connected": False,
1377
+ "pod": None,
1378
+ "error": error_msg
1379
+ }
1380
+
911
1381
  # Check local connection state first
912
1382
  local_status = None
913
1383
  if pod_manager is not None:
@@ -915,10 +1385,11 @@ async def get_pod_connection_status():
915
1385
  if local_status.get("connected"):
916
1386
  return local_status
917
1387
 
918
- # If not locally connected, check Prime Intellect API for any running pods
919
- if prime_intellect:
1388
+ # If not locally connected, check the active provider's API for any running pods
1389
+ active_provider = get_active_provider()
1390
+ if active_provider and active_provider.is_configured:
920
1391
  try:
921
- pods_response = await prime_intellect.get_pods(status=None, limit=100, offset=0)
1392
+ pods_response = await active_provider.get_pods(status=None, limit=100, offset=0)
922
1393
  pods = pods_response.get("data", [])
923
1394
 
924
1395
  # Find any ACTIVE pods with SSH connection info
@@ -942,10 +1413,11 @@ async def get_pod_connection_status():
942
1413
  "price_hr": first_pod.get("priceHr"),
943
1414
  "ssh_connection": first_pod.get("sshConnection")
944
1415
  },
1416
+ "provider": active_provider.PROVIDER_NAME,
945
1417
  "message": "Found running pod but not connected. Backend may have restarted."
946
1418
  }
947
- except Exception as e:
948
- print(f"[CONNECTION STATUS] Error checking Prime Intellect API: {e}", file=sys.stderr, flush=True)
1419
+ except Exception:
1420
+ pass
949
1421
 
950
1422
  # No connection and no running pods found
951
1423
  return {"connected": False, "pod": None}
@@ -964,7 +1436,8 @@ async def get_worker_logs():
964
1436
  if not host_part:
965
1437
  raise HTTPException(status_code=500, detail="Invalid SSH connection")
966
1438
 
967
- ssh_host = host_part.split("@")[1]
1439
+ # Extract user and host from user@host
1440
+ ssh_user, ssh_host = host_part.split("@")
968
1441
  ssh_port = ssh_parts[ssh_parts.index("-p") + 1] if "-p" in ssh_parts else "22"
969
1442
 
970
1443
  ssh_key = pod_manager._get_ssh_key()
@@ -975,7 +1448,7 @@ async def get_worker_logs():
975
1448
  "-o", "StrictHostKeyChecking=no",
976
1449
  "-o", "UserKnownHostsFile=/dev/null",
977
1450
  "-o", "BatchMode=yes",
978
- f"root@{ssh_host}",
1451
+ f"{ssh_user}@{ssh_host}",
979
1452
  "cat /tmp/worker.log 2>&1 || echo 'No worker log found'"
980
1453
  ])
981
1454
 
@@ -1100,15 +1573,16 @@ async def create_dataset_disk(request: Request):
1100
1573
  Returns:
1101
1574
  Dict with disk_id, mount_path, instructions
1102
1575
  """
1103
- if not prime_intellect:
1104
- raise HTTPException(status_code=503, detail="Prime Intellect API not configured")
1576
+ active_provider = get_active_provider()
1577
+ if not active_provider or not active_provider.is_configured:
1578
+ raise HTTPException(status_code=503, detail="No GPU provider configured. Please select and configure a provider.")
1105
1579
 
1106
1580
  try:
1107
1581
  body = await request.json()
1108
1582
  pod_id = body.get("pod_id")
1109
1583
  disk_name = body.get("disk_name")
1110
1584
  size_gb = body.get("size_gb")
1111
- provider_type = body.get("provider_type", "runpod")
1585
+ provider_type = body.get("provider_type", active_provider.PROVIDER_NAME)
1112
1586
 
1113
1587
  if not pod_id or not disk_name or not size_gb:
1114
1588
  raise HTTPException(status_code=400, detail="pod_id, disk_name, and size_gb are required")
@@ -1155,3 +1629,37 @@ async def get_subset_code(
1155
1629
  return result
1156
1630
  except Exception as exc:
1157
1631
  raise HTTPException(status_code=500, detail=f"Failed to generate subset code: {exc}")
1632
+
1633
+
1634
+ # ============================================================================
1635
+ # Claude AI API
1636
+ # ============================================================================
1637
+
1638
+ @app.get("/api/claude/config")
1639
+ async def get_claude_config():
1640
+ """Check if Claude API is configured."""
1641
+ return {"configured": claude_service is not None}
1642
+
1643
+
1644
+ @app.post("/api/claude/config")
1645
+ async def set_claude_config(request: Request):
1646
+ """Save Claude API key to user config and reinitialize service."""
1647
+ global claude_service
1648
+
1649
+ body = await request.json()
1650
+ api_key = body.get("api_key", "").strip()
1651
+
1652
+ if not api_key:
1653
+ raise HTTPException(status_code=400, detail="API key is required")
1654
+
1655
+ try:
1656
+ # Test the API key by creating a service
1657
+ test_service = ClaudeService(api_key=api_key)
1658
+ # If successful, save and use it
1659
+ save_api_key("CLAUDE_API_KEY", api_key)
1660
+ claude_service = test_service
1661
+ return {"configured": True, "message": "Claude API key saved successfully"}
1662
+ except ImportError as e:
1663
+ raise HTTPException(status_code=500, detail=f"anthropic package not installed: {e}")
1664
+ except Exception:
1665
+ raise HTTPException(status_code=400, detail="Invalid API key. Please check your credentials.")