more-compute 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. frontend/app/globals.css +734 -27
  2. frontend/app/layout.tsx +13 -3
  3. frontend/components/Notebook.tsx +2 -14
  4. frontend/components/cell/MonacoCell.tsx +99 -5
  5. frontend/components/layout/Sidebar.tsx +39 -4
  6. frontend/components/panels/ClaudePanel.tsx +461 -0
  7. frontend/components/popups/ComputePopup.tsx +738 -447
  8. frontend/components/popups/FilterPopup.tsx +305 -189
  9. frontend/components/popups/MetricsPopup.tsx +20 -1
  10. frontend/components/popups/ProviderConfigModal.tsx +322 -0
  11. frontend/components/popups/ProviderDropdown.tsx +398 -0
  12. frontend/components/popups/SettingsPopup.tsx +1 -1
  13. frontend/contexts/ClaudeContext.tsx +392 -0
  14. frontend/contexts/PodWebSocketContext.tsx +16 -21
  15. frontend/hooks/useInlineDiff.ts +269 -0
  16. frontend/lib/api.ts +323 -12
  17. frontend/lib/settings.ts +5 -0
  18. frontend/lib/websocket-native.ts +4 -8
  19. frontend/lib/websocket.ts +1 -2
  20. frontend/package-lock.json +733 -36
  21. frontend/package.json +2 -0
  22. frontend/public/assets/icons/providers/lambda_labs.svg +22 -0
  23. frontend/public/assets/icons/providers/prime_intellect.svg +18 -0
  24. frontend/public/assets/icons/providers/runpod.svg +9 -0
  25. frontend/public/assets/icons/providers/vastai.svg +1 -0
  26. frontend/settings.md +54 -0
  27. frontend/tsconfig.tsbuildinfo +1 -0
  28. frontend/types/claude.ts +194 -0
  29. kernel_run.py +13 -0
  30. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/METADATA +53 -11
  31. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/RECORD +56 -37
  32. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/WHEEL +1 -1
  33. morecompute/__init__.py +1 -1
  34. morecompute/__version__.py +1 -1
  35. morecompute/execution/executor.py +24 -67
  36. morecompute/execution/worker.py +6 -72
  37. morecompute/models/api_models.py +62 -0
  38. morecompute/notebook.py +11 -0
  39. morecompute/server.py +641 -133
  40. morecompute/services/claude_service.py +392 -0
  41. morecompute/services/pod_manager.py +168 -67
  42. morecompute/services/pod_monitor.py +67 -39
  43. morecompute/services/prime_intellect.py +0 -4
  44. morecompute/services/providers/__init__.py +92 -0
  45. morecompute/services/providers/base_provider.py +336 -0
  46. morecompute/services/providers/lambda_labs_provider.py +394 -0
  47. morecompute/services/providers/provider_factory.py +194 -0
  48. morecompute/services/providers/runpod_provider.py +504 -0
  49. morecompute/services/providers/vastai_provider.py +407 -0
  50. morecompute/utils/cell_magics.py +0 -3
  51. morecompute/utils/config_util.py +93 -3
  52. morecompute/utils/special_commands.py +5 -32
  53. morecompute/utils/version_check.py +117 -0
  54. frontend/styling_README.md +0 -23
  55. {more_compute-0.4.4.dist-info/licenses → more_compute-0.5.0.dist-info}/LICENSE +0 -0
  56. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/entry_points.txt +0 -0
  57. {more_compute-0.4.4.dist-info → more_compute-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,407 @@
1
+ """Vast.ai GPU cloud provider implementation."""
2
+
3
+ import json
4
+ from typing import Any
5
+ from datetime import datetime, timezone
6
+
7
+ from .base_provider import BaseGPUProvider, NormalizedPod
8
+ from .provider_factory import register_provider
9
+ from ...models.api_models import PodResponse
10
+
11
+
12
+ @register_provider
13
+ class VastAIProvider(BaseGPUProvider):
14
+ """Vast.ai GPU cloud provider using REST API.
15
+
16
+ Vast.ai provides community GPUs at competitive prices.
17
+ """
18
+
19
+ PROVIDER_NAME = "vastai"
20
+ PROVIDER_DISPLAY_NAME = "Vast.ai"
21
+ API_KEY_ENV_NAME = "VASTAI_API_KEY"
22
+ SUPPORTS_SSH = True
23
+ DASHBOARD_URL = "https://cloud.vast.ai/"
24
+
25
+ BASE_URL = "https://console.vast.ai/api/v0"
26
+
27
+ def __init__(self, api_key: str | None = None):
28
+ super().__init__(api_key)
29
+
30
+ def _get_auth_headers(self) -> dict[str, str]:
31
+ """Get Vast.ai authentication headers."""
32
+ return {
33
+ "Content-Type": "application/json",
34
+ "Accept": "application/json",
35
+ }
36
+
37
+ async def _make_vast_request(
38
+ self,
39
+ method: str,
40
+ endpoint: str,
41
+ params: dict[str, Any] | None = None,
42
+ json_data: dict[str, Any] | None = None
43
+ ) -> dict[str, Any]:
44
+ """Make an authenticated request to Vast.ai API.
45
+
46
+ Vast.ai uses api_key as a query parameter.
47
+ """
48
+ import httpx
49
+ from fastapi import HTTPException
50
+
51
+ url = f"{self.BASE_URL}{endpoint}"
52
+
53
+ # Add API key to params
54
+ if params is None:
55
+ params = {}
56
+ params["api_key"] = self.api_key
57
+
58
+ async with httpx.AsyncClient(follow_redirects=True) as client:
59
+ try:
60
+ response = await client.request(
61
+ method=method,
62
+ url=url,
63
+ headers=self._get_auth_headers(),
64
+ params=params,
65
+ json=json_data,
66
+ timeout=30.0
67
+ )
68
+ response.raise_for_status()
69
+
70
+ if response.status_code == 204 or not response.content:
71
+ return {}
72
+
73
+ return response.json()
74
+ except httpx.HTTPStatusError as e:
75
+ raise HTTPException(
76
+ status_code=e.response.status_code,
77
+ detail=f"Vast.ai API error: {e.response.text}"
78
+ )
79
+ except httpx.RequestError as e:
80
+ raise HTTPException(
81
+ status_code=503,
82
+ detail=f"Vast.ai connection error: {str(e)}"
83
+ )
84
+
85
+ async def get_gpu_availability(
86
+ self,
87
+ regions: list[str] | None = None,
88
+ gpu_count: int | None = None,
89
+ gpu_type: str | None = None,
90
+ verified: bool | None = None,
91
+ min_reliability: float | None = None,
92
+ min_gpu_ram: float | None = None,
93
+ **kwargs: Any
94
+ ) -> dict[str, Any]:
95
+ """Get available GPU offers from Vast.ai marketplace.
96
+
97
+ Vast.ai has a marketplace model where users list their GPUs.
98
+
99
+ Args:
100
+ regions: Filter by region/geolocation
101
+ gpu_count: Minimum number of GPUs
102
+ gpu_type: Filter by GPU name (exact match)
103
+ verified: If True, only show verified hosts
104
+ min_reliability: Minimum reliability score (0.0-1.0)
105
+ min_gpu_ram: Minimum GPU RAM in GB
106
+ """
107
+ # Build query for offers
108
+ query = {
109
+ "rentable": {"eq": True},
110
+ "rented": {"eq": False},
111
+ "order": [["dph_total", "asc"]], # Sort by price
112
+ "type": "on-demand"
113
+ }
114
+
115
+ # Filter by verified status (default to True if not specified)
116
+ if verified is True or verified is None:
117
+ query["verified"] = {"eq": True}
118
+
119
+ # Filter by GPU type (partial match using contains-like behavior)
120
+ if gpu_type:
121
+ # Vast.ai uses exact match, so we'll do client-side filtering for partial match
122
+ pass
123
+
124
+ # Filter by GPU count
125
+ if gpu_count:
126
+ query["num_gpus"] = {"gte": gpu_count}
127
+
128
+ # Filter by reliability
129
+ if min_reliability is not None:
130
+ query["reliability2"] = {"gte": min_reliability}
131
+
132
+ # Filter by GPU RAM (in MB for Vast.ai)
133
+ if min_gpu_ram is not None:
134
+ query["gpu_ram"] = {"gte": min_gpu_ram * 1024} # Convert GB to MB
135
+
136
+ response = await self._make_vast_request(
137
+ "GET",
138
+ "/bundles",
139
+ params={"q": json.dumps(query)}
140
+ )
141
+
142
+ offers = response.get("offers", [])
143
+
144
+ # Transform to standardized format
145
+ gpus = []
146
+ for offer in offers:
147
+ # Filter by region if specified
148
+ if regions and offer.get("geolocation", "").split(",")[0] not in regions:
149
+ continue
150
+
151
+ # Client-side filter by GPU type (partial match)
152
+ if gpu_type:
153
+ gpu_name = offer.get("gpu_name", "").lower()
154
+ if gpu_type.lower() not in gpu_name:
155
+ continue
156
+
157
+ gpus.append({
158
+ "gpuType": offer.get("gpu_name", ""),
159
+ "gpuName": offer.get("gpu_name", ""),
160
+ "gpuCount": offer.get("num_gpus", 1),
161
+ "priceHr": offer.get("dph_total", 0),
162
+ "cloudId": str(offer.get("id")),
163
+ "socket": str(offer.get("id")),
164
+ "region": offer.get("geolocation", "").split(",")[0] if offer.get("geolocation") else None,
165
+ "geolocation": offer.get("geolocation"),
166
+ "reliabilityScore": offer.get("reliability2", offer.get("reliability", 0)),
167
+ "dlPerf": offer.get("dlperf", 0),
168
+ "memoryGb": offer.get("gpu_ram", 0) / 1024, # Convert MB to GB
169
+ "storageGb": offer.get("disk_space", 0),
170
+ "cpuCores": offer.get("cpu_cores_effective"),
171
+ "cpuRam": offer.get("cpu_ram", 0) / 1024, # Convert MB to GB
172
+ "verified": offer.get("verified", False),
173
+ "provider": self.PROVIDER_NAME
174
+ })
175
+
176
+ return {
177
+ "data": gpus,
178
+ "total_count": len(gpus),
179
+ "provider": self.PROVIDER_NAME
180
+ }
181
+
182
+ async def create_pod(self, request: Any) -> PodResponse:
183
+ """Create a new Vast.ai instance.
184
+
185
+ Args:
186
+ request: CreatePodRequest with pod configuration
187
+
188
+ Returns:
189
+ PodResponse with created instance info
190
+ """
191
+ import sys
192
+ from fastapi import HTTPException
193
+
194
+ pod_config = request.pod if hasattr(request, 'pod') else request
195
+
196
+ offer_id = pod_config.cloudId if hasattr(pod_config, 'cloudId') else pod_config.get("cloudId")
197
+ image = pod_config.image if hasattr(pod_config, 'image') else pod_config.get("image", "nvidia/cuda:12.1.0-devel-ubuntu22.04")
198
+ disk_size = pod_config.diskSize if hasattr(pod_config, 'diskSize') else pod_config.get("diskSize", 20)
199
+ name = pod_config.name if hasattr(pod_config, 'name') else pod_config.get("name", "morecompute-instance")
200
+
201
+ # Create the instance - Vast.ai API format
202
+ payload = {
203
+ "image": image,
204
+ "disk": float(disk_size),
205
+ "label": name,
206
+ "runtype": "ssh",
207
+ }
208
+
209
+ # Add environment variables if specified
210
+ env_vars = pod_config.envVars if hasattr(pod_config, 'envVars') else pod_config.get("envVars")
211
+ if env_vars:
212
+ env_dict = {e.key: e.value for e in env_vars} if hasattr(env_vars[0], 'key') else env_vars
213
+ payload["env"] = env_dict
214
+
215
+ try:
216
+ response = await self._make_vast_request(
217
+ "PUT",
218
+ f"/asks/{offer_id}/",
219
+ json_data=payload
220
+ )
221
+ except HTTPException as e:
222
+ # Check for specific error cases
223
+ error_detail = str(e.detail) if hasattr(e, 'detail') else str(e)
224
+
225
+ if "402" in error_detail or "insufficient" in error_detail.lower() or "balance" in error_detail.lower():
226
+ raise HTTPException(
227
+ status_code=402,
228
+ detail="Insufficient funds in your Vast.ai account. Please add credits at https://cloud.vast.ai/"
229
+ )
230
+ raise
231
+
232
+ instance_id = response.get("new_contract")
233
+ if not instance_id:
234
+ # Check if response indicates an error
235
+ if response.get("success") is False:
236
+ error_msg = response.get("error", response.get("msg", "Unknown error"))
237
+ raise HTTPException(status_code=400, detail=f"Vast.ai error: {error_msg}")
238
+ raise HTTPException(status_code=500, detail="Failed to create Vast.ai instance - no contract ID returned")
239
+
240
+ # Get instance details
241
+ return await self.get_pod(str(instance_id))
242
+
243
+ async def get_pods(
244
+ self,
245
+ status: str | None = None,
246
+ limit: int = 100,
247
+ offset: int = 0
248
+ ) -> dict[str, Any]:
249
+ """Get list of all Vast.ai instances."""
250
+ response = await self._make_vast_request(
251
+ "GET",
252
+ "/instances"
253
+ )
254
+
255
+ instances = response.get("instances", [])
256
+
257
+ # Filter by status if specified
258
+ if status:
259
+ status_lower = status.lower()
260
+ instances = [i for i in instances if self._normalize_status(i.get("actual_status", "")).lower() == status_lower]
261
+
262
+ # Apply pagination
263
+ instances = instances[offset:offset + limit]
264
+
265
+ # Transform to standardized format
266
+ pods = []
267
+ for instance in instances:
268
+ ssh_connection = self._build_ssh_connection(instance)
269
+
270
+ pods.append({
271
+ "id": str(instance.get("id")),
272
+ "name": instance.get("label", f"vast-{instance.get('id')}"),
273
+ "status": self._normalize_status(instance.get("actual_status", "loading")),
274
+ "gpuName": instance.get("gpu_name", ""),
275
+ "gpuCount": instance.get("num_gpus", 1),
276
+ "priceHr": instance.get("dph_total", 0),
277
+ "sshConnection": ssh_connection,
278
+ "ip": instance.get("public_ipaddr"),
279
+ "region": instance.get("geolocation", "").split(",")[0] if instance.get("geolocation") else None,
280
+ "createdAt": instance.get("start_date", datetime.now(timezone.utc).isoformat()),
281
+ "updatedAt": datetime.now(timezone.utc).isoformat(),
282
+ "provider": self.PROVIDER_NAME
283
+ })
284
+
285
+ return {
286
+ "data": pods,
287
+ "total_count": len(pods),
288
+ "offset": offset,
289
+ "limit": limit,
290
+ "provider": self.PROVIDER_NAME
291
+ }
292
+
293
+ async def get_pod(self, pod_id: str) -> PodResponse:
294
+ """Get details for a specific Vast.ai instance."""
295
+ from fastapi import HTTPException
296
+
297
+ response = await self._make_vast_request(
298
+ "GET",
299
+ "/instances",
300
+ params={"owner": "me"}
301
+ )
302
+
303
+ instances = response.get("instances", [])
304
+ instance = next((i for i in instances if str(i.get("id")) == pod_id), None)
305
+
306
+ if not instance:
307
+ raise HTTPException(status_code=404, detail=f"Instance {pod_id} not found")
308
+
309
+ ssh_connection = self._build_ssh_connection(instance)
310
+
311
+ now = datetime.now(timezone.utc)
312
+ return PodResponse(
313
+ id=str(instance.get("id", "")),
314
+ userId="",
315
+ teamId=None,
316
+ name=instance.get("label", f"vast-{instance.get('id')}"),
317
+ status=self._normalize_status(instance.get("actual_status", "loading")),
318
+ gpuName=instance.get("gpu_name", ""),
319
+ gpuCount=instance.get("num_gpus", 1),
320
+ priceHr=instance.get("dph_total", 0),
321
+ sshConnection=ssh_connection,
322
+ ip=instance.get("public_ipaddr"),
323
+ createdAt=now,
324
+ updatedAt=now
325
+ )
326
+
327
+ def _build_ssh_connection(self, instance: dict[str, Any]) -> str | None:
328
+ """Build SSH connection string from Vast.ai instance data."""
329
+ ip = instance.get("public_ipaddr") or instance.get("ssh_host")
330
+ port = instance.get("ssh_port", 22)
331
+
332
+ if not ip:
333
+ return None
334
+
335
+ return f"ssh root@{ip} -p {port}"
336
+
337
+ async def delete_pod(self, pod_id: str) -> dict[str, Any]:
338
+ """Destroy a Vast.ai instance."""
339
+ response = await self._make_vast_request(
340
+ "DELETE",
341
+ f"/instances/{pod_id}/"
342
+ )
343
+
344
+ return {
345
+ "success": response.get("success", True),
346
+ "pod_id": pod_id,
347
+ "provider": self.PROVIDER_NAME
348
+ }
349
+
350
+ async def stop_pod(self, pod_id: str) -> dict[str, Any]:
351
+ """Stop a Vast.ai instance (without destroying)."""
352
+ response = await self._make_vast_request(
353
+ "PUT",
354
+ f"/instances/{pod_id}/",
355
+ json_data={"state": "stopped"}
356
+ )
357
+
358
+ return {
359
+ "success": True,
360
+ "pod_id": pod_id,
361
+ "action": "stopped"
362
+ }
363
+
364
+ async def start_pod(self, pod_id: str) -> dict[str, Any]:
365
+ """Start a stopped Vast.ai instance."""
366
+ response = await self._make_vast_request(
367
+ "PUT",
368
+ f"/instances/{pod_id}/",
369
+ json_data={"state": "running"}
370
+ )
371
+
372
+ return {
373
+ "success": True,
374
+ "pod_id": pod_id,
375
+ "action": "started"
376
+ }
377
+
378
+ def _normalize_status(self, vast_status: str) -> str:
379
+ """Convert Vast.ai status to normalized status."""
380
+ status_map = {
381
+ "running": "ACTIVE",
382
+ "loading": "STARTING",
383
+ "created": "PENDING",
384
+ "exited": "STOPPED",
385
+ "offline": "STOPPED",
386
+ "error": "ERROR",
387
+ "destroying": "TERMINATING"
388
+ }
389
+ return status_map.get(vast_status.lower(), vast_status.upper())
390
+
391
+ def normalize_pod(self, pod_data: dict[str, Any]) -> NormalizedPod:
392
+ """Convert Vast.ai instance data to normalized format."""
393
+ ssh_connection = self._build_ssh_connection(pod_data)
394
+
395
+ return NormalizedPod(
396
+ id=str(pod_data.get("id", "")),
397
+ name=pod_data.get("label", f"vast-{pod_data.get('id')}"),
398
+ status=self._normalize_status(pod_data.get("actual_status", "loading")),
399
+ gpu_name=pod_data.get("gpu_name", ""),
400
+ gpu_count=pod_data.get("num_gpus", 1),
401
+ price_hr=pod_data.get("dph_total", 0),
402
+ ssh_connection=ssh_connection,
403
+ ip=pod_data.get("public_ipaddr"),
404
+ provider=self.PROVIDER_NAME,
405
+ created_at=pod_data.get("start_date", datetime.now(timezone.utc).isoformat()),
406
+ updated_at=datetime.now(timezone.utc).isoformat()
407
+ )
@@ -251,7 +251,6 @@ class CellMagicHandlers:
251
251
  # Track process for interrupt handling
252
252
  if hasattr(cell_magic_handler, 'special_handler'):
253
253
  cell_magic_handler.special_handler.current_process_sync = process
254
- print(f"[CELL_MAGIC] Tracking sync subprocess PID={process.pid}", file=sys.stderr, flush=True)
255
254
 
256
255
  # Read and print output line by line (real-time streaming)
257
256
  def read_stream(stream, output_type):
@@ -292,7 +291,6 @@ class CellMagicHandlers:
292
291
  if hasattr(cell_magic_handler, 'special_handler'):
293
292
  if cell_magic_handler.special_handler.sync_interrupted:
294
293
  # Process was killed by interrupt handler
295
- print(f"[CELL_MAGIC] Process was interrupted, raising KeyboardInterrupt", file=sys.stderr, flush=True)
296
294
  raise KeyboardInterrupt("Execution interrupted by user")
297
295
 
298
296
  return_code = process.returncode
@@ -312,7 +310,6 @@ class CellMagicHandlers:
312
310
  # Clear process reference
313
311
  if hasattr(cell_magic_handler, 'special_handler'):
314
312
  cell_magic_handler.special_handler.current_process_sync = None
315
- print(f"[CELL_MAGIC] Cleared sync subprocess reference", file=sys.stderr, flush=True)
316
313
 
317
314
  return return_code
318
315
 
@@ -12,8 +12,8 @@ CONFIG_FILE = CONFIG_DIR / "config.json"
12
12
 
13
13
 
14
14
  def _ensure_config_dir() -> None:
15
- """Ensure the config directory exists."""
16
- CONFIG_DIR.mkdir(parents=True, exist_ok=True)
15
+ """Ensure the config directory exists with secure permissions."""
16
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True, mode=0o700)
17
17
 
