more-compute 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontend/app/globals.css +734 -27
- frontend/app/layout.tsx +13 -3
- frontend/components/Notebook.tsx +2 -14
- frontend/components/cell/MonacoCell.tsx +99 -5
- frontend/components/layout/Sidebar.tsx +39 -4
- frontend/components/panels/ClaudePanel.tsx +461 -0
- frontend/components/popups/ComputePopup.tsx +739 -418
- frontend/components/popups/FilterPopup.tsx +305 -189
- frontend/components/popups/MetricsPopup.tsx +20 -1
- frontend/components/popups/ProviderConfigModal.tsx +322 -0
- frontend/components/popups/ProviderDropdown.tsx +398 -0
- frontend/components/popups/SettingsPopup.tsx +1 -1
- frontend/contexts/ClaudeContext.tsx +392 -0
- frontend/contexts/PodWebSocketContext.tsx +16 -21
- frontend/hooks/useInlineDiff.ts +269 -0
- frontend/lib/api.ts +323 -12
- frontend/lib/settings.ts +5 -0
- frontend/lib/websocket-native.ts +4 -8
- frontend/lib/websocket.ts +1 -2
- frontend/package-lock.json +733 -36
- frontend/package.json +2 -0
- frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
- frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
- frontend/public/assets/icons/providers/runpod.svg +9 -0
- frontend/public/assets/icons/providers/vastai.svg +1 -0
- frontend/settings.md +54 -0
- frontend/tsconfig.tsbuildinfo +1 -0
- frontend/types/claude.ts +194 -0
- kernel_run.py +13 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
- morecompute/__init__.py +1 -1
- morecompute/__version__.py +1 -1
- morecompute/execution/executor.py +24 -67
- morecompute/execution/worker.py +6 -72
- morecompute/models/api_models.py +62 -0
- morecompute/notebook.py +11 -0
- morecompute/server.py +641 -133
- morecompute/services/claude_service.py +392 -0
- morecompute/services/pod_manager.py +168 -67
- morecompute/services/pod_monitor.py +67 -39
- morecompute/services/prime_intellect.py +0 -4
- morecompute/services/providers/__init__.py +92 -0
- morecompute/services/providers/base_provider.py +336 -0
- morecompute/services/providers/lambda_labs_provider.py +394 -0
- morecompute/services/providers/provider_factory.py +194 -0
- morecompute/services/providers/runpod_provider.py +504 -0
- morecompute/services/providers/vastai_provider.py +407 -0
- morecompute/utils/cell_magics.py +0 -3
- morecompute/utils/config_util.py +93 -3
- morecompute/utils/special_commands.py +5 -32
- morecompute/utils/version_check.py +117 -0
- frontend/styling_README.md +0 -23
- {more_compute-0.4.3.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
- {more_compute-0.4.3.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""Vast.ai GPU cloud provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
|
|
7
|
+
from .base_provider import BaseGPUProvider, NormalizedPod
|
|
8
|
+
from .provider_factory import register_provider
|
|
9
|
+
from ...models.api_models import PodResponse
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@register_provider
|
|
13
|
+
class VastAIProvider(BaseGPUProvider):
|
|
14
|
+
"""Vast.ai GPU cloud provider using REST API.
|
|
15
|
+
|
|
16
|
+
Vast.ai provides community GPUs at competitive prices.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
PROVIDER_NAME = "vastai"
|
|
20
|
+
PROVIDER_DISPLAY_NAME = "Vast.ai"
|
|
21
|
+
API_KEY_ENV_NAME = "VASTAI_API_KEY"
|
|
22
|
+
SUPPORTS_SSH = True
|
|
23
|
+
DASHBOARD_URL = "https://cloud.vast.ai/"
|
|
24
|
+
|
|
25
|
+
BASE_URL = "https://console.vast.ai/api/v0"
|
|
26
|
+
|
|
27
|
+
def __init__(self, api_key: str | None = None):
|
|
28
|
+
super().__init__(api_key)
|
|
29
|
+
|
|
30
|
+
def _get_auth_headers(self) -> dict[str, str]:
|
|
31
|
+
"""Get Vast.ai authentication headers."""
|
|
32
|
+
return {
|
|
33
|
+
"Content-Type": "application/json",
|
|
34
|
+
"Accept": "application/json",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
async def _make_vast_request(
|
|
38
|
+
self,
|
|
39
|
+
method: str,
|
|
40
|
+
endpoint: str,
|
|
41
|
+
params: dict[str, Any] | None = None,
|
|
42
|
+
json_data: dict[str, Any] | None = None
|
|
43
|
+
) -> dict[str, Any]:
|
|
44
|
+
"""Make an authenticated request to Vast.ai API.
|
|
45
|
+
|
|
46
|
+
Vast.ai uses api_key as a query parameter.
|
|
47
|
+
"""
|
|
48
|
+
import httpx
|
|
49
|
+
from fastapi import HTTPException
|
|
50
|
+
|
|
51
|
+
url = f"{self.BASE_URL}{endpoint}"
|
|
52
|
+
|
|
53
|
+
# Add API key to params
|
|
54
|
+
if params is None:
|
|
55
|
+
params = {}
|
|
56
|
+
params["api_key"] = self.api_key
|
|
57
|
+
|
|
58
|
+
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
59
|
+
try:
|
|
60
|
+
response = await client.request(
|
|
61
|
+
method=method,
|
|
62
|
+
url=url,
|
|
63
|
+
headers=self._get_auth_headers(),
|
|
64
|
+
params=params,
|
|
65
|
+
json=json_data,
|
|
66
|
+
timeout=30.0
|
|
67
|
+
)
|
|
68
|
+
response.raise_for_status()
|
|
69
|
+
|
|
70
|
+
if response.status_code == 204 or not response.content:
|
|
71
|
+
return {}
|
|
72
|
+
|
|
73
|
+
return response.json()
|
|
74
|
+
except httpx.HTTPStatusError as e:
|
|
75
|
+
raise HTTPException(
|
|
76
|
+
status_code=e.response.status_code,
|
|
77
|
+
detail=f"Vast.ai API error: {e.response.text}"
|
|
78
|
+
)
|
|
79
|
+
except httpx.RequestError as e:
|
|
80
|
+
raise HTTPException(
|
|
81
|
+
status_code=503,
|
|
82
|
+
detail=f"Vast.ai connection error: {str(e)}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def get_gpu_availability(
|
|
86
|
+
self,
|
|
87
|
+
regions: list[str] | None = None,
|
|
88
|
+
gpu_count: int | None = None,
|
|
89
|
+
gpu_type: str | None = None,
|
|
90
|
+
verified: bool | None = None,
|
|
91
|
+
min_reliability: float | None = None,
|
|
92
|
+
min_gpu_ram: float | None = None,
|
|
93
|
+
**kwargs: Any
|
|
94
|
+
) -> dict[str, Any]:
|
|
95
|
+
"""Get available GPU offers from Vast.ai marketplace.
|
|
96
|
+
|
|
97
|
+
Vast.ai has a marketplace model where users list their GPUs.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
regions: Filter by region/geolocation
|
|
101
|
+
gpu_count: Minimum number of GPUs
|
|
102
|
+
gpu_type: Filter by GPU name (exact match)
|
|
103
|
+
verified: If True, only show verified hosts
|
|
104
|
+
min_reliability: Minimum reliability score (0.0-1.0)
|
|
105
|
+
min_gpu_ram: Minimum GPU RAM in GB
|
|
106
|
+
"""
|
|
107
|
+
# Build query for offers
|
|
108
|
+
query = {
|
|
109
|
+
"rentable": {"eq": True},
|
|
110
|
+
"rented": {"eq": False},
|
|
111
|
+
"order": [["dph_total", "asc"]], # Sort by price
|
|
112
|
+
"type": "on-demand"
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Filter by verified status (default to True if not specified)
|
|
116
|
+
if verified is True or verified is None:
|
|
117
|
+
query["verified"] = {"eq": True}
|
|
118
|
+
|
|
119
|
+
# Filter by GPU type (partial match using contains-like behavior)
|
|
120
|
+
if gpu_type:
|
|
121
|
+
# Vast.ai uses exact match, so we'll do client-side filtering for partial match
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
# Filter by GPU count
|
|
125
|
+
if gpu_count:
|
|
126
|
+
query["num_gpus"] = {"gte": gpu_count}
|
|
127
|
+
|
|
128
|
+
# Filter by reliability
|
|
129
|
+
if min_reliability is not None:
|
|
130
|
+
query["reliability2"] = {"gte": min_reliability}
|
|
131
|
+
|
|
132
|
+
# Filter by GPU RAM (in MB for Vast.ai)
|
|
133
|
+
if min_gpu_ram is not None:
|
|
134
|
+
query["gpu_ram"] = {"gte": min_gpu_ram * 1024} # Convert GB to MB
|
|
135
|
+
|
|
136
|
+
response = await self._make_vast_request(
|
|
137
|
+
"GET",
|
|
138
|
+
"/bundles",
|
|
139
|
+
params={"q": json.dumps(query)}
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
offers = response.get("offers", [])
|
|
143
|
+
|
|
144
|
+
# Transform to standardized format
|
|
145
|
+
gpus = []
|
|
146
|
+
for offer in offers:
|
|
147
|
+
# Filter by region if specified
|
|
148
|
+
if regions and offer.get("geolocation", "").split(",")[0] not in regions:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
# Client-side filter by GPU type (partial match)
|
|
152
|
+
if gpu_type:
|
|
153
|
+
gpu_name = offer.get("gpu_name", "").lower()
|
|
154
|
+
if gpu_type.lower() not in gpu_name:
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
gpus.append({
|
|
158
|
+
"gpuType": offer.get("gpu_name", ""),
|
|
159
|
+
"gpuName": offer.get("gpu_name", ""),
|
|
160
|
+
"gpuCount": offer.get("num_gpus", 1),
|
|
161
|
+
"priceHr": offer.get("dph_total", 0),
|
|
162
|
+
"cloudId": str(offer.get("id")),
|
|
163
|
+
"socket": str(offer.get("id")),
|
|
164
|
+
"region": offer.get("geolocation", "").split(",")[0] if offer.get("geolocation") else None,
|
|
165
|
+
"geolocation": offer.get("geolocation"),
|
|
166
|
+
"reliabilityScore": offer.get("reliability2", offer.get("reliability", 0)),
|
|
167
|
+
"dlPerf": offer.get("dlperf", 0),
|
|
168
|
+
"memoryGb": offer.get("gpu_ram", 0) / 1024, # Convert MB to GB
|
|
169
|
+
"storageGb": offer.get("disk_space", 0),
|
|
170
|
+
"cpuCores": offer.get("cpu_cores_effective"),
|
|
171
|
+
"cpuRam": offer.get("cpu_ram", 0) / 1024, # Convert MB to GB
|
|
172
|
+
"verified": offer.get("verified", False),
|
|
173
|
+
"provider": self.PROVIDER_NAME
|
|
174
|
+
})
|
|
175
|
+
|
|
176
|
+
return {
|
|
177
|
+
"data": gpus,
|
|
178
|
+
"total_count": len(gpus),
|
|
179
|
+
"provider": self.PROVIDER_NAME
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
async def create_pod(self, request: Any) -> PodResponse:
|
|
183
|
+
"""Create a new Vast.ai instance.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
request: CreatePodRequest with pod configuration
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
PodResponse with created instance info
|
|
190
|
+
"""
|
|
191
|
+
import sys
|
|
192
|
+
from fastapi import HTTPException
|
|
193
|
+
|
|
194
|
+
pod_config = request.pod if hasattr(request, 'pod') else request
|
|
195
|
+
|
|
196
|
+
offer_id = pod_config.cloudId if hasattr(pod_config, 'cloudId') else pod_config.get("cloudId")
|
|
197
|
+
image = pod_config.image if hasattr(pod_config, 'image') else pod_config.get("image", "nvidia/cuda:12.1.0-devel-ubuntu22.04")
|
|
198
|
+
disk_size = pod_config.diskSize if hasattr(pod_config, 'diskSize') else pod_config.get("diskSize", 20)
|
|
199
|
+
name = pod_config.name if hasattr(pod_config, 'name') else pod_config.get("name", "morecompute-instance")
|
|
200
|
+
|
|
201
|
+
# Create the instance - Vast.ai API format
|
|
202
|
+
payload = {
|
|
203
|
+
"image": image,
|
|
204
|
+
"disk": float(disk_size),
|
|
205
|
+
"label": name,
|
|
206
|
+
"runtype": "ssh",
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
# Add environment variables if specified
|
|
210
|
+
env_vars = pod_config.envVars if hasattr(pod_config, 'envVars') else pod_config.get("envVars")
|
|
211
|
+
if env_vars:
|
|
212
|
+
env_dict = {e.key: e.value for e in env_vars} if hasattr(env_vars[0], 'key') else env_vars
|
|
213
|
+
payload["env"] = env_dict
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
response = await self._make_vast_request(
|
|
217
|
+
"PUT",
|
|
218
|
+
f"/asks/{offer_id}/",
|
|
219
|
+
json_data=payload
|
|
220
|
+
)
|
|
221
|
+
except HTTPException as e:
|
|
222
|
+
# Check for specific error cases
|
|
223
|
+
error_detail = str(e.detail) if hasattr(e, 'detail') else str(e)
|
|
224
|
+
|
|
225
|
+
if "402" in error_detail or "insufficient" in error_detail.lower() or "balance" in error_detail.lower():
|
|
226
|
+
raise HTTPException(
|
|
227
|
+
status_code=402,
|
|
228
|
+
detail="Insufficient funds in your Vast.ai account. Please add credits at https://cloud.vast.ai/"
|
|
229
|
+
)
|
|
230
|
+
raise
|
|
231
|
+
|
|
232
|
+
instance_id = response.get("new_contract")
|
|
233
|
+
if not instance_id:
|
|
234
|
+
# Check if response indicates an error
|
|
235
|
+
if response.get("success") is False:
|
|
236
|
+
error_msg = response.get("error", response.get("msg", "Unknown error"))
|
|
237
|
+
raise HTTPException(status_code=400, detail=f"Vast.ai error: {error_msg}")
|
|
238
|
+
raise HTTPException(status_code=500, detail="Failed to create Vast.ai instance - no contract ID returned")
|
|
239
|
+
|
|
240
|
+
# Get instance details
|
|
241
|
+
return await self.get_pod(str(instance_id))
|
|
242
|
+
|
|
243
|
+
async def get_pods(
|
|
244
|
+
self,
|
|
245
|
+
status: str | None = None,
|
|
246
|
+
limit: int = 100,
|
|
247
|
+
offset: int = 0
|
|
248
|
+
) -> dict[str, Any]:
|
|
249
|
+
"""Get list of all Vast.ai instances."""
|
|
250
|
+
response = await self._make_vast_request(
|
|
251
|
+
"GET",
|
|
252
|
+
"/instances"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
instances = response.get("instances", [])
|
|
256
|
+
|
|
257
|
+
# Filter by status if specified
|
|
258
|
+
if status:
|
|
259
|
+
status_lower = status.lower()
|
|
260
|
+
instances = [i for i in instances if self._normalize_status(i.get("actual_status", "")).lower() == status_lower]
|
|
261
|
+
|
|
262
|
+
# Apply pagination
|
|
263
|
+
instances = instances[offset:offset + limit]
|
|
264
|
+
|
|
265
|
+
# Transform to standardized format
|
|
266
|
+
pods = []
|
|
267
|
+
for instance in instances:
|
|
268
|
+
ssh_connection = self._build_ssh_connection(instance)
|
|
269
|
+
|
|
270
|
+
pods.append({
|
|
271
|
+
"id": str(instance.get("id")),
|
|
272
|
+
"name": instance.get("label", f"vast-{instance.get('id')}"),
|
|
273
|
+
"status": self._normalize_status(instance.get("actual_status", "loading")),
|
|
274
|
+
"gpuName": instance.get("gpu_name", ""),
|
|
275
|
+
"gpuCount": instance.get("num_gpus", 1),
|
|
276
|
+
"priceHr": instance.get("dph_total", 0),
|
|
277
|
+
"sshConnection": ssh_connection,
|
|
278
|
+
"ip": instance.get("public_ipaddr"),
|
|
279
|
+
"region": instance.get("geolocation", "").split(",")[0] if instance.get("geolocation") else None,
|
|
280
|
+
"createdAt": instance.get("start_date", datetime.now(timezone.utc).isoformat()),
|
|
281
|
+
"updatedAt": datetime.now(timezone.utc).isoformat(),
|
|
282
|
+
"provider": self.PROVIDER_NAME
|
|
283
|
+
})
|
|
284
|
+
|
|
285
|
+
return {
|
|
286
|
+
"data": pods,
|
|
287
|
+
"total_count": len(pods),
|
|
288
|
+
"offset": offset,
|
|
289
|
+
"limit": limit,
|
|
290
|
+
"provider": self.PROVIDER_NAME
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
async def get_pod(self, pod_id: str) -> PodResponse:
|
|
294
|
+
"""Get details for a specific Vast.ai instance."""
|
|
295
|
+
from fastapi import HTTPException
|
|
296
|
+
|
|
297
|
+
response = await self._make_vast_request(
|
|
298
|
+
"GET",
|
|
299
|
+
"/instances",
|
|
300
|
+
params={"owner": "me"}
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
instances = response.get("instances", [])
|
|
304
|
+
instance = next((i for i in instances if str(i.get("id")) == pod_id), None)
|
|
305
|
+
|
|
306
|
+
if not instance:
|
|
307
|
+
raise HTTPException(status_code=404, detail=f"Instance {pod_id} not found")
|
|
308
|
+
|
|
309
|
+
ssh_connection = self._build_ssh_connection(instance)
|
|
310
|
+
|
|
311
|
+
now = datetime.now(timezone.utc)
|
|
312
|
+
return PodResponse(
|
|
313
|
+
id=str(instance.get("id", "")),
|
|
314
|
+
userId="",
|
|
315
|
+
teamId=None,
|
|
316
|
+
name=instance.get("label", f"vast-{instance.get('id')}"),
|
|
317
|
+
status=self._normalize_status(instance.get("actual_status", "loading")),
|
|
318
|
+
gpuName=instance.get("gpu_name", ""),
|
|
319
|
+
gpuCount=instance.get("num_gpus", 1),
|
|
320
|
+
priceHr=instance.get("dph_total", 0),
|
|
321
|
+
sshConnection=ssh_connection,
|
|
322
|
+
ip=instance.get("public_ipaddr"),
|
|
323
|
+
createdAt=now,
|
|
324
|
+
updatedAt=now
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def _build_ssh_connection(self, instance: dict[str, Any]) -> str | None:
|
|
328
|
+
"""Build SSH connection string from Vast.ai instance data."""
|
|
329
|
+
ip = instance.get("public_ipaddr") or instance.get("ssh_host")
|
|
330
|
+
port = instance.get("ssh_port", 22)
|
|
331
|
+
|
|
332
|
+
if not ip:
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
return f"ssh root@{ip} -p {port}"
|
|
336
|
+
|
|
337
|
+
async def delete_pod(self, pod_id: str) -> dict[str, Any]:
|
|
338
|
+
"""Destroy a Vast.ai instance."""
|
|
339
|
+
response = await self._make_vast_request(
|
|
340
|
+
"DELETE",
|
|
341
|
+
f"/instances/{pod_id}/"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
return {
|
|
345
|
+
"success": response.get("success", True),
|
|
346
|
+
"pod_id": pod_id,
|
|
347
|
+
"provider": self.PROVIDER_NAME
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
async def stop_pod(self, pod_id: str) -> dict[str, Any]:
|
|
351
|
+
"""Stop a Vast.ai instance (without destroying)."""
|
|
352
|
+
response = await self._make_vast_request(
|
|
353
|
+
"PUT",
|
|
354
|
+
f"/instances/{pod_id}/",
|
|
355
|
+
json_data={"state": "stopped"}
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return {
|
|
359
|
+
"success": True,
|
|
360
|
+
"pod_id": pod_id,
|
|
361
|
+
"action": "stopped"
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
async def start_pod(self, pod_id: str) -> dict[str, Any]:
|
|
365
|
+
"""Start a stopped Vast.ai instance."""
|
|
366
|
+
response = await self._make_vast_request(
|
|
367
|
+
"PUT",
|
|
368
|
+
f"/instances/{pod_id}/",
|
|
369
|
+
json_data={"state": "running"}
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
return {
|
|
373
|
+
"success": True,
|
|
374
|
+
"pod_id": pod_id,
|
|
375
|
+
"action": "started"
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
def _normalize_status(self, vast_status: str) -> str:
|
|
379
|
+
"""Convert Vast.ai status to normalized status."""
|
|
380
|
+
status_map = {
|
|
381
|
+
"running": "ACTIVE",
|
|
382
|
+
"loading": "STARTING",
|
|
383
|
+
"created": "PENDING",
|
|
384
|
+
"exited": "STOPPED",
|
|
385
|
+
"offline": "STOPPED",
|
|
386
|
+
"error": "ERROR",
|
|
387
|
+
"destroying": "TERMINATING"
|
|
388
|
+
}
|
|
389
|
+
return status_map.get(vast_status.lower(), vast_status.upper())
|
|
390
|
+
|
|
391
|
+
def normalize_pod(self, pod_data: dict[str, Any]) -> NormalizedPod:
|
|
392
|
+
"""Convert Vast.ai instance data to normalized format."""
|
|
393
|
+
ssh_connection = self._build_ssh_connection(pod_data)
|
|
394
|
+
|
|
395
|
+
return NormalizedPod(
|
|
396
|
+
id=str(pod_data.get("id", "")),
|
|
397
|
+
name=pod_data.get("label", f"vast-{pod_data.get('id')}"),
|
|
398
|
+
status=self._normalize_status(pod_data.get("actual_status", "loading")),
|
|
399
|
+
gpu_name=pod_data.get("gpu_name", ""),
|
|
400
|
+
gpu_count=pod_data.get("num_gpus", 1),
|
|
401
|
+
price_hr=pod_data.get("dph_total", 0),
|
|
402
|
+
ssh_connection=ssh_connection,
|
|
403
|
+
ip=pod_data.get("public_ipaddr"),
|
|
404
|
+
provider=self.PROVIDER_NAME,
|
|
405
|
+
created_at=pod_data.get("start_date", datetime.now(timezone.utc).isoformat()),
|
|
406
|
+
updated_at=datetime.now(timezone.utc).isoformat()
|
|
407
|
+
)
|
morecompute/utils/cell_magics.py
CHANGED
|
@@ -251,7 +251,6 @@ class CellMagicHandlers:
|
|
|
251
251
|
# Track process for interrupt handling
|
|
252
252
|
if hasattr(cell_magic_handler, 'special_handler'):
|
|
253
253
|
cell_magic_handler.special_handler.current_process_sync = process
|
|
254
|
-
print(f"[CELL_MAGIC] Tracking sync subprocess PID={process.pid}", file=sys.stderr, flush=True)
|
|
255
254
|
|
|
256
255
|
# Read and print output line by line (real-time streaming)
|
|
257
256
|
def read_stream(stream, output_type):
|
|
@@ -292,7 +291,6 @@ class CellMagicHandlers:
|
|
|
292
291
|
if hasattr(cell_magic_handler, 'special_handler'):
|
|
293
292
|
if cell_magic_handler.special_handler.sync_interrupted:
|
|
294
293
|
# Process was killed by interrupt handler
|
|
295
|
-
print(f"[CELL_MAGIC] Process was interrupted, raising KeyboardInterrupt", file=sys.stderr, flush=True)
|
|
296
294
|
raise KeyboardInterrupt("Execution interrupted by user")
|
|
297
295
|
|
|
298
296
|
return_code = process.returncode
|
|
@@ -312,7 +310,6 @@ class CellMagicHandlers:
|
|
|
312
310
|
# Clear process reference
|
|
313
311
|
if hasattr(cell_magic_handler, 'special_handler'):
|
|
314
312
|
cell_magic_handler.special_handler.current_process_sync = None
|
|
315
|
-
print(f"[CELL_MAGIC] Cleared sync subprocess reference", file=sys.stderr, flush=True)
|
|
316
313
|
|
|
317
314
|
return return_code
|
|
318
315
|
|
morecompute/utils/config_util.py
CHANGED
|
@@ -12,8 +12,8 @@ CONFIG_FILE = CONFIG_DIR / "config.json"
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def _ensure_config_dir() -> None:
|
|
15
|
-
"""Ensure the config directory exists."""
|
|
16
|
-
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
15
|
+
"""Ensure the config directory exists with secure permissions."""
|
|
16
|
+
CONFIG_DIR.mkdir(parents=True, exist_ok=True, mode=0o700)
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def _load_config() -> dict:
|
|
@@ -28,10 +28,12 @@ def _load_config() -> dict:
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def _save_config(config: dict) -> None:
|
|
31
|
-
"""Save config to JSON file."""
|
|
31
|
+
"""Save config to JSON file with secure permissions."""
|
|
32
32
|
_ensure_config_dir()
|
|
33
33
|
with CONFIG_FILE.open("w", encoding="utf-8") as f:
|
|
34
34
|
json.dump(config, f, indent=2)
|
|
35
|
+
# Set secure file permissions (owner read/write only)
|
|
36
|
+
CONFIG_FILE.chmod(0o600)
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
def load_api_key(key_name: str) -> Optional[str]:
|
|
@@ -73,3 +75,91 @@ def save_api_key(key_name: str, api_key: str) -> None:
|
|
|
73
75
|
config = _load_config()
|
|
74
76
|
config[key_name] = api_key
|
|
75
77
|
_save_config(config)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def delete_api_key(key_name: str) -> bool:
|
|
81
|
+
"""
|
|
82
|
+
Delete an API key from config.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
key_name: Key name to delete
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
True if key was deleted, False if it didn't exist
|
|
89
|
+
"""
|
|
90
|
+
config = _load_config()
|
|
91
|
+
if key_name in config:
|
|
92
|
+
del config[key_name]
|
|
93
|
+
_save_config(config)
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_active_provider() -> Optional[str]:
|
|
99
|
+
"""
|
|
100
|
+
Get the currently active provider name.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Provider name or None if not set
|
|
104
|
+
"""
|
|
105
|
+
config = _load_config()
|
|
106
|
+
return config.get("active_provider")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def set_active_provider(provider_name: str) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Set the active provider.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
provider_name: The provider to make active
|
|
115
|
+
"""
|
|
116
|
+
config = _load_config()
|
|
117
|
+
config["active_provider"] = provider_name
|
|
118
|
+
_save_config(config)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def get_all_configured_keys() -> dict[str, bool]:
|
|
122
|
+
"""
|
|
123
|
+
Get a mapping of all API key names to whether they are configured.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dict mapping key names to True/False
|
|
127
|
+
"""
|
|
128
|
+
config = _load_config()
|
|
129
|
+
|
|
130
|
+
# Known provider API key names (SSH-based providers only)
|
|
131
|
+
key_names = [
|
|
132
|
+
"RUNPOD_API_KEY",
|
|
133
|
+
"LAMBDA_LABS_API_KEY",
|
|
134
|
+
"VASTAI_API_KEY",
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
result = {}
|
|
138
|
+
for key_name in key_names:
|
|
139
|
+
# Check environment first, then config
|
|
140
|
+
env_val = os.getenv(key_name)
|
|
141
|
+
config_val = config.get(key_name)
|
|
142
|
+
result[key_name] = bool(env_val or config_val)
|
|
143
|
+
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_provider_api_keys(provider_name: str) -> dict[str, Optional[str]]:
|
|
148
|
+
"""
|
|
149
|
+
Get all API keys needed for a specific provider.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
provider_name: Provider name (e.g., "runpod", "modal")
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Dict mapping key names to their values (or None if not set)
|
|
156
|
+
"""
|
|
157
|
+
# Provider to key name mappings (SSH-based providers only)
|
|
158
|
+
provider_keys = {
|
|
159
|
+
"runpod": ["RUNPOD_API_KEY"],
|
|
160
|
+
"lambda_labs": ["LAMBDA_LABS_API_KEY"],
|
|
161
|
+
"vastai": ["VASTAI_API_KEY"],
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
key_names = provider_keys.get(provider_name, [])
|
|
165
|
+
return {key: load_api_key(key) for key in key_names}
|
|
@@ -1,13 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import io
|
|
3
|
-
import sys
|
|
4
2
|
import asyncio
|
|
5
3
|
import subprocess
|
|
6
4
|
import time
|
|
7
5
|
import shlex
|
|
8
|
-
import
|
|
9
|
-
from contextlib import redirect_stdout, redirect_stderr
|
|
10
|
-
from typing import Dict, Any, Optional, Tuple, Union
|
|
6
|
+
from typing import Dict, Any, Optional, Union
|
|
11
7
|
from fastapi import WebSocket
|
|
12
8
|
|
|
13
9
|
from .cell_magics import CellMagicHandlers
|
|
@@ -124,7 +120,6 @@ class AsyncSpecialCommandHandler:
|
|
|
124
120
|
|
|
125
121
|
# Track process for interrupt handling
|
|
126
122
|
self.current_process = process
|
|
127
|
-
print(f"[SPECIAL_CMD] Started subprocess PID={process.pid}", file=sys.stderr, flush=True)
|
|
128
123
|
|
|
129
124
|
try:
|
|
130
125
|
# Stream output concurrently
|
|
@@ -138,22 +133,16 @@ class AsyncSpecialCommandHandler:
|
|
|
138
133
|
# Track tasks for interruption
|
|
139
134
|
self.stream_tasks = [stdout_task, stderr_task]
|
|
140
135
|
|
|
141
|
-
print(f"[SPECIAL_CMD] Waiting for stream tasks to complete...", file=sys.stderr, flush=True)
|
|
142
136
|
# Wait for both streams to complete
|
|
143
137
|
await asyncio.gather(stdout_task, stderr_task, return_exceptions=True)
|
|
144
138
|
|
|
145
|
-
print(f"[SPECIAL_CMD] Streams complete, waiting for process to exit...", file=sys.stderr, flush=True)
|
|
146
139
|
# Wait for process completion
|
|
147
140
|
return_code = await process.wait()
|
|
148
|
-
print(f"[SPECIAL_CMD] Process exited with code {return_code}", file=sys.stderr, flush=True)
|
|
149
141
|
except asyncio.CancelledError:
|
|
150
142
|
# Task was cancelled - treat as interrupt
|
|
151
|
-
print(f"[SPECIAL_CMD] Task cancelled, treating as interrupt", file=sys.stderr, flush=True)
|
|
152
143
|
return_code = -15 # SIGTERM
|
|
153
144
|
except Exception as e:
|
|
154
|
-
print(f"[SPECIAL_CMD] Exception during execution: {e}", file=sys.stderr, flush=True)
|
|
155
145
|
import traceback
|
|
156
|
-
traceback.print_exc()
|
|
157
146
|
# Set error result
|
|
158
147
|
result["status"] = "error"
|
|
159
148
|
result["error"] = {
|
|
@@ -167,11 +156,9 @@ class AsyncSpecialCommandHandler:
|
|
|
167
156
|
# Clear process reference when done
|
|
168
157
|
self.current_process = None
|
|
169
158
|
self.stream_tasks = []
|
|
170
|
-
print(f"[SPECIAL_CMD] Cleared process reference", file=sys.stderr, flush=True)
|
|
171
159
|
|
|
172
160
|
# Check if process was interrupted (negative return code means killed by signal)
|
|
173
161
|
if return_code < 0:
|
|
174
|
-
print(f"[SPECIAL_CMD] Process was interrupted (return_code={return_code}), setting KeyboardInterrupt error", file=sys.stderr, flush=True)
|
|
175
162
|
result["status"] = "error"
|
|
176
163
|
result["error"] = {
|
|
177
164
|
"output_type": "error",
|
|
@@ -189,8 +176,6 @@ class AsyncSpecialCommandHandler:
|
|
|
189
176
|
"traceback": [f"Shell command '{command}' failed"]
|
|
190
177
|
}
|
|
191
178
|
|
|
192
|
-
print(f"[SPECIAL_CMD] Returning result: status={result['status']}, return_code={return_code}", file=sys.stderr, flush=True)
|
|
193
|
-
|
|
194
179
|
# If pip install/uninstall occurred, notify clients to refresh packages
|
|
195
180
|
try:
|
|
196
181
|
if websocket and return_code == 0 and (command.startswith('pip install') or command.startswith('pip uninstall') or 'pip install' in command or 'pip uninstall' in command):
|
|
@@ -228,52 +213,40 @@ class AsyncSpecialCommandHandler:
|
|
|
228
213
|
# Cancel stream tasks first
|
|
229
214
|
for task in self.stream_tasks:
|
|
230
215
|
if not task.done():
|
|
231
|
-
print(f"[SPECIAL_CMD] Cancelling stream task", file=sys.stderr, flush=True)
|
|
232
216
|
task.cancel()
|
|
233
217
|
|
|
234
218
|
# Interrupt async subprocess
|
|
235
219
|
if self.current_process:
|
|
236
220
|
try:
|
|
237
|
-
print(f"[SPECIAL_CMD] Interrupting async subprocess PID={self.current_process.pid}", file=sys.stderr, flush=True)
|
|
238
221
|
self.current_process.terminate()
|
|
239
222
|
|
|
240
223
|
# Give it a moment to terminate gracefully
|
|
241
224
|
try:
|
|
242
225
|
await asyncio.wait_for(self.current_process.wait(), timeout=1.0)
|
|
243
|
-
print(f"[SPECIAL_CMD] Async subprocess terminated gracefully", file=sys.stderr, flush=True)
|
|
244
226
|
except asyncio.TimeoutError:
|
|
245
227
|
# Force kill if it doesn't terminate
|
|
246
|
-
print(f"[SPECIAL_CMD] Async subprocess didn't terminate, force killing", file=sys.stderr, flush=True)
|
|
247
228
|
self.current_process.kill()
|
|
248
229
|
await self.current_process.wait()
|
|
249
|
-
print(f"[SPECIAL_CMD] Async subprocess killed", file=sys.stderr, flush=True)
|
|
250
230
|
|
|
251
|
-
except Exception
|
|
252
|
-
|
|
231
|
+
except Exception:
|
|
232
|
+
pass
|
|
253
233
|
|
|
254
234
|
# Interrupt sync subprocess
|
|
255
235
|
if self.current_process_sync:
|
|
256
236
|
try:
|
|
257
|
-
print(f"[SPECIAL_CMD] Interrupting sync subprocess PID={self.current_process_sync.pid}", file=sys.stderr, flush=True)
|
|
258
237
|
self.sync_interrupted = True # Set flag so shell commands know to stop
|
|
259
238
|
self.current_process_sync.terminate()
|
|
260
239
|
|
|
261
240
|
# Give it a moment to terminate gracefully
|
|
262
241
|
try:
|
|
263
242
|
self.current_process_sync.wait(timeout=1.0)
|
|
264
|
-
print(f"[SPECIAL_CMD] Sync subprocess terminated gracefully", file=sys.stderr, flush=True)
|
|
265
243
|
except subprocess.TimeoutExpired:
|
|
266
244
|
# Force kill if it doesn't terminate
|
|
267
|
-
print(f"[SPECIAL_CMD] Sync subprocess didn't terminate, force killing", file=sys.stderr, flush=True)
|
|
268
245
|
self.current_process_sync.kill()
|
|
269
246
|
self.current_process_sync.wait()
|
|
270
|
-
print(f"[SPECIAL_CMD] Sync subprocess killed", file=sys.stderr, flush=True)
|
|
271
247
|
|
|
272
|
-
except Exception
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
if not self.current_process and not self.current_process_sync:
|
|
276
|
-
print(f"[SPECIAL_CMD] No subprocess to interrupt", file=sys.stderr, flush=True)
|
|
248
|
+
except Exception:
|
|
249
|
+
pass
|
|
277
250
|
|
|
278
251
|
async def _stream_output(self, stream, stream_type: str, result: Dict[str, Any],
|
|
279
252
|
websocket: Optional[WebSocket] = None,
|