more-compute 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontend/app/globals.css +734 -27
- frontend/app/layout.tsx +13 -3
- frontend/components/Notebook.tsx +2 -14
- frontend/components/cell/MonacoCell.tsx +99 -5
- frontend/components/layout/Sidebar.tsx +39 -4
- frontend/components/panels/ClaudePanel.tsx +461 -0
- frontend/components/popups/ComputePopup.tsx +739 -418
- frontend/components/popups/FilterPopup.tsx +305 -189
- frontend/components/popups/MetricsPopup.tsx +20 -1
- frontend/components/popups/ProviderConfigModal.tsx +322 -0
- frontend/components/popups/ProviderDropdown.tsx +398 -0
- frontend/components/popups/SettingsPopup.tsx +1 -1
- frontend/contexts/ClaudeContext.tsx +392 -0
- frontend/contexts/PodWebSocketContext.tsx +16 -21
- frontend/hooks/useInlineDiff.ts +269 -0
- frontend/lib/api.ts +323 -12
- frontend/lib/settings.ts +5 -0
- frontend/lib/websocket-native.ts +4 -8
- frontend/lib/websocket.ts +1 -2
- frontend/package-lock.json +733 -36
- frontend/package.json +2 -0
- frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
- frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
- frontend/public/assets/icons/providers/runpod.svg +9 -0
- frontend/public/assets/icons/providers/vastai.svg +1 -0
- frontend/settings.md +54 -0
- frontend/tsconfig.tsbuildinfo +1 -0
- frontend/types/claude.ts +194 -0
- kernel_run.py +13 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
- morecompute/__init__.py +1 -1
- morecompute/__version__.py +1 -1
- morecompute/execution/executor.py +24 -67
- morecompute/execution/worker.py +6 -72
- morecompute/models/api_models.py +62 -0
- morecompute/notebook.py +11 -0
- morecompute/server.py +641 -133
- morecompute/services/claude_service.py +392 -0
- morecompute/services/pod_manager.py +168 -67
- morecompute/services/pod_monitor.py +67 -39
- morecompute/services/prime_intellect.py +0 -4
- morecompute/services/providers/__init__.py +92 -0
- morecompute/services/providers/base_provider.py +336 -0
- morecompute/services/providers/lambda_labs_provider.py +394 -0
- morecompute/services/providers/provider_factory.py +194 -0
- morecompute/services/providers/runpod_provider.py +504 -0
- morecompute/services/providers/vastai_provider.py +407 -0
- morecompute/utils/cell_magics.py +0 -3
- morecompute/utils/config_util.py +93 -3
- morecompute/utils/special_commands.py +5 -32
- morecompute/utils/version_check.py +117 -0
- frontend/styling_README.md +0 -23
- {more_compute-0.4.3.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""Lambda Labs GPU cloud provider implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
|
|
6
|
+
from .base_provider import BaseGPUProvider, NormalizedPod
|
|
7
|
+
from .provider_factory import register_provider
|
|
8
|
+
from ...models.api_models import PodResponse
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@register_provider
|
|
12
|
+
class LambdaLabsProvider(BaseGPUProvider):
|
|
13
|
+
"""Lambda Labs GPU cloud provider using REST API."""
|
|
14
|
+
|
|
15
|
+
PROVIDER_NAME = "lambda_labs"
|
|
16
|
+
PROVIDER_DISPLAY_NAME = "Lambda Labs"
|
|
17
|
+
API_KEY_ENV_NAME = "LAMBDA_LABS_API_KEY"
|
|
18
|
+
SUPPORTS_SSH = True
|
|
19
|
+
DASHBOARD_URL = "https://cloud.lambdalabs.com/api-keys"
|
|
20
|
+
|
|
21
|
+
BASE_URL = "https://cloud.lambdalabs.com/api/v1"
|
|
22
|
+
|
|
23
|
+
def __init__(self, api_key: str | None = None):
|
|
24
|
+
super().__init__(api_key)
|
|
25
|
+
|
|
26
|
+
def _get_auth_headers(self) -> dict[str, str]:
|
|
27
|
+
"""Get Lambda Labs authentication headers."""
|
|
28
|
+
return {
|
|
29
|
+
"Content-Type": "application/json",
|
|
30
|
+
"Authorization": f"Bearer {self.api_key}" if self.api_key else ""
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
async def get_gpu_availability(
|
|
34
|
+
self,
|
|
35
|
+
regions: list[str] | None = None,
|
|
36
|
+
gpu_count: int | None = None,
|
|
37
|
+
gpu_type: str | None = None,
|
|
38
|
+
**kwargs: Any
|
|
39
|
+
) -> dict[str, Any]:
|
|
40
|
+
"""Get available GPU instance types from Lambda Labs.
|
|
41
|
+
|
|
42
|
+
Lambda Labs returns instance types with availability info.
|
|
43
|
+
"""
|
|
44
|
+
response = await self._make_request(
|
|
45
|
+
"GET",
|
|
46
|
+
f"{self.BASE_URL}/instance-types"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
instance_types = response.get("data", {})
|
|
50
|
+
gpus = []
|
|
51
|
+
|
|
52
|
+
for instance_type_name, instance_info in instance_types.items():
|
|
53
|
+
instance_type = instance_info.get("instance_type", {})
|
|
54
|
+
regions_available = instance_info.get("regions_with_capacity_available", [])
|
|
55
|
+
|
|
56
|
+
# Filter by region if specified
|
|
57
|
+
if regions:
|
|
58
|
+
regions_available = [r for r in regions_available if r.get("name") in regions]
|
|
59
|
+
|
|
60
|
+
# Skip if no availability
|
|
61
|
+
if not regions_available:
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
specs = instance_type.get("specs", {})
|
|
65
|
+
gpu_spec = specs.get("gpus", 1)
|
|
66
|
+
gpu_name = instance_type.get("description", instance_type_name)
|
|
67
|
+
|
|
68
|
+
# Filter by GPU type if specified
|
|
69
|
+
if gpu_type and gpu_type.lower() not in gpu_name.lower():
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
# Filter by GPU count if specified
|
|
73
|
+
if gpu_count and gpu_spec != gpu_count:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
for region in regions_available:
|
|
77
|
+
gpus.append({
|
|
78
|
+
"gpuType": instance_type_name,
|
|
79
|
+
"gpuName": gpu_name,
|
|
80
|
+
"gpuCount": gpu_spec,
|
|
81
|
+
"priceHr": instance_type.get("price_cents_per_hour", 0) / 100,
|
|
82
|
+
"cloudId": instance_type_name,
|
|
83
|
+
"socket": instance_type_name,
|
|
84
|
+
"region": region.get("name"),
|
|
85
|
+
"regionDescription": region.get("description"),
|
|
86
|
+
"vcpus": specs.get("vcpus"),
|
|
87
|
+
"memoryGb": specs.get("memory_gib"),
|
|
88
|
+
"storageGb": specs.get("storage_gib"),
|
|
89
|
+
"provider": self.PROVIDER_NAME
|
|
90
|
+
})
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
"data": gpus,
|
|
94
|
+
"total_count": len(gpus),
|
|
95
|
+
"provider": self.PROVIDER_NAME
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
async def create_pod(self, request: Any) -> PodResponse:
|
|
99
|
+
"""Launch a new Lambda Labs instance.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
request: CreatePodRequest with pod configuration
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
PodResponse with created instance info
|
|
106
|
+
"""
|
|
107
|
+
import sys
|
|
108
|
+
from fastapi import HTTPException
|
|
109
|
+
|
|
110
|
+
pod_config = request.pod if hasattr(request, 'pod') else request
|
|
111
|
+
|
|
112
|
+
# Get SSH key IDs (Lambda requires exactly one SSH key)
|
|
113
|
+
ssh_keys = await self._get_ssh_key_ids()
|
|
114
|
+
if not ssh_keys:
|
|
115
|
+
raise HTTPException(
|
|
116
|
+
status_code=400,
|
|
117
|
+
detail="No SSH keys found. Please add an SSH key to your Lambda Labs account at https://cloud.lambdalabs.com/ssh-keys"
|
|
118
|
+
)
|
|
119
|
+
# Lambda Labs requires exactly one SSH key - use the first one
|
|
120
|
+
ssh_key = ssh_keys[0]
|
|
121
|
+
|
|
122
|
+
instance_type = pod_config.gpuType if hasattr(pod_config, 'gpuType') else pod_config.get("gpuType")
|
|
123
|
+
name = pod_config.name if hasattr(pod_config, 'name') else pod_config.get("name", "morecompute-instance")
|
|
124
|
+
|
|
125
|
+
# Try multiple field names for region
|
|
126
|
+
region = None
|
|
127
|
+
for field in ['dataCenterId', 'region', 'regionName', 'region_name']:
|
|
128
|
+
if hasattr(pod_config, field):
|
|
129
|
+
region = getattr(pod_config, field)
|
|
130
|
+
elif isinstance(pod_config, dict) and pod_config.get(field):
|
|
131
|
+
region = pod_config.get(field)
|
|
132
|
+
if region:
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
if not region:
|
|
136
|
+
# Get all availability and find region for this specific instance type
|
|
137
|
+
availability = await self.get_gpu_availability() # Get all, don't filter
|
|
138
|
+
if availability.get("data"):
|
|
139
|
+
# Find the GPU entry matching this instance type
|
|
140
|
+
for gpu in availability["data"]:
|
|
141
|
+
if gpu.get("cloudId") == instance_type or gpu.get("gpuType") == instance_type:
|
|
142
|
+
region = gpu.get("region")
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
if not region:
|
|
146
|
+
raise HTTPException(
|
|
147
|
+
status_code=400,
|
|
148
|
+
detail=f"No available regions found for instance type '{instance_type}'. This GPU may be out of stock. Please try a different GPU."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
payload = {
|
|
152
|
+
"instance_type_name": instance_type,
|
|
153
|
+
"ssh_key_names": [ssh_key], # Lambda requires exactly one SSH key
|
|
154
|
+
"name": name,
|
|
155
|
+
"quantity": 1,
|
|
156
|
+
"region_name": region # Always required
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
response = await self._make_request(
|
|
160
|
+
"POST",
|
|
161
|
+
f"{self.BASE_URL}/instance-operations/launch",
|
|
162
|
+
json_data=payload
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
instance_ids = response.get("data", {}).get("instance_ids", [])
|
|
166
|
+
if not instance_ids:
|
|
167
|
+
raise HTTPException(status_code=500, detail="Failed to launch Lambda Labs instance")
|
|
168
|
+
|
|
169
|
+
# Get instance details
|
|
170
|
+
instance_id = instance_ids[0]
|
|
171
|
+
return await self.get_pod(instance_id)
|
|
172
|
+
|
|
173
|
+
async def _get_ssh_key_ids(self) -> list[str]:
|
|
174
|
+
"""Get list of SSH key names registered with Lambda Labs.
|
|
175
|
+
|
|
176
|
+
Returns keys sorted to prefer ed25519 keys (more common for modern setups).
|
|
177
|
+
"""
|
|
178
|
+
response = await self._make_request(
|
|
179
|
+
"GET",
|
|
180
|
+
f"{self.BASE_URL}/ssh-keys"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
keys = response.get("data", [])
|
|
184
|
+
|
|
185
|
+
# Separate ed25519 keys from others (prefer ed25519 as they're more common locally)
|
|
186
|
+
ed25519_keys = []
|
|
187
|
+
other_keys = []
|
|
188
|
+
|
|
189
|
+
for key in keys:
|
|
190
|
+
name = key.get("name")
|
|
191
|
+
public_key = key.get("public_key", "")
|
|
192
|
+
if name:
|
|
193
|
+
if public_key.startswith("ssh-ed25519"):
|
|
194
|
+
ed25519_keys.append(name)
|
|
195
|
+
else:
|
|
196
|
+
other_keys.append(name)
|
|
197
|
+
|
|
198
|
+
# Return ed25519 keys first, then others
|
|
199
|
+
return ed25519_keys + other_keys
|
|
200
|
+
|
|
201
|
+
async def get_ssh_keys_detailed(self) -> list[dict[str, Any]]:
|
|
202
|
+
"""Get detailed list of SSH keys with their types."""
|
|
203
|
+
response = await self._make_request(
|
|
204
|
+
"GET",
|
|
205
|
+
f"{self.BASE_URL}/ssh-keys"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
keys = response.get("data", [])
|
|
209
|
+
result = []
|
|
210
|
+
|
|
211
|
+
for key in keys:
|
|
212
|
+
public_key = key.get("public_key", "")
|
|
213
|
+
key_type = "unknown"
|
|
214
|
+
if public_key.startswith("ssh-ed25519"):
|
|
215
|
+
key_type = "ed25519"
|
|
216
|
+
elif public_key.startswith("ssh-rsa"):
|
|
217
|
+
key_type = "rsa"
|
|
218
|
+
elif public_key.startswith("ecdsa"):
|
|
219
|
+
key_type = "ecdsa"
|
|
220
|
+
|
|
221
|
+
result.append({
|
|
222
|
+
"name": key.get("name"),
|
|
223
|
+
"type": key_type,
|
|
224
|
+
"fingerprint": public_key[:50] + "..." if len(public_key) > 50 else public_key
|
|
225
|
+
})
|
|
226
|
+
|
|
227
|
+
return result
|
|
228
|
+
|
|
229
|
+
async def add_ssh_key(self, name: str, public_key: str) -> dict[str, Any]:
|
|
230
|
+
"""Add a new SSH key to Lambda Labs account."""
|
|
231
|
+
response = await self._make_request(
|
|
232
|
+
"POST",
|
|
233
|
+
f"{self.BASE_URL}/ssh-keys",
|
|
234
|
+
json_data={
|
|
235
|
+
"name": name,
|
|
236
|
+
"public_key": public_key
|
|
237
|
+
}
|
|
238
|
+
)
|
|
239
|
+
return response.get("data", {})
|
|
240
|
+
|
|
241
|
+
async def get_pods(
|
|
242
|
+
self,
|
|
243
|
+
status: str | None = None,
|
|
244
|
+
limit: int = 100,
|
|
245
|
+
offset: int = 0
|
|
246
|
+
) -> dict[str, Any]:
|
|
247
|
+
"""Get list of all Lambda Labs instances."""
|
|
248
|
+
response = await self._make_request(
|
|
249
|
+
"GET",
|
|
250
|
+
f"{self.BASE_URL}/instances"
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
instances = response.get("data", [])
|
|
254
|
+
|
|
255
|
+
# Filter by status if specified
|
|
256
|
+
if status:
|
|
257
|
+
status_lower = status.lower()
|
|
258
|
+
instances = [i for i in instances if i.get("status", "").lower() == status_lower]
|
|
259
|
+
|
|
260
|
+
# Apply pagination
|
|
261
|
+
instances = instances[offset:offset + limit]
|
|
262
|
+
|
|
263
|
+
# Transform to standardized format
|
|
264
|
+
pods = []
|
|
265
|
+
for instance in instances:
|
|
266
|
+
ssh_connection = None
|
|
267
|
+
ip = instance.get("ip")
|
|
268
|
+
if ip:
|
|
269
|
+
ssh_connection = f"ssh ubuntu@{ip}"
|
|
270
|
+
|
|
271
|
+
pods.append({
|
|
272
|
+
"id": instance.get("id"),
|
|
273
|
+
"name": instance.get("name"),
|
|
274
|
+
"status": self._normalize_status(instance.get("status", "unknown")),
|
|
275
|
+
"gpuName": instance.get("instance_type", {}).get("description", ""),
|
|
276
|
+
"gpuCount": instance.get("instance_type", {}).get("specs", {}).get("gpus", 1),
|
|
277
|
+
"priceHr": instance.get("instance_type", {}).get("price_cents_per_hour", 0) / 100,
|
|
278
|
+
"sshConnection": ssh_connection,
|
|
279
|
+
"ip": ip,
|
|
280
|
+
"region": instance.get("region", {}).get("name"),
|
|
281
|
+
"createdAt": instance.get("created_at", datetime.now(timezone.utc).isoformat()),
|
|
282
|
+
"updatedAt": instance.get("created_at", datetime.now(timezone.utc).isoformat()),
|
|
283
|
+
"provider": self.PROVIDER_NAME
|
|
284
|
+
})
|
|
285
|
+
|
|
286
|
+
return {
|
|
287
|
+
"data": pods,
|
|
288
|
+
"total_count": len(pods),
|
|
289
|
+
"offset": offset,
|
|
290
|
+
"limit": limit,
|
|
291
|
+
"provider": self.PROVIDER_NAME
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
async def get_pod(self, pod_id: str) -> PodResponse:
|
|
295
|
+
"""Get details for a specific Lambda Labs instance."""
|
|
296
|
+
response = await self._make_request(
|
|
297
|
+
"GET",
|
|
298
|
+
f"{self.BASE_URL}/instances/{pod_id}"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
instance = response.get("data", {})
|
|
302
|
+
if not instance:
|
|
303
|
+
from fastapi import HTTPException
|
|
304
|
+
raise HTTPException(status_code=404, detail=f"Instance {pod_id} not found")
|
|
305
|
+
|
|
306
|
+
ssh_connection = None
|
|
307
|
+
ip = instance.get("ip")
|
|
308
|
+
if ip:
|
|
309
|
+
ssh_connection = f"ssh ubuntu@{ip}"
|
|
310
|
+
|
|
311
|
+
now = datetime.now(timezone.utc)
|
|
312
|
+
instance_type = instance.get("instance_type", {})
|
|
313
|
+
|
|
314
|
+
return PodResponse(
|
|
315
|
+
id=instance.get("id", ""),
|
|
316
|
+
userId="",
|
|
317
|
+
teamId=None,
|
|
318
|
+
name=instance.get("name", ""),
|
|
319
|
+
status=self._normalize_status(instance.get("status", "unknown")),
|
|
320
|
+
gpuName=instance_type.get("description", ""),
|
|
321
|
+
gpuCount=instance_type.get("specs", {}).get("gpus", 1),
|
|
322
|
+
priceHr=instance_type.get("price_cents_per_hour", 0) / 100,
|
|
323
|
+
sshConnection=ssh_connection,
|
|
324
|
+
ip=ip,
|
|
325
|
+
createdAt=now,
|
|
326
|
+
updatedAt=now
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
async def delete_pod(self, pod_id: str) -> dict[str, Any]:
|
|
330
|
+
"""Terminate a Lambda Labs instance."""
|
|
331
|
+
response = await self._make_request(
|
|
332
|
+
"POST",
|
|
333
|
+
f"{self.BASE_URL}/instance-operations/terminate",
|
|
334
|
+
json_data={
|
|
335
|
+
"instance_ids": [pod_id]
|
|
336
|
+
}
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
terminated = response.get("data", {}).get("terminated_instances", [])
|
|
340
|
+
return {
|
|
341
|
+
"success": pod_id in [t.get("id") for t in terminated],
|
|
342
|
+
"pod_id": pod_id,
|
|
343
|
+
"provider": self.PROVIDER_NAME
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
async def restart_pod(self, pod_id: str) -> dict[str, Any]:
|
|
347
|
+
"""Restart a Lambda Labs instance."""
|
|
348
|
+
response = await self._make_request(
|
|
349
|
+
"POST",
|
|
350
|
+
f"{self.BASE_URL}/instance-operations/restart",
|
|
351
|
+
json_data={
|
|
352
|
+
"instance_ids": [pod_id]
|
|
353
|
+
}
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
restarted = response.get("data", {}).get("restarted_instances", [])
|
|
357
|
+
return {
|
|
358
|
+
"success": pod_id in [r.get("id") for r in restarted],
|
|
359
|
+
"pod_id": pod_id,
|
|
360
|
+
"action": "restarted"
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
def _normalize_status(self, lambda_status: str) -> str:
|
|
364
|
+
"""Convert Lambda Labs status to normalized status."""
|
|
365
|
+
status_map = {
|
|
366
|
+
"active": "ACTIVE",
|
|
367
|
+
"booting": "STARTING",
|
|
368
|
+
"unhealthy": "ERROR",
|
|
369
|
+
"terminated": "TERMINATED"
|
|
370
|
+
}
|
|
371
|
+
return status_map.get(lambda_status.lower(), lambda_status.upper())
|
|
372
|
+
|
|
373
|
+
def normalize_pod(self, pod_data: dict[str, Any]) -> NormalizedPod:
|
|
374
|
+
"""Convert Lambda Labs instance data to normalized format."""
|
|
375
|
+
ssh_connection = None
|
|
376
|
+
ip = pod_data.get("ip")
|
|
377
|
+
if ip:
|
|
378
|
+
ssh_connection = f"ssh ubuntu@{ip}"
|
|
379
|
+
|
|
380
|
+
instance_type = pod_data.get("instance_type", {})
|
|
381
|
+
|
|
382
|
+
return NormalizedPod(
|
|
383
|
+
id=pod_data.get("id", ""),
|
|
384
|
+
name=pod_data.get("name", ""),
|
|
385
|
+
status=self._normalize_status(pod_data.get("status", "unknown")),
|
|
386
|
+
gpu_name=instance_type.get("description", ""),
|
|
387
|
+
gpu_count=instance_type.get("specs", {}).get("gpus", 1),
|
|
388
|
+
price_hr=instance_type.get("price_cents_per_hour", 0) / 100,
|
|
389
|
+
ssh_connection=ssh_connection,
|
|
390
|
+
ip=ip,
|
|
391
|
+
provider=self.PROVIDER_NAME,
|
|
392
|
+
created_at=pod_data.get("created_at", datetime.now(timezone.utc).isoformat()),
|
|
393
|
+
updated_at=pod_data.get("created_at", datetime.now(timezone.utc).isoformat())
|
|
394
|
+
)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Factory and registry for GPU cloud providers."""
|
|
2
|
+
|
|
3
|
+
from typing import Type
|
|
4
|
+
from .base_provider import BaseGPUProvider, ProviderInfo, ProviderType
|
|
5
|
+
from ...utils.config_util import load_api_key, save_api_key, _load_config, _save_config
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Registry of provider classes
|
|
9
|
+
_PROVIDER_REGISTRY: dict[str, Type[BaseGPUProvider]] = {}
|
|
10
|
+
|
|
11
|
+
# Cached provider instances
|
|
12
|
+
_PROVIDER_INSTANCES: dict[str, BaseGPUProvider] = {}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def register_provider(provider_class: Type[BaseGPUProvider]) -> Type[BaseGPUProvider]:
|
|
16
|
+
"""Decorator to register a provider class.
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
@register_provider
|
|
20
|
+
class MyProvider(BaseGPUProvider):
|
|
21
|
+
PROVIDER_NAME = "my_provider"
|
|
22
|
+
...
|
|
23
|
+
"""
|
|
24
|
+
name = provider_class.PROVIDER_NAME
|
|
25
|
+
if not name:
|
|
26
|
+
raise ValueError(f"Provider class {provider_class.__name__} must define PROVIDER_NAME")
|
|
27
|
+
_PROVIDER_REGISTRY[name] = provider_class
|
|
28
|
+
return provider_class
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_provider_class(provider_name: str) -> Type[BaseGPUProvider] | None:
|
|
32
|
+
"""Get a registered provider class by name.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
provider_name: The provider identifier (e.g., "runpod")
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Provider class or None if not found
|
|
39
|
+
"""
|
|
40
|
+
return _PROVIDER_REGISTRY.get(provider_name)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_provider(provider_name: str, force_new: bool = False) -> BaseGPUProvider | None:
|
|
44
|
+
"""Get a provider instance, creating it if necessary.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
provider_name: The provider identifier (e.g., "runpod")
|
|
48
|
+
force_new: If True, create a new instance even if one is cached
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Provider instance or None if provider not found
|
|
52
|
+
"""
|
|
53
|
+
if not force_new and provider_name in _PROVIDER_INSTANCES:
|
|
54
|
+
return _PROVIDER_INSTANCES[provider_name]
|
|
55
|
+
|
|
56
|
+
provider_class = get_provider_class(provider_name)
|
|
57
|
+
if not provider_class:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
# Load API key for this provider
|
|
61
|
+
api_key = load_api_key(provider_class.API_KEY_ENV_NAME)
|
|
62
|
+
|
|
63
|
+
# Create instance
|
|
64
|
+
instance = provider_class(api_key=api_key)
|
|
65
|
+
_PROVIDER_INSTANCES[provider_name] = instance
|
|
66
|
+
return instance
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def refresh_provider(provider_name: str) -> BaseGPUProvider | None:
|
|
70
|
+
"""Refresh a provider instance (e.g., after API key update).
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
provider_name: The provider identifier
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
New provider instance or None if provider not found
|
|
77
|
+
"""
|
|
78
|
+
# Clear cached instance
|
|
79
|
+
if provider_name in _PROVIDER_INSTANCES:
|
|
80
|
+
del _PROVIDER_INSTANCES[provider_name]
|
|
81
|
+
return get_provider(provider_name, force_new=True)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def list_providers() -> list[ProviderInfo]:
|
|
85
|
+
"""List all registered providers with their configuration status.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of ProviderInfo for all registered providers
|
|
89
|
+
"""
|
|
90
|
+
active_provider = get_active_provider_name()
|
|
91
|
+
providers = []
|
|
92
|
+
|
|
93
|
+
for name, provider_class in _PROVIDER_REGISTRY.items():
|
|
94
|
+
api_key = load_api_key(provider_class.API_KEY_ENV_NAME)
|
|
95
|
+
info = ProviderInfo(
|
|
96
|
+
name=name,
|
|
97
|
+
display_name=provider_class.PROVIDER_DISPLAY_NAME,
|
|
98
|
+
api_key_env_name=provider_class.API_KEY_ENV_NAME,
|
|
99
|
+
supports_ssh=provider_class.SUPPORTS_SSH,
|
|
100
|
+
dashboard_url=provider_class.DASHBOARD_URL,
|
|
101
|
+
configured=api_key is not None and len(api_key.strip()) > 0,
|
|
102
|
+
is_active=(name == active_provider)
|
|
103
|
+
)
|
|
104
|
+
providers.append(info)
|
|
105
|
+
|
|
106
|
+
# Sort by display name for consistent ordering
|
|
107
|
+
providers.sort(key=lambda p: p.display_name)
|
|
108
|
+
return providers
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_configured_providers() -> list[ProviderInfo]:
|
|
112
|
+
"""Get list of providers that have API keys configured.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
List of ProviderInfo for configured providers only
|
|
116
|
+
"""
|
|
117
|
+
return [p for p in list_providers() if p.configured]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_active_provider_name() -> str | None:
|
|
121
|
+
"""Get the name of the currently active provider.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Provider name or None if not set
|
|
125
|
+
"""
|
|
126
|
+
config = _load_config()
|
|
127
|
+
return config.get("active_provider")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def set_active_provider(provider_name: str) -> bool:
|
|
131
|
+
"""Set the active provider.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
provider_name: The provider to make active
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
True if successful, False if provider not found or not configured
|
|
138
|
+
"""
|
|
139
|
+
if provider_name not in _PROVIDER_REGISTRY:
|
|
140
|
+
return False
|
|
141
|
+
|
|
142
|
+
provider_class = _PROVIDER_REGISTRY[provider_name]
|
|
143
|
+
api_key = load_api_key(provider_class.API_KEY_ENV_NAME)
|
|
144
|
+
|
|
145
|
+
if not api_key:
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
config = _load_config()
|
|
149
|
+
config["active_provider"] = provider_name
|
|
150
|
+
_save_config(config)
|
|
151
|
+
return True
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_active_provider() -> BaseGPUProvider | None:
|
|
155
|
+
"""Get the currently active provider instance.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Active provider instance or None if not set
|
|
159
|
+
"""
|
|
160
|
+
active_name = get_active_provider_name()
|
|
161
|
+
if not active_name:
|
|
162
|
+
return None
|
|
163
|
+
return get_provider(active_name)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def configure_provider(provider_name: str, api_key: str, make_active: bool = False) -> bool:
|
|
167
|
+
"""Configure a provider with an API key.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
provider_name: The provider to configure
|
|
171
|
+
api_key: The API key to save
|
|
172
|
+
make_active: If True, also make this the active provider
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
True if successful, False if provider not found
|
|
176
|
+
"""
|
|
177
|
+
if provider_name not in _PROVIDER_REGISTRY:
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
provider_class = _PROVIDER_REGISTRY[provider_name]
|
|
181
|
+
save_api_key(provider_class.API_KEY_ENV_NAME, api_key)
|
|
182
|
+
|
|
183
|
+
# Refresh the provider instance
|
|
184
|
+
refresh_provider(provider_name)
|
|
185
|
+
|
|
186
|
+
if make_active:
|
|
187
|
+
set_active_provider(provider_name)
|
|
188
|
+
|
|
189
|
+
return True
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def clear_all_providers() -> None:
|
|
193
|
+
"""Clear all cached provider instances. Useful for testing."""
|
|
194
|
+
_PROVIDER_INSTANCES.clear()
|