18
18
 
19
19
  def _load_config() -> dict:
@@ -28,10 +28,12 @@ def _load_config() -> dict:
28
28
 
29
29
 
30
30
  def _save_config(config: dict) -> None:
31
- """Save config to JSON file."""
31
+ """Save config to JSON file with secure permissions."""
32
32
  _ensure_config_dir()
33
33
  with CONFIG_FILE.open("w", encoding="utf-8") as f:
34
34
  json.dump(config, f, indent=2)
35
+ # Set secure file permissions (owner read/write only)
36
+ CONFIG_FILE.chmod(0o600)
35
37
 
36
38
 
37
39
  def load_api_key(key_name: str) -> Optional[str]:
@@ -73,3 +75,91 @@ def save_api_key(key_name: str, api_key: str) -> None:
73
75
  config = _load_config()
74
76
  config[key_name] = api_key
75
77
  _save_config(config)
78
+
79
+
80
+ def delete_api_key(key_name: str) -> bool:
81
+ """
82
+ Delete an API key from config.
83
+
84
+ Args:
85
+ key_name: Key name to delete
86
+
87
+ Returns:
88
+ True if key was deleted, False if it didn't exist
89
+ """
90
+ config = _load_config()
91
+ if key_name in config:
92
+ del config[key_name]
93
+ _save_config(config)
94
+ return True
95
+ return False
96
+
97
+
98
+ def get_active_provider() -> Optional[str]:
99
+ """
100
+ Get the currently active provider name.
101
+
102
+ Returns:
103
+ Provider name or None if not set
104
+ """
105
+ config = _load_config()
106
+ return config.get("active_provider")
107
+
108
+
109
+ def set_active_provider(provider_name: str) -> None:
110
+ """
111
+ Set the active provider.
112
+
113
+ Args:
114
+ provider_name: The provider to make active
115
+ """
116
+ config = _load_config()
117
+ config["active_provider"] = provider_name
118
+ _save_config(config)
119
+
120
+
121
+ def get_all_configured_keys() -> dict[str, bool]:
122
+ """
123
+ Get a mapping of all API key names to whether they are configured.
124
+
125
+ Returns:
126
+ Dict mapping key names to True/False
127
+ """
128
+ config = _load_config()
129
+
130
+ # Known provider API key names (SSH-based providers only)
131
+ key_names = [
132
+ "RUNPOD_API_KEY",
133
+ "LAMBDA_LABS_API_KEY",
134
+ "VASTAI_API_KEY",
135
+ ]
136
+
137
+ result = {}
138
+ for key_name in key_names:
139
+ # Check environment first, then config
140
+ env_val = os.getenv(key_name)
141
+ config_val = config.get(key_name)
142
+ result[key_name] = bool(env_val or config_val)
143
+
144
+ return result
145
+
146
+
147
+ def get_provider_api_keys(provider_name: str) -> dict[str, Optional[str]]:
148
+ """
149
+ Get all API keys needed for a specific provider.
150
+
151
+ Args:
152
+ provider_name: Provider name (e.g., "runpod", "modal")
153
+
154
+ Returns:
155
+ Dict mapping key names to their values (or None if not set)
156
+ """
157
+ # Provider to key name mappings (SSH-based providers only)
158
+ provider_keys = {
159
+ "runpod": ["RUNPOD_API_KEY"],
160
+ "lambda_labs": ["LAMBDA_LABS_API_KEY"],
161
+ "vastai": ["VASTAI_API_KEY"],
162
+ }
163
+
164
+ key_names = provider_keys.get(provider_name, [])
165
+ return {key: load_api_key(key) for key in key_names}
@@ -1,13 +1,9 @@
1
1
  import os
