more-compute 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontend/app/globals.css +734 -27
- frontend/app/layout.tsx +13 -3
- frontend/components/Notebook.tsx +2 -14
- frontend/components/cell/MonacoCell.tsx +99 -5
- frontend/components/layout/Sidebar.tsx +39 -4
- frontend/components/panels/ClaudePanel.tsx +461 -0
- frontend/components/popups/ComputePopup.tsx +738 -447
- frontend/components/popups/FilterPopup.tsx +305 -189
- frontend/components/popups/MetricsPopup.tsx +20 -1
- frontend/components/popups/ProviderConfigModal.tsx +322 -0
- frontend/components/popups/ProviderDropdown.tsx +398 -0
- frontend/components/popups/SettingsPopup.tsx +1 -1
- frontend/contexts/ClaudeContext.tsx +392 -0
- frontend/contexts/PodWebSocketContext.tsx +16 -21
- frontend/hooks/useInlineDiff.ts +269 -0
- frontend/lib/api.ts +323 -12
- frontend/lib/settings.ts +5 -0
- frontend/lib/websocket-native.ts +4 -8
- frontend/lib/websocket.ts +1 -2
- frontend/package-lock.json +733 -36
- frontend/package.json +2 -0
- frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
- frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
- frontend/public/assets/icons/providers/runpod.svg +9 -0
- frontend/public/assets/icons/providers/vastai.svg +1 -0
- frontend/settings.md +54 -0
- frontend/tsconfig.tsbuildinfo +1 -0
- frontend/types/claude.ts +194 -0
- kernel_run.py +13 -0
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
- morecompute/__init__.py +1 -1
- morecompute/__version__.py +1 -1
- morecompute/execution/executor.py +24 -67
- morecompute/execution/worker.py +6 -72
- morecompute/models/api_models.py +62 -0
- morecompute/notebook.py +11 -0
- morecompute/server.py +641 -133
- morecompute/services/claude_service.py +392 -0
- morecompute/services/pod_manager.py +168 -67
- morecompute/services/pod_monitor.py +67 -39
- morecompute/services/prime_intellect.py +0 -4
- morecompute/services/providers/__init__.py +92 -0
- morecompute/services/providers/base_provider.py +336 -0
- morecompute/services/providers/lambda_labs_provider.py +394 -0
- morecompute/services/providers/provider_factory.py +194 -0
- morecompute/services/providers/runpod_provider.py +504 -0
- morecompute/services/providers/vastai_provider.py +407 -0
- morecompute/utils/cell_magics.py +0 -3
- morecompute/utils/config_util.py +93 -3
- morecompute/utils/special_commands.py +5 -32
- morecompute/utils/version_check.py +117 -0
- frontend/styling_README.md +0 -23
- {more_compute-0.4.4.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
- {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""RunPod GPU cloud provider implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
import httpx
|
|
6
|
+
from fastapi import HTTPException
|
|
7
|
+
|
|
8
|
+
from .base_provider import BaseGPUProvider, NormalizedPod
|
|
9
|
+
from .provider_factory import register_provider
|
|
10
|
+
from ...models.api_models import PodResponse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_provider
|
|
14
|
+
class RunPodProvider(BaseGPUProvider):
|
|
15
|
+
"""RunPod GPU cloud provider using GraphQL API."""
|
|
16
|
+
|
|
17
|
+
PROVIDER_NAME = "runpod"
|
|
18
|
+
PROVIDER_DISPLAY_NAME = "RunPod"
|
|
19
|
+
API_KEY_ENV_NAME = "RUNPOD_API_KEY"
|
|
20
|
+
SUPPORTS_SSH = True
|
|
21
|
+
DASHBOARD_URL = "https://www.runpod.io/console/user/settings"
|
|
22
|
+
|
|
23
|
+
BASE_URL = "https://api.runpod.io/graphql"
|
|
24
|
+
|
|
25
|
+
def __init__(self, api_key: str | None = None):
|
|
26
|
+
super().__init__(api_key)
|
|
27
|
+
|
|
28
|
+
def _get_auth_headers(self) -> dict[str, str]:
|
|
29
|
+
"""Get RunPod authentication headers."""
|
|
30
|
+
return {
|
|
31
|
+
"Content-Type": "application/json",
|
|
32
|
+
"Authorization": f"Bearer {self.api_key}" if self.api_key else ""
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
async def _graphql_request(
|
|
36
|
+
self,
|
|
37
|
+
query: str,
|
|
38
|
+
variables: dict[str, Any] | None = None
|
|
39
|
+
) -> dict[str, Any]:
|
|
40
|
+
"""Make a GraphQL request to RunPod API.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
query: GraphQL query string
|
|
44
|
+
variables: Query variables
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Response data
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
HTTPException: On API errors
|
|
51
|
+
"""
|
|
52
|
+
async with httpx.AsyncClient() as client:
|
|
53
|
+
try:
|
|
54
|
+
response = await client.post(
|
|
55
|
+
self.BASE_URL,
|
|
56
|
+
headers=self._get_auth_headers(),
|
|
57
|
+
json={
|
|
58
|
+
"query": query,
|
|
59
|
+
"variables": variables or {}
|
|
60
|
+
},
|
|
61
|
+
timeout=30.0
|
|
62
|
+
)
|
|
63
|
+
response.raise_for_status()
|
|
64
|
+
result = response.json()
|
|
65
|
+
|
|
66
|
+
if "errors" in result:
|
|
67
|
+
error_msg = result["errors"][0].get("message", "Unknown error")
|
|
68
|
+
raise HTTPException(
|
|
69
|
+
status_code=400,
|
|
70
|
+
detail=f"RunPod API error: {error_msg}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return result.get("data", {})
|
|
74
|
+
except httpx.HTTPStatusError as e:
|
|
75
|
+
raise HTTPException(
|
|
76
|
+
status_code=e.response.status_code,
|
|
77
|
+
detail=f"RunPod API error: {e.response.text}"
|
|
78
|
+
)
|
|
79
|
+
except httpx.RequestError as e:
|
|
80
|
+
raise HTTPException(
|
|
81
|
+
status_code=503,
|
|
82
|
+
detail=f"RunPod connection error: {str(e)}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def get_gpu_availability(
|
|
86
|
+
self,
|
|
87
|
+
regions: list[str] | None = None,
|
|
88
|
+
gpu_count: int | None = None,
|
|
89
|
+
gpu_type: str | None = None,
|
|
90
|
+
secure_cloud: bool | None = None,
|
|
91
|
+
community_cloud: bool | None = None,
|
|
92
|
+
**kwargs: Any
|
|
93
|
+
) -> dict[str, Any]:
|
|
94
|
+
"""Get available GPU types from RunPod.
|
|
95
|
+
|
|
96
|
+
Returns dict with available GPUs and their pricing.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
regions: Filter by region (not supported by RunPod GPU types query)
|
|
100
|
+
gpu_count: Number of GPUs to request pricing for
|
|
101
|
+
gpu_type: Filter by GPU type name (partial match)
|
|
102
|
+
secure_cloud: If True, only show GPUs available in Secure Cloud
|
|
103
|
+
community_cloud: If True, only show GPUs available in Community Cloud
|
|
104
|
+
"""
|
|
105
|
+
query = """
|
|
106
|
+
query GpuTypes {
|
|
107
|
+
gpuTypes {
|
|
108
|
+
id
|
|
109
|
+
displayName
|
|
110
|
+
memoryInGb
|
|
111
|
+
secureCloud
|
|
112
|
+
communityCloud
|
|
113
|
+
lowestPrice(input: {gpuCount: 1}) {
|
|
114
|
+
minimumBidPrice
|
|
115
|
+
uninterruptablePrice
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
data = await self._graphql_request(query)
|
|
122
|
+
gpu_types = data.get("gpuTypes", [])
|
|
123
|
+
|
|
124
|
+
# Filter by GPU type if specified
|
|
125
|
+
if gpu_type:
|
|
126
|
+
gpu_type_lower = gpu_type.lower()
|
|
127
|
+
gpu_types = [
|
|
128
|
+
g for g in gpu_types
|
|
129
|
+
if gpu_type_lower in g.get("displayName", "").lower()
|
|
130
|
+
or gpu_type_lower in g.get("id", "").lower()
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
# Filter by cloud type
|
|
134
|
+
if secure_cloud is True:
|
|
135
|
+
gpu_types = [g for g in gpu_types if g.get("secureCloud")]
|
|
136
|
+
if community_cloud is True:
|
|
137
|
+
gpu_types = [g for g in gpu_types if g.get("communityCloud")]
|
|
138
|
+
|
|
139
|
+
# Transform to normalized format
|
|
140
|
+
gpus = []
|
|
141
|
+
for gpu in gpu_types:
|
|
142
|
+
lowest_price = gpu.get("lowestPrice", {})
|
|
143
|
+
price = lowest_price.get("uninterruptablePrice") or lowest_price.get("minimumBidPrice") or 0
|
|
144
|
+
|
|
145
|
+
gpus.append({
|
|
146
|
+
"gpuType": gpu.get("id"),
|
|
147
|
+
"gpuName": gpu.get("displayName"),
|
|
148
|
+
"gpuCount": gpu_count or 1,
|
|
149
|
+
"priceHr": price,
|
|
150
|
+
"cloudId": gpu.get("id"),
|
|
151
|
+
"socket": gpu.get("id"),
|
|
152
|
+
"memoryGb": gpu.get("memoryInGb"),
|
|
153
|
+
"secureCloud": gpu.get("secureCloud"),
|
|
154
|
+
"communityCloud": gpu.get("communityCloud"),
|
|
155
|
+
"provider": self.PROVIDER_NAME
|
|
156
|
+
})
|
|
157
|
+
|
|
158
|
+
return {
|
|
159
|
+
"data": gpus,
|
|
160
|
+
"total_count": len(gpus),
|
|
161
|
+
"provider": self.PROVIDER_NAME
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
async def create_pod(self, request: Any) -> PodResponse:
|
|
165
|
+
"""Create a new RunPod pod.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
request: CreatePodRequest with pod configuration
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
PodResponse with created pod info
|
|
172
|
+
"""
|
|
173
|
+
pod_config = request.pod if hasattr(request, 'pod') else request
|
|
174
|
+
|
|
175
|
+
mutation = """
|
|
176
|
+
mutation CreatePod($input: PodFindAndDeployOnDemandInput!) {
|
|
177
|
+
podFindAndDeployOnDemand(input: $input) {
|
|
178
|
+
id
|
|
179
|
+
name
|
|
180
|
+
desiredStatus
|
|
181
|
+
imageName
|
|
182
|
+
gpuCount
|
|
183
|
+
machineId
|
|
184
|
+
machine {
|
|
185
|
+
gpuDisplayName
|
|
186
|
+
}
|
|
187
|
+
runtime {
|
|
188
|
+
uptimeInSeconds
|
|
189
|
+
ports {
|
|
190
|
+
ip
|
|
191
|
+
isIpPublic
|
|
192
|
+
privatePort
|
|
193
|
+
publicPort
|
|
194
|
+
type
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
variables = {
|
|
202
|
+
"input": {
|
|
203
|
+
"name": pod_config.name if hasattr(pod_config, 'name') else pod_config.get("name"),
|
|
204
|
+
"gpuTypeId": pod_config.gpuType if hasattr(pod_config, 'gpuType') else pod_config.get("gpuType"),
|
|
205
|
+
"gpuCount": pod_config.gpuCount if hasattr(pod_config, 'gpuCount') else pod_config.get("gpuCount", 1),
|
|
206
|
+
"volumeInGb": pod_config.diskSize if hasattr(pod_config, 'diskSize') else pod_config.get("diskSize", 20),
|
|
207
|
+
"containerDiskInGb": 20,
|
|
208
|
+
"dockerArgs": "",
|
|
209
|
+
"deployCost": pod_config.maxPrice if hasattr(pod_config, 'maxPrice') else pod_config.get("maxPrice"),
|
|
210
|
+
"startSsh": True,
|
|
211
|
+
"imageName": pod_config.image if hasattr(pod_config, 'image') else pod_config.get("image", "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"),
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
# Remove None values
|
|
216
|
+
variables["input"] = {k: v for k, v in variables["input"].items() if v is not None}
|
|
217
|
+
|
|
218
|
+
data = await self._graphql_request(mutation, variables)
|
|
219
|
+
pod_data = data.get("podFindAndDeployOnDemand", {})
|
|
220
|
+
|
|
221
|
+
# Get SSH connection info
|
|
222
|
+
ssh_connection = await self._get_ssh_connection(pod_data.get("id"))
|
|
223
|
+
|
|
224
|
+
now = datetime.utcnow()
|
|
225
|
+
return PodResponse(
|
|
226
|
+
id=pod_data.get("id", ""),
|
|
227
|
+
userId="",
|
|
228
|
+
teamId=None,
|
|
229
|
+
name=pod_data.get("name", ""),
|
|
230
|
+
status=self._normalize_status(pod_data.get("desiredStatus", "PENDING")),
|
|
231
|
+
gpuName=pod_data.get("machine", {}).get("gpuDisplayName", ""),
|
|
232
|
+
gpuCount=pod_data.get("gpuCount", 1),
|
|
233
|
+
priceHr=0.0, # Will be fetched separately
|
|
234
|
+
sshConnection=ssh_connection,
|
|
235
|
+
ip=None,
|
|
236
|
+
createdAt=now,
|
|
237
|
+
updatedAt=now
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
async def _get_ssh_connection(self, pod_id: str) -> str | None:
|
|
241
|
+
"""Get SSH connection string for a pod."""
|
|
242
|
+
if not pod_id:
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
query = """
|
|
246
|
+
query Pod($podId: String!) {
|
|
247
|
+
pod(input: {podId: $podId}) {
|
|
248
|
+
id
|
|
249
|
+
runtime {
|
|
250
|
+
ports {
|
|
251
|
+
ip
|
|
252
|
+
isIpPublic
|
|
253
|
+
privatePort
|
|
254
|
+
publicPort
|
|
255
|
+
type
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
data = await self._graphql_request(query, {"podId": pod_id})
|
|
264
|
+
pod = data.get("pod", {})
|
|
265
|
+
runtime = pod.get("runtime", {})
|
|
266
|
+
ports = runtime.get("ports", [])
|
|
267
|
+
|
|
268
|
+
for port in ports:
|
|
269
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
270
|
+
ip = port.get("ip")
|
|
271
|
+
public_port = port.get("publicPort")
|
|
272
|
+
return f"ssh root@{ip} -p {public_port}"
|
|
273
|
+
except Exception:
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
async def get_pods(
|
|
279
|
+
self,
|
|
280
|
+
status: str | None = None,
|
|
281
|
+
limit: int = 100,
|
|
282
|
+
offset: int = 0
|
|
283
|
+
) -> dict[str, Any]:
|
|
284
|
+
"""Get list of all RunPod pods."""
|
|
285
|
+
query = """
|
|
286
|
+
query Pods {
|
|
287
|
+
myself {
|
|
288
|
+
pods {
|
|
289
|
+
id
|
|
290
|
+
name
|
|
291
|
+
desiredStatus
|
|
292
|
+
imageName
|
|
293
|
+
gpuCount
|
|
294
|
+
costPerHr
|
|
295
|
+
machineId
|
|
296
|
+
machine {
|
|
297
|
+
gpuDisplayName
|
|
298
|
+
}
|
|
299
|
+
runtime {
|
|
300
|
+
uptimeInSeconds
|
|
301
|
+
ports {
|
|
302
|
+
ip
|
|
303
|
+
isIpPublic
|
|
304
|
+
privatePort
|
|
305
|
+
publicPort
|
|
306
|
+
type
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
data = await self._graphql_request(query)
|
|
315
|
+
pods_raw = data.get("myself", {}).get("pods", [])
|
|
316
|
+
|
|
317
|
+
# Filter by status if specified
|
|
318
|
+
if status:
|
|
319
|
+
status_upper = status.upper()
|
|
320
|
+
pods_raw = [p for p in pods_raw if p.get("desiredStatus", "").upper() == status_upper]
|
|
321
|
+
|
|
322
|
+
# Apply pagination
|
|
323
|
+
pods_raw = pods_raw[offset:offset + limit]
|
|
324
|
+
|
|
325
|
+
# Transform to standardized format
|
|
326
|
+
pods = []
|
|
327
|
+
for pod in pods_raw:
|
|
328
|
+
ssh_connection = None
|
|
329
|
+
runtime = pod.get("runtime", {})
|
|
330
|
+
if runtime:
|
|
331
|
+
ports = runtime.get("ports", [])
|
|
332
|
+
for port in ports:
|
|
333
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
334
|
+
ip = port.get("ip")
|
|
335
|
+
public_port = port.get("publicPort")
|
|
336
|
+
ssh_connection = f"ssh root@{ip} -p {public_port}"
|
|
337
|
+
break
|
|
338
|
+
|
|
339
|
+
pods.append({
|
|
340
|
+
"id": pod.get("id"),
|
|
341
|
+
"name": pod.get("name"),
|
|
342
|
+
"status": self._normalize_status(pod.get("desiredStatus", "PENDING")),
|
|
343
|
+
"gpuName": pod.get("machine", {}).get("gpuDisplayName", ""),
|
|
344
|
+
"gpuCount": pod.get("gpuCount", 1),
|
|
345
|
+
"priceHr": pod.get("costPerHr", 0),
|
|
346
|
+
"sshConnection": ssh_connection,
|
|
347
|
+
"ip": None,
|
|
348
|
+
"createdAt": datetime.utcnow().isoformat(),
|
|
349
|
+
"updatedAt": datetime.utcnow().isoformat(),
|
|
350
|
+
"provider": self.PROVIDER_NAME
|
|
351
|
+
})
|
|
352
|
+
|
|
353
|
+
return {
|
|
354
|
+
"data": pods,
|
|
355
|
+
"total_count": len(pods),
|
|
356
|
+
"offset": offset,
|
|
357
|
+
"limit": limit,
|
|
358
|
+
"provider": self.PROVIDER_NAME
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
async def get_pod(self, pod_id: str) -> PodResponse:
|
|
362
|
+
"""Get details for a specific RunPod pod."""
|
|
363
|
+
query = """
|
|
364
|
+
query Pod($podId: String!) {
|
|
365
|
+
pod(input: {podId: $podId}) {
|
|
366
|
+
id
|
|
367
|
+
name
|
|
368
|
+
desiredStatus
|
|
369
|
+
imageName
|
|
370
|
+
gpuCount
|
|
371
|
+
costPerHr
|
|
372
|
+
machineId
|
|
373
|
+
machine {
|
|
374
|
+
gpuDisplayName
|
|
375
|
+
}
|
|
376
|
+
runtime {
|
|
377
|
+
uptimeInSeconds
|
|
378
|
+
ports {
|
|
379
|
+
ip
|
|
380
|
+
isIpPublic
|
|
381
|
+
privatePort
|
|
382
|
+
publicPort
|
|
383
|
+
type
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
"""
|
|
389
|
+
|
|
390
|
+
data = await self._graphql_request(query, {"podId": pod_id})
|
|
391
|
+
pod = data.get("pod")
|
|
392
|
+
|
|
393
|
+
if not pod:
|
|
394
|
+
raise HTTPException(status_code=404, detail=f"Pod {pod_id} not found")
|
|
395
|
+
|
|
396
|
+
# Extract SSH connection
|
|
397
|
+
ssh_connection = None
|
|
398
|
+
runtime = pod.get("runtime", {})
|
|
399
|
+
if runtime:
|
|
400
|
+
ports = runtime.get("ports", [])
|
|
401
|
+
for port in ports:
|
|
402
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
403
|
+
ip = port.get("ip")
|
|
404
|
+
public_port = port.get("publicPort")
|
|
405
|
+
ssh_connection = f"ssh root@{ip} -p {public_port}"
|
|
406
|
+
break
|
|
407
|
+
|
|
408
|
+
now = datetime.utcnow()
|
|
409
|
+
return PodResponse(
|
|
410
|
+
id=pod.get("id", ""),
|
|
411
|
+
userId="",
|
|
412
|
+
teamId=None,
|
|
413
|
+
name=pod.get("name", ""),
|
|
414
|
+
status=self._normalize_status(pod.get("desiredStatus", "PENDING")),
|
|
415
|
+
gpuName=pod.get("machine", {}).get("gpuDisplayName", ""),
|
|
416
|
+
gpuCount=pod.get("gpuCount", 1),
|
|
417
|
+
priceHr=pod.get("costPerHr", 0),
|
|
418
|
+
sshConnection=ssh_connection,
|
|
419
|
+
ip=None,
|
|
420
|
+
createdAt=now,
|
|
421
|
+
updatedAt=now
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
async def delete_pod(self, pod_id: str) -> dict[str, Any]:
|
|
425
|
+
"""Delete/terminate a RunPod pod."""
|
|
426
|
+
mutation = """
|
|
427
|
+
mutation TerminatePod($podId: String!) {
|
|
428
|
+
podTerminate(input: {podId: $podId})
|
|
429
|
+
}
|
|
430
|
+
"""
|
|
431
|
+
|
|
432
|
+
await self._graphql_request(mutation, {"podId": pod_id})
|
|
433
|
+
return {"success": True, "pod_id": pod_id, "provider": self.PROVIDER_NAME}
|
|
434
|
+
|
|
435
|
+
async def stop_pod(self, pod_id: str) -> dict[str, Any]:
|
|
436
|
+
"""Stop a RunPod pod (without deleting)."""
|
|
437
|
+
mutation = """
|
|
438
|
+
mutation StopPod($podId: String!) {
|
|
439
|
+
podStop(input: {podId: $podId})
|
|
440
|
+
}
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
await self._graphql_request(mutation, {"podId": pod_id})
|
|
444
|
+
return {"success": True, "pod_id": pod_id, "action": "stopped"}
|
|
445
|
+
|
|
446
|
+
async def resume_pod(self, pod_id: str) -> dict[str, Any]:
|
|
447
|
+
"""Resume a stopped RunPod pod."""
|
|
448
|
+
mutation = """
|
|
449
|
+
mutation ResumePod($podId: String!) {
|
|
450
|
+
podResume(input: {podId: $podId}) {
|
|
451
|
+
id
|
|
452
|
+
desiredStatus
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
data = await self._graphql_request(mutation, {"podId": pod_id})
|
|
458
|
+
return {
|
|
459
|
+
"success": True,
|
|
460
|
+
"pod_id": pod_id,
|
|
461
|
+
"status": data.get("podResume", {}).get("desiredStatus", "RUNNING")
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
def _normalize_status(self, runpod_status: str) -> str:
|
|
465
|
+
"""Convert RunPod status to normalized status."""
|
|
466
|
+
status_map = {
|
|
467
|
+
"RUNNING": "ACTIVE",
|
|
468
|
+
"PENDING": "PENDING",
|
|
469
|
+
"EXITED": "TERMINATED",
|
|
470
|
+
"STOPPED": "STOPPED",
|
|
471
|
+
"STOPPING": "STOPPING",
|
|
472
|
+
"STARTING": "STARTING",
|
|
473
|
+
"TERMINATING": "TERMINATING",
|
|
474
|
+
"TERMINATED": "TERMINATED",
|
|
475
|
+
"ERROR": "ERROR"
|
|
476
|
+
}
|
|
477
|
+
return status_map.get(runpod_status.upper(), runpod_status)
|
|
478
|
+
|
|
479
|
+
def normalize_pod(self, pod_data: dict[str, Any]) -> NormalizedPod:
|
|
480
|
+
"""Convert RunPod pod data to normalized format."""
|
|
481
|
+
ssh_connection = None
|
|
482
|
+
runtime = pod_data.get("runtime", {})
|
|
483
|
+
if runtime:
|
|
484
|
+
ports = runtime.get("ports", [])
|
|
485
|
+
for port in ports:
|
|
486
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
487
|
+
ip = port.get("ip")
|
|
488
|
+
public_port = port.get("publicPort")
|
|
489
|
+
ssh_connection = f"ssh root@{ip} -p {public_port}"
|
|
490
|
+
break
|
|
491
|
+
|
|
492
|
+
return NormalizedPod(
|
|
493
|
+
id=pod_data.get("id", ""),
|
|
494
|
+
name=pod_data.get("name", ""),
|
|
495
|
+
status=self._normalize_status(pod_data.get("desiredStatus", "PENDING")),
|
|
496
|
+
gpu_name=pod_data.get("machine", {}).get("gpuDisplayName", ""),
|
|
497
|
+
gpu_count=pod_data.get("gpuCount", 1),
|
|
498
|
+
price_hr=pod_data.get("costPerHr", 0.0),
|
|
499
|
+
ssh_connection=ssh_connection,
|
|
500
|
+
ip=None,
|
|
501
|
+
provider=self.PROVIDER_NAME,
|
|
502
|
+
created_at=datetime.utcnow().isoformat(),
|
|
503
|
+
updated_at=datetime.utcnow().isoformat()
|
|
504
|
+
)
|