2
- import io
3
- import sys
4
2
  import asyncio
5
3
  import subprocess
6
4
  import time
7
5
  import shlex
8
- import platform
9
- from contextlib import redirect_stdout, redirect_stderr
10
- from typing import Dict, Any, Optional, Tuple, Union
6
+ from typing import Dict, Any, Optional, Union
11
7
  from fastapi import WebSocket
12
8
 
13
9
  from .cell_magics import CellMagicHandlers
@@ -124,7 +120,6 @@ class AsyncSpecialCommandHandler:
124
120
 
125
121
  # Track process for interrupt handling
126
122
  self.current_process = process
127
- print(f"[SPECIAL_CMD] Started subprocess PID={process.pid}", file=sys.stderr, flush=True)
128
123
 
129
124
  try:
130
125
  # Stream output concurrently
@@ -138,22 +133,16 @@ class AsyncSpecialCommandHandler:
138
133
  # Track tasks for interruption
139
134
  self.stream_tasks = [stdout_task, stderr_task]
140
135
 
141
- print(f"[SPECIAL_CMD] Waiting for stream tasks to complete...", file=sys.stderr, flush=True)
142
136
  # Wait for both streams to complete
143
137
  await asyncio.gather(stdout_task, stderr_task, return_exceptions=True)
144
138
 
145
- print(f"[SPECIAL_CMD] Streams complete, waiting for process to exit...", file=sys.stderr, flush=True)
146
139
  # Wait for process completion
147
140
  return_code = await process.wait()
148
- print(f"[SPECIAL_CMD] Process exited with code {return_code}", file=sys.stderr, flush=True)
149
141
  except asyncio.CancelledError:
150
142
  # Task was cancelled - treat as interrupt
151
- print(f"[SPECIAL_CMD] Task cancelled, treating as interrupt", file=sys.stderr, flush=True)
152
143
  return_code = -15 # SIGTERM
153
144
  except Exception as e:
154
- print(f"[SPECIAL_CMD] Exception during execution: {e}", file=sys.stderr, flush=True)
155
145
  import traceback
156
- traceback.print_exc()
157
146
  # Set error result
158
147
  result["status"] = "error"
159
148
  result["error"] = {
@@ -167,11 +156,9 @@ class AsyncSpecialCommandHandler:
167
156
  # Clear process reference when done
168
157
  self.current_process = None
169
158
  self.stream_tasks = []
170
- print(f"[SPECIAL_CMD] Cleared process reference", file=sys.stderr, flush=True)
171
159
 
172
160
  # Check if process was interrupted (negative return code means killed by signal)
173
161
  if return_code < 0:
174
- print(f"[SPECIAL_CMD] Process was interrupted (return_code={return_code}), setting KeyboardInterrupt error", file=sys.stderr, flush=True)
175
162
  result["status"] = "error"
176
163
  result["error"] = {
177
164
  "output_type": "error",
@@ -189,8 +176,6 @@ class AsyncSpecialCommandHandler:
189
176
  "traceback": [f"Shell command '{command}' failed"]
190
177
  }
191
178
 
192
- print(f"[SPECIAL_CMD] Returning result: status={result['status']}, return_code={return_code}", file=sys.stderr, flush=True)
193
-
194
179
  # If pip install/uninstall occurred, notify clients to refresh packages
195
180
  try:
196
181
  if websocket and return_code == 0 and (command.startswith('pip install') or command.startswith('pip uninstall') or 'pip install' in command or 'pip uninstall' in command):
@@ -228,52 +213,40 @@ class AsyncSpecialCommandHandler:
228
213
  # Cancel stream tasks first
229
214
  for task in self.stream_tasks:
230
215
  if not task.done():
231
- print(f"[SPECIAL_CMD] Cancelling stream task", file=sys.stderr, flush=True)
232
216
  task.cancel()
233
217
 
234
218
  # Interrupt async subprocess
235
219
  if self.current_process:
236
220
  try:
237
- print(f"[SPECIAL_CMD] Interrupting async subprocess PID={self.current_process.pid}", file=sys.stderr, flush=True)
238
221
  self.current_process.terminate()
239
222
 
240
223
  # Give it a moment to terminate gracefully
241
224
  try:
242
225
  await asyncio.wait_for(self.current_process.wait(), timeout=1.0)
243
- print(f"[SPECIAL_CMD] Async subprocess terminated gracefully", file=sys.stderr, flush=True)
244
226
  except asyncio.TimeoutError:
245
227
  # Force kill if it doesn't terminate
246
- print(f"[SPECIAL_CMD] Async subprocess didn't terminate, force killing", file=sys.stderr, flush=True)
247
228
  self.current_process.kill()
248
229
  await self.current_process.wait()
249
- print(f"[SPECIAL_CMD] Async subprocess killed", file=sys.stderr, flush=True)
250
230
 
251
- except Exception as e:
252
- print(f"[SPECIAL_CMD] Error interrupting async subprocess: {e}", file=sys.stderr, flush=True)
231
+ except Exception:
232
+ pass
253
233
 
254
234
  # Interrupt sync subprocess
255
235
  if self.current_process_sync:
256
236
  try:
257
- print(f"[SPECIAL_CMD] Interrupting sync subprocess PID={self.current_process_sync.pid}", file=sys.stderr, flush=True)
258
237
  self.sync_interrupted = True # Set flag so shell commands know to stop
259
238
  self.current_process_sync.terminate()
260
239
 
261
240
  # Give it a moment to terminate gracefully
262
241
  try:
263
242
  self.current_process_sync.wait(timeout=1.0)
264
- print(f"[SPECIAL_CMD] Sync subprocess terminated gracefully", file=sys.stderr, flush=True)
265
243
  except subprocess.TimeoutExpired:
266
244
  # Force kill if it doesn't terminate
267
- print(f"[SPECIAL_CMD] Sync subprocess didn't terminate, force killing", file=sys.stderr, flush=True)
268
245
  self.current_process_sync.kill()
269
246
  self.current_process_sync.wait()
270
- print(f"[SPECIAL_CMD] Sync subprocess killed", file=sys.stderr, flush=True)
271
247
 
272
- except Exception as e:
273
- print(f"[SPECIAL_CMD] Error interrupting sync subprocess: {e}", file=sys.stderr, flush=True)
274
-
275
- if not self.current_process and not self.current_process_sync:
276
- print(f"[SPECIAL_CMD] No subprocess to interrupt", file=sys.stderr, flush=True)
248
+ except Exception:
249
+ pass
277
250
 
278
251
  async def _stream_output(self, stream, stream_type: str, result: Dict[str, Any],
279
252
  websocket: Optional[WebSocket] = None,