more-compute 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kernel_run.py +283 -0
- more_compute-0.1.0.dist-info/METADATA +163 -0
- more_compute-0.1.0.dist-info/RECORD +26 -0
- more_compute-0.1.0.dist-info/WHEEL +5 -0
- more_compute-0.1.0.dist-info/entry_points.txt +2 -0
- more_compute-0.1.0.dist-info/licenses/LICENSE +21 -0
- more_compute-0.1.0.dist-info/top_level.txt +2 -0
- morecompute/__init__.py +6 -0
- morecompute/cli.py +31 -0
- morecompute/execution/__init__.py +5 -0
- morecompute/execution/__main__.py +10 -0
- morecompute/execution/executor.py +381 -0
- morecompute/execution/worker.py +244 -0
- morecompute/notebook.py +81 -0
- morecompute/process_worker.py +209 -0
- morecompute/server.py +641 -0
- morecompute/services/pod_manager.py +503 -0
- morecompute/services/prime_intellect.py +316 -0
- morecompute/static/styles.css +1056 -0
- morecompute/utils/__init__.py +17 -0
- morecompute/utils/cache_util.py +23 -0
- morecompute/utils/error_utils.py +322 -0
- morecompute/utils/notebook_util.py +44 -0
- morecompute/utils/python_environment_util.py +197 -0
- morecompute/utils/special_commands.py +458 -0
- morecompute/utils/system_environment_util.py +134 -0
morecompute/server.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
1
|
+
from cachetools import TTLCache
|
|
2
|
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
|
|
3
|
+
from fastapi.responses import PlainTextResponse
|
|
4
|
+
from fastapi.staticfiles import StaticFiles
|
|
5
|
+
import os
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import importlib.metadata as importlib_metadata
|
|
9
|
+
import zmq
|
|
10
|
+
import textwrap
|
|
11
|
+
|
|
12
|
+
from .notebook import Notebook
|
|
13
|
+
from .execution import NextZmqExecutor
|
|
14
|
+
from .utils.python_environment_util import PythonEnvironmentDetector
|
|
15
|
+
from .utils.system_environment_util import DeviceMetrics
|
|
16
|
+
from .utils.error_utils import ErrorUtils
|
|
17
|
+
from .utils.cache_util import make_cache_key
|
|
18
|
+
from .utils.notebook_util import coerce_cell_source
|
|
19
|
+
from .services.prime_intellect import PrimeIntellectService, CreatePodRequest, PodResponse
|
|
20
|
+
from .services.pod_manager import PodKernelManager
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
BASE_DIR = Path(os.getenv("MORECOMPUTE_ROOT", Path.cwd())).resolve()
|
|
24
|
+
PACKAGE_DIR = Path(__file__).resolve().parent
|
|
25
|
+
ASSETS_DIR = Path(os.getenv("MORECOMPUTE_ASSETS_DIR", BASE_DIR / "assets")).resolve()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def resolve_path(requested_path: str) -> Path:
|
|
29
|
+
relative = requested_path or "."
|
|
30
|
+
target = (BASE_DIR / relative).resolve()
|
|
31
|
+
try:
|
|
32
|
+
target.relative_to(BASE_DIR)
|
|
33
|
+
except ValueError:
|
|
34
|
+
raise HTTPException(status_code=400, detail="Path outside notebook root")
|
|
35
|
+
return target
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
app = FastAPI()
|
|
39
|
+
gpu_cache = TTLCache(maxsize=50, ttl = 60)
|
|
40
|
+
pod_cache = TTLCache(maxsize = 100, ttl = 300)
|
|
41
|
+
packages_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for packages
|
|
42
|
+
environments_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for environments
|
|
43
|
+
|
|
44
|
+
# Mount assets directory for icons, images, etc.
|
|
45
|
+
if ASSETS_DIR.exists():
|
|
46
|
+
app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
|
|
47
|
+
|
|
48
|
+
# Global instances for the application state
|
|
49
|
+
notebook_path_env = os.getenv("MORECOMPUTE_NOTEBOOK_PATH")
|
|
50
|
+
if notebook_path_env:
|
|
51
|
+
notebook = Notebook(file_path=notebook_path_env)
|
|
52
|
+
else:
|
|
53
|
+
notebook = Notebook()
|
|
54
|
+
error_utils = ErrorUtils()
|
|
55
|
+
executor = NextZmqExecutor(error_utils=error_utils)
|
|
56
|
+
metrics = DeviceMetrics()
|
|
57
|
+
|
|
58
|
+
# Initialize Prime Intellect service if API key is provided
|
|
59
|
+
# Check environment variable first, then .env file (commonly gitignored)
|
|
60
|
+
prime_api_key = os.getenv("PRIME_INTELLECT_API_KEY")
|
|
61
|
+
if not prime_api_key:
|
|
62
|
+
env_path = BASE_DIR / ".env"
|
|
63
|
+
if env_path.exists():
|
|
64
|
+
try:
|
|
65
|
+
with env_path.open("r", encoding="utf-8") as f:
|
|
66
|
+
for line in f:
|
|
67
|
+
line = line.strip()
|
|
68
|
+
if line.startswith("PRIME_INTELLECT_API_KEY="):
|
|
69
|
+
prime_api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
|
|
70
|
+
break
|
|
71
|
+
except Exception:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
prime_intellect = PrimeIntellectService(api_key=prime_api_key) if prime_api_key else None
|
|
75
|
+
pod_manager: PodKernelManager | None = None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@app.get("/api/packages")
|
|
79
|
+
async def list_installed_packages(force_refresh: bool = False):
|
|
80
|
+
"""
|
|
81
|
+
Return installed packages for the current Python runtime.
|
|
82
|
+
Args:
|
|
83
|
+
force_refresh: If True, bypass cache and fetch fresh data
|
|
84
|
+
"""
|
|
85
|
+
cache_key = "packages_list"
|
|
86
|
+
|
|
87
|
+
# Check cache first unless force refresh is requested
|
|
88
|
+
if not force_refresh and cache_key in packages_cache:
|
|
89
|
+
return packages_cache[cache_key]
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
packages = []
|
|
93
|
+
for dist in importlib_metadata.distributions():
|
|
94
|
+
name = dist.metadata.get("Name") or dist.metadata.get("Summary") or dist.metadata.get("name")
|
|
95
|
+
version = dist.version
|
|
96
|
+
if name and version:
|
|
97
|
+
packages.append({"name": str(name), "version": str(version)})
|
|
98
|
+
packages.sort(key=lambda p: p["name"].lower())
|
|
99
|
+
|
|
100
|
+
result = {"packages": packages}
|
|
101
|
+
packages_cache[cache_key] = result
|
|
102
|
+
return result
|
|
103
|
+
except Exception as exc:
|
|
104
|
+
raise HTTPException(status_code=500, detail=f"Failed to list packages: {exc}")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@app.get("/api/metrics")
|
|
108
|
+
async def get_metrics():
|
|
109
|
+
try:
|
|
110
|
+
return metrics.get_all_devices()
|
|
111
|
+
except Exception as exc:
|
|
112
|
+
raise HTTPException(status_code=500, detail=f"Failed to get metrics: {exc}")
|
|
113
|
+
|
|
114
|
+
@app.get("/api/environments")
|
|
115
|
+
async def get_environments(full: bool = True, force_refresh: bool = False):
|
|
116
|
+
"""
|
|
117
|
+
Return available Python environments.
|
|
118
|
+
Args:
|
|
119
|
+
full: If True (default), performs comprehensive scan (conda, system, venv).
|
|
120
|
+
Takes a few seconds but finds all environments.
|
|
121
|
+
force_refresh: If True, bypass cache and fetch fresh data
|
|
122
|
+
"""
|
|
123
|
+
cache_key = f"environments_{full}"
|
|
124
|
+
|
|
125
|
+
# Check cache first unless force refresh is requested
|
|
126
|
+
if not force_refresh and cache_key in environments_cache:
|
|
127
|
+
return environments_cache[cache_key]
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
detector = PythonEnvironmentDetector()
|
|
131
|
+
environments = detector.detect_all_environments()
|
|
132
|
+
current_env = detector.get_current_environment()
|
|
133
|
+
|
|
134
|
+
result = {
|
|
135
|
+
"status": "success",
|
|
136
|
+
"environments": environments,
|
|
137
|
+
"current": current_env
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
environments_cache[cache_key] = result # Cache the result
|
|
141
|
+
return result
|
|
142
|
+
|
|
143
|
+
except Exception as exc:
|
|
144
|
+
raise HTTPException(status_code=500, detail=f"Failed to detect environments: {exc}")
|
|
145
|
+
|
|
146
|
+
@app.get("/api/files")
|
|
147
|
+
async def list_files(path: str = "."):
|
|
148
|
+
directory = resolve_path(path)
|
|
149
|
+
if not directory.exists() or not directory.is_dir():
|
|
150
|
+
raise HTTPException(status_code=404, detail="Directory not found")
|
|
151
|
+
|
|
152
|
+
items: list[dict[str, str | int]] = []
|
|
153
|
+
try:
|
|
154
|
+
for entry in sorted(directory.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
|
|
155
|
+
stat = entry.stat()
|
|
156
|
+
item_path = entry.relative_to(BASE_DIR)
|
|
157
|
+
items.append({
|
|
158
|
+
"name": entry.name,
|
|
159
|
+
"path": str(item_path).replace("\\", "/"),
|
|
160
|
+
"type": "directory" if entry.is_dir() else "file",
|
|
161
|
+
"size": stat.st_size,
|
|
162
|
+
"modified": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
|
|
163
|
+
})
|
|
164
|
+
except PermissionError as exc:
|
|
165
|
+
raise HTTPException(status_code=403, detail=f"Permission denied: {exc}")
|
|
166
|
+
|
|
167
|
+
return {
|
|
168
|
+
"root": str(BASE_DIR),
|
|
169
|
+
"path": str(directory.relative_to(BASE_DIR)) if directory != BASE_DIR else ".",
|
|
170
|
+
"items": items,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@app.post("/api/fix-indentation")
|
|
175
|
+
async def fix_indentation(request: Request):
|
|
176
|
+
"""Fix indentation in Python code using textwrap.dedent()."""
|
|
177
|
+
try:
|
|
178
|
+
body = await request.json()
|
|
179
|
+
code = body.get("code", "")
|
|
180
|
+
fixed_code = textwrap.dedent(code)
|
|
181
|
+
return {"fixed_code": fixed_code}
|
|
182
|
+
except Exception as exc:
|
|
183
|
+
raise HTTPException(status_code=500, detail=f"Failed to fix indentation: {exc}")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@app.get("/api/file")
|
|
187
|
+
async def read_file(path: str, max_bytes: int = 256_000):
|
|
188
|
+
file_path = resolve_path(path)
|
|
189
|
+
if not file_path.exists() or not file_path.is_file():
|
|
190
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
with file_path.open("rb") as f:
|
|
194
|
+
content = f.read(max_bytes + 1)
|
|
195
|
+
except PermissionError as exc:
|
|
196
|
+
raise HTTPException(status_code=403, detail=f"Permission denied: {exc}")
|
|
197
|
+
|
|
198
|
+
truncated = len(content) > max_bytes
|
|
199
|
+
if truncated:
|
|
200
|
+
content = content[:max_bytes]
|
|
201
|
+
|
|
202
|
+
text = content.decode("utf-8", errors="replace")
|
|
203
|
+
if truncated:
|
|
204
|
+
text += "\n\n… (truncated)"
|
|
205
|
+
|
|
206
|
+
return PlainTextResponse(text)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class WebSocketManager:
|
|
210
|
+
"""Manages WebSocket connections and message handling."""
|
|
211
|
+
def __init__(self) -> None:
|
|
212
|
+
self.clients: dict[WebSocket, None] = {}
|
|
213
|
+
self.executor = executor
|
|
214
|
+
self.notebook = notebook
|
|
215
|
+
|
|
216
|
+
async def connect(self, websocket: WebSocket):
|
|
217
|
+
await websocket.accept()
|
|
218
|
+
self.clients[websocket] = None
|
|
219
|
+
# Send the initial notebook state to the new client
|
|
220
|
+
await websocket.send_json({
|
|
221
|
+
"type": "notebook_data",
|
|
222
|
+
"data": self.notebook.get_notebook_data()
|
|
223
|
+
})
|
|
224
|
+
|
|
225
|
+
def disconnect(self, websocket: WebSocket):
|
|
226
|
+
del self.clients[websocket]
|
|
227
|
+
|
|
228
|
+
async def broadcast_notebook_update(self):
|
|
229
|
+
"""Send the entire notebook state to all connected clients."""
|
|
230
|
+
updated_data = self.notebook.get_notebook_data()
|
|
231
|
+
for client in self.clients:
|
|
232
|
+
await client.send_json({
|
|
233
|
+
"type": "notebook_updated",
|
|
234
|
+
"data": updated_data
|
|
235
|
+
})
|
|
236
|
+
|
|
237
|
+
async def handle_message_loop(self, websocket: WebSocket):
|
|
238
|
+
"""Main loop to handle incoming WebSocket messages."""
|
|
239
|
+
while True:
|
|
240
|
+
try:
|
|
241
|
+
message = await websocket.receive_json()
|
|
242
|
+
await self._handle_message(websocket, message)
|
|
243
|
+
except WebSocketDisconnect:
|
|
244
|
+
self.disconnect(websocket)
|
|
245
|
+
break
|
|
246
|
+
except Exception as e:
|
|
247
|
+
await self._send_error(websocket, f"Unhandled error: {e}")
|
|
248
|
+
|
|
249
|
+
async def _handle_message(self, websocket: WebSocket, message: dict):
|
|
250
|
+
message_type = message.get("type")
|
|
251
|
+
data = message.get("data", {})
|
|
252
|
+
|
|
253
|
+
handlers = {
|
|
254
|
+
"execute_cell": self._handle_execute_cell,
|
|
255
|
+
"add_cell": self._handle_add_cell,
|
|
256
|
+
"delete_cell": self._handle_delete_cell,
|
|
257
|
+
"update_cell": self._handle_update_cell,
|
|
258
|
+
"interrupt_kernel": self._handle_interrupt_kernel,
|
|
259
|
+
"reset_kernel": self._handle_reset_kernel,
|
|
260
|
+
"load_notebook": self._handle_load_notebook,
|
|
261
|
+
"save_notebook": self._handle_save_notebook,
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
handler = handlers.get(message_type)
|
|
265
|
+
if handler:
|
|
266
|
+
await handler(websocket, data)
|
|
267
|
+
else:
|
|
268
|
+
await self._send_error(websocket, f"Unknown message type: {message_type}")
|
|
269
|
+
|
|
270
|
+
async def _handle_execute_cell(self, websocket: WebSocket, data: dict):
|
|
271
|
+
import sys
|
|
272
|
+
cell_index = data.get("cell_index")
|
|
273
|
+
if cell_index is None or not (0 <= cell_index < len(self.notebook.cells)):
|
|
274
|
+
await self._send_error(websocket, "Invalid cell index.")
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
source = coerce_cell_source(self.notebook.cells[cell_index].get('source', ''))
|
|
278
|
+
|
|
279
|
+
await websocket.send_json({
|
|
280
|
+
"type": "execution_start",
|
|
281
|
+
"data": {"cell_index": cell_index, "execution_count": getattr(self.executor, 'execution_count', 0) + 1}
|
|
282
|
+
})
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
result = await self.executor.execute_cell(cell_index, source, websocket)
|
|
286
|
+
except Exception as e:
|
|
287
|
+
error_msg = str(e)
|
|
288
|
+
print(f"[SERVER ERROR] execute_cell failed: {error_msg}", file=sys.stderr, flush=True)
|
|
289
|
+
|
|
290
|
+
# Send error to frontend
|
|
291
|
+
result = {
|
|
292
|
+
'status': 'error',
|
|
293
|
+
'execution_count': None,
|
|
294
|
+
'execution_time': '0ms',
|
|
295
|
+
'outputs': [],
|
|
296
|
+
'error': {
|
|
297
|
+
'output_type': 'error',
|
|
298
|
+
'ename': type(e).__name__,
|
|
299
|
+
'evalue': error_msg,
|
|
300
|
+
'traceback': [f'{type(e).__name__}: {error_msg}', 'Worker failed to start or crashed. Check server logs.']
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
await websocket.send_json({
|
|
304
|
+
"type": "execution_error",
|
|
305
|
+
"data": {
|
|
306
|
+
"cell_index": cell_index,
|
|
307
|
+
"error": result['error']
|
|
308
|
+
}
|
|
309
|
+
})
|
|
310
|
+
|
|
311
|
+
self.notebook.cells[cell_index]['outputs'] = result.get('outputs', [])
|
|
312
|
+
self.notebook.cells[cell_index]['execution_count'] = result.get('execution_count')
|
|
313
|
+
|
|
314
|
+
await websocket.send_json({
|
|
315
|
+
"type": "execution_complete",
|
|
316
|
+
"data": { "cell_index": cell_index, "result": result }
|
|
317
|
+
})
|
|
318
|
+
|
|
319
|
+
async def _handle_add_cell(self, websocket: WebSocket, data: dict):
|
|
320
|
+
index = data.get('index', len(self.notebook.cells))
|
|
321
|
+
cell_type = data.get('cell_type', 'code')
|
|
322
|
+
self.notebook.add_cell(index=index, cell_type=cell_type)
|
|
323
|
+
await self.broadcast_notebook_update()
|
|
324
|
+
|
|
325
|
+
async def _handle_delete_cell(self, websocket: WebSocket, data: dict):
|
|
326
|
+
index = data.get('cell_index')
|
|
327
|
+
if index is not None:
|
|
328
|
+
self.notebook.delete_cell(index)
|
|
329
|
+
await self.broadcast_notebook_update()
|
|
330
|
+
|
|
331
|
+
async def _handle_update_cell(self, websocket: WebSocket, data: dict):
|
|
332
|
+
index = data.get('cell_index')
|
|
333
|
+
source = data.get('source')
|
|
334
|
+
if index is not None and source is not None:
|
|
335
|
+
self.notebook.update_cell(index, source)
|
|
336
|
+
#self.notebook.save_to_file()
|
|
337
|
+
#to -do?
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
async def _handle_load_notebook(self, websocket: WebSocket, data: dict):
|
|
341
|
+
# In a real app, this would load from a file path in `data`
|
|
342
|
+
# For now, it just sends the current state back to the requester
|
|
343
|
+
await websocket.send_json({
|
|
344
|
+
"type": "notebook_data",
|
|
345
|
+
"data": self.notebook.get_notebook_data()
|
|
346
|
+
})
|
|
347
|
+
|
|
348
|
+
async def _handle_save_notebook(self, websocket: WebSocket, data: dict):
|
|
349
|
+
try:
|
|
350
|
+
self.notebook.save_to_file()
|
|
351
|
+
await websocket.send_json({"type": "notebook_saved", "data": {"file_path": self.notebook.file_path}})
|
|
352
|
+
except Exception as exc:
|
|
353
|
+
await self._send_error(websocket, f"Failed to save notebook: {exc}")
|
|
354
|
+
|
|
355
|
+
async def _handle_interrupt_kernel(self, websocket: WebSocket, data: dict):
|
|
356
|
+
try:
|
|
357
|
+
cell_index = data.get('cell_index')
|
|
358
|
+
except Exception:
|
|
359
|
+
cell_index = None
|
|
360
|
+
|
|
361
|
+
import sys
|
|
362
|
+
print(f"[SERVER] Interrupt request received for cell {cell_index}", file=sys.stderr, flush=True)
|
|
363
|
+
|
|
364
|
+
# Perform the interrupt (this may take up to 1 second)
|
|
365
|
+
await self.executor.interrupt_kernel(cell_index=cell_index)
|
|
366
|
+
|
|
367
|
+
print(f"[SERVER] Interrupt completed, sending error message", file=sys.stderr, flush=True)
|
|
368
|
+
|
|
369
|
+
# Inform all clients that the currently running cell (if any) is interrupted
|
|
370
|
+
try:
|
|
371
|
+
await websocket.send_json({
|
|
372
|
+
"type": "execution_error",
|
|
373
|
+
"data": {
|
|
374
|
+
"cell_index": cell_index,
|
|
375
|
+
"error": {
|
|
376
|
+
"output_type": "error",
|
|
377
|
+
"ename": "KeyboardInterrupt",
|
|
378
|
+
"evalue": "Execution interrupted by user",
|
|
379
|
+
"traceback": ["KeyboardInterrupt: Execution was stopped by user"]
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
})
|
|
383
|
+
await websocket.send_json({
|
|
384
|
+
"type": "execution_complete",
|
|
385
|
+
"data": {
|
|
386
|
+
"cell_index": cell_index,
|
|
387
|
+
"result": {
|
|
388
|
+
"status": "error",
|
|
389
|
+
"execution_count": None,
|
|
390
|
+
"execution_time": "interrupted",
|
|
391
|
+
"outputs": [],
|
|
392
|
+
"error": {
|
|
393
|
+
"output_type": "error",
|
|
394
|
+
"ename": "KeyboardInterrupt",
|
|
395
|
+
"evalue": "Execution interrupted by user",
|
|
396
|
+
"traceback": ["KeyboardInterrupt: Execution was stopped by user"]
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
})
|
|
401
|
+
print(f"[SERVER] Error messages sent for cell {cell_index}", file=sys.stderr, flush=True)
|
|
402
|
+
except Exception as e:
|
|
403
|
+
print(f"[SERVER] Failed to send error messages: {e}", file=sys.stderr, flush=True)
|
|
404
|
+
|
|
405
|
+
async def _handle_reset_kernel(self, websocket: WebSocket, data: dict):
|
|
406
|
+
self.executor.reset_kernel()
|
|
407
|
+
self.notebook.clear_all_outputs()
|
|
408
|
+
await self.broadcast_notebook_update()
|
|
409
|
+
|
|
410
|
+
async def _send_error(self, websocket: WebSocket, error_message: str):
|
|
411
|
+
await websocket.send_json({"type": "error", "data": {"error": error_message}})
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
manager = WebSocketManager()
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
@app.websocket("/ws")
|
|
418
|
+
async def websocket_endpoint(websocket: WebSocket):
|
|
419
|
+
await manager.connect(websocket)
|
|
420
|
+
await manager.handle_message_loop(websocket)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
#gpu connection api
|
|
424
|
+
@app.get("/api/gpu/config")
|
|
425
|
+
async def get_gpu_config():
|
|
426
|
+
"""Check if Prime Intellect API is configured."""
|
|
427
|
+
return {"configured": prime_intellect is not None}
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@app.post("/api/gpu/config")
|
|
431
|
+
async def set_gpu_config(request: Request):
|
|
432
|
+
"""Save Prime Intellect API key to .env file (commonly gitignored) and reinitialize service."""
|
|
433
|
+
global prime_intellect
|
|
434
|
+
|
|
435
|
+
try:
|
|
436
|
+
body = await request.json()
|
|
437
|
+
api_key = body.get("api_key", "").strip()
|
|
438
|
+
if not api_key:
|
|
439
|
+
raise HTTPException(status_code=400, detail="API key is required")
|
|
440
|
+
|
|
441
|
+
env_path = BASE_DIR / ".env"
|
|
442
|
+
|
|
443
|
+
# Read existing .env content
|
|
444
|
+
existing_lines = []
|
|
445
|
+
if env_path.exists():
|
|
446
|
+
with env_path.open("r", encoding="utf-8") as f:
|
|
447
|
+
existing_lines = f.readlines()
|
|
448
|
+
|
|
449
|
+
# Remove any existing PRIME_INTELLECT_API_KEY lines
|
|
450
|
+
new_lines = [line for line in existing_lines if not line.strip().startswith("PRIME_INTELLECT_API_KEY=")]
|
|
451
|
+
# Add the new API key
|
|
452
|
+
new_lines.append(f"PRIME_INTELLECT_API_KEY={api_key}\n")
|
|
453
|
+
# Write back to .env
|
|
454
|
+
with env_path.open("w", encoding="utf-8") as f:
|
|
455
|
+
f.writelines(new_lines)
|
|
456
|
+
prime_intellect = PrimeIntellectService(api_key=api_key)
|
|
457
|
+
|
|
458
|
+
return {"configured": True, "message": "API key saved successfully"}
|
|
459
|
+
|
|
460
|
+
except HTTPException:
|
|
461
|
+
raise
|
|
462
|
+
except Exception as exc:
|
|
463
|
+
raise HTTPException(status_code=500, detail=f"Failed to save API key: {exc}")
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@app.get("/api/gpu/availability")
|
|
467
|
+
async def get_gpu_availability(
|
|
468
|
+
regions: list[str] | None = None,
|
|
469
|
+
gpu_count: int | None = None,
|
|
470
|
+
gpu_type: str | None = None,
|
|
471
|
+
security: str | None = None
|
|
472
|
+
):
|
|
473
|
+
"""Get available GPU resources from Prime Intellect."""
|
|
474
|
+
if not prime_intellect:
|
|
475
|
+
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
476
|
+
|
|
477
|
+
cache_key = make_cache_key(
|
|
478
|
+
"gpu_avail",
|
|
479
|
+
regions = regions,
|
|
480
|
+
gpu_count = gpu_count,
|
|
481
|
+
gpu_type = gpu_type,
|
|
482
|
+
security=security
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
if cache_key in gpu_cache:
|
|
486
|
+
return gpu_cache[cache_key]
|
|
487
|
+
|
|
488
|
+
#cache miss
|
|
489
|
+
result = await prime_intellect.get_gpu_availability(regions, gpu_count, gpu_type, security)
|
|
490
|
+
gpu_cache[cache_key] = result
|
|
491
|
+
return result
|
|
492
|
+
|
|
493
|
+
@app.get("/api/gpu/pods")
|
|
494
|
+
async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int = 0):
|
|
495
|
+
"""Get list of user's GPU pods."""
|
|
496
|
+
if not prime_intellect:
|
|
497
|
+
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
498
|
+
|
|
499
|
+
cache_key = make_cache_key(
|
|
500
|
+
"gpu_pod",
|
|
501
|
+
status=status,
|
|
502
|
+
limit=limit,
|
|
503
|
+
offset=offset
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
if cache_key in pod_cache:
|
|
507
|
+
return pod_cache[cache_key]
|
|
508
|
+
|
|
509
|
+
# Cache miss: fetch from API
|
|
510
|
+
result = await prime_intellect.get_pods(status, limit, offset)
|
|
511
|
+
pod_cache[cache_key] = result
|
|
512
|
+
return result
|
|
513
|
+
|
|
514
|
+
@app.post("/api/gpu/pods")
|
|
515
|
+
async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
|
|
516
|
+
"""Create a new GPU pod."""
|
|
517
|
+
import sys
|
|
518
|
+
print(f"[CREATE POD] Received request: {pod_request.model_dump()}", file=sys.stderr, flush=True)
|
|
519
|
+
|
|
520
|
+
if not prime_intellect:
|
|
521
|
+
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
522
|
+
|
|
523
|
+
try:
|
|
524
|
+
result = await prime_intellect.create_pod(pod_request)
|
|
525
|
+
print(f"[CREATE POD] Success: {result}", file=sys.stderr, flush=True)
|
|
526
|
+
pod_cache.clear()
|
|
527
|
+
|
|
528
|
+
return result
|
|
529
|
+
except HTTPException as e:
|
|
530
|
+
if e.status_code == 402:
|
|
531
|
+
raise HTTPException(
|
|
532
|
+
status_code=402,
|
|
533
|
+
detail="Insufficient funds in your Prime Intellect wallet. Please add credits at https://app.primeintellect.ai/dashboard/billing"
|
|
534
|
+
)
|
|
535
|
+
elif e.status_code == 401 or e.status_code == 403:
|
|
536
|
+
raise HTTPException(
|
|
537
|
+
status_code=e.status_code,
|
|
538
|
+
detail="Authentication failed. Please check your Prime Intellect API key."
|
|
539
|
+
)
|
|
540
|
+
else:
|
|
541
|
+
print(f"[CREATE POD] Error: {e}", file=sys.stderr, flush=True)
|
|
542
|
+
raise
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
@app.get("/api/gpu/pods/{pod_id}")
|
|
546
|
+
async def get_gpu_pod(pod_id: str) -> PodResponse:
|
|
547
|
+
"""Get details of a specific GPU pod."""
|
|
548
|
+
if not prime_intellect:
|
|
549
|
+
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
550
|
+
|
|
551
|
+
cache_key = make_cache_key("gpu_pod_detail", pod_id=pod_id)
|
|
552
|
+
|
|
553
|
+
if cache_key in pod_cache:
|
|
554
|
+
return pod_cache[cache_key]
|
|
555
|
+
|
|
556
|
+
result = await prime_intellect.get_pod(pod_id)
|
|
557
|
+
pod_cache[cache_key] = result
|
|
558
|
+
return result
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
@app.delete("/api/gpu/pods/{pod_id}")
|
|
562
|
+
async def delete_gpu_pod(pod_id: str):
|
|
563
|
+
"""Delete a GPU pod."""
|
|
564
|
+
if not prime_intellect:
|
|
565
|
+
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
566
|
+
|
|
567
|
+
result = await prime_intellect.delete_pod(pod_id)
|
|
568
|
+
pod_cache.clear()
|
|
569
|
+
return result
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@app.post("/api/gpu/pods/{pod_id}/connect")
|
|
573
|
+
async def connect_to_pod(pod_id: str):
|
|
574
|
+
"""Connect to a GPU pod and establish SSH tunnel for remote execution."""
|
|
575
|
+
global pod_manager
|
|
576
|
+
|
|
577
|
+
if not prime_intellect:
|
|
578
|
+
raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
|
|
579
|
+
if pod_manager is None:
|
|
580
|
+
pod_manager = PodKernelManager(pi_service=prime_intellect)
|
|
581
|
+
|
|
582
|
+
# Disconnect from any existing pod first, may need to fix later for multi-pod
|
|
583
|
+
if pod_manager.pod is not None:
|
|
584
|
+
await pod_manager.disconnect()
|
|
585
|
+
|
|
586
|
+
# Connect to the new pod
|
|
587
|
+
result = await pod_manager.connect_to_pod(pod_id)
|
|
588
|
+
|
|
589
|
+
if result.get("status") == "ok":
|
|
590
|
+
pod_manager.attach_executor(executor)
|
|
591
|
+
addresses = pod_manager.get_executor_addresses()
|
|
592
|
+
executor.cmd_addr = addresses["cmd_addr"]
|
|
593
|
+
executor.pub_addr = addresses["pub_addr"]
|
|
594
|
+
|
|
595
|
+
# Reconnect executor sockets to tunneled ports
|
|
596
|
+
executor.req.close(0) # type: ignore[reportAttributeAccessIssue]
|
|
597
|
+
executor.req = executor.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
|
598
|
+
executor.req.connect(executor.cmd_addr) # type: ignore[reportAttributeAccessIssue]
|
|
599
|
+
|
|
600
|
+
executor.sub.close(0) # type: ignore[reportAttributeAccessIssue]
|
|
601
|
+
executor.sub = executor.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
|
602
|
+
executor.sub.connect(executor.pub_addr) # type: ignore[reportAttributeAccessIssue]
|
|
603
|
+
executor.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
|
|
604
|
+
|
|
605
|
+
return result
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
@app.post("/api/gpu/pods/disconnect")
|
|
609
|
+
async def disconnect_from_pod():
|
|
610
|
+
"""Disconnect from current GPU pod."""
|
|
611
|
+
global pod_manager
|
|
612
|
+
|
|
613
|
+
if pod_manager is None or pod_manager.pod is None:
|
|
614
|
+
return {"status": "ok", "message": "No active connection"}
|
|
615
|
+
|
|
616
|
+
result = await pod_manager.disconnect()
|
|
617
|
+
|
|
618
|
+
# Reset executor to local addresses
|
|
619
|
+
executor.cmd_addr = os.getenv('MC_ZMQ_CMD_ADDR', 'tcp://127.0.0.1:5555')
|
|
620
|
+
executor.pub_addr = os.getenv('MC_ZMQ_PUB_ADDR', 'tcp://127.0.0.1:5556')
|
|
621
|
+
|
|
622
|
+
# Reconnect to local worker
|
|
623
|
+
executor.req.close(0) # type: ignore[reportAttributeAccessIssue]
|
|
624
|
+
executor.req = executor.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
|
625
|
+
executor.req.connect(executor.cmd_addr) # type: ignore[reportAttributeAccessIssue]
|
|
626
|
+
|
|
627
|
+
executor.sub.close(0) # type: ignore[reportAttributeAccessIssue]
|
|
628
|
+
executor.sub = executor.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
|
629
|
+
executor.sub.connect(executor.pub_addr) # type: ignore[reportAttributeAccessIssue]
|
|
630
|
+
executor.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
|
|
631
|
+
|
|
632
|
+
return result
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@app.get("/api/gpu/pods/connection/status")
|
|
636
|
+
async def get_pod_connection_status():
|
|
637
|
+
"""Get status of current pod connection."""
|
|
638
|
+
if pod_manager is None:
|
|
639
|
+
return {"connected": False, "pod": None}
|
|
640
|
+
|
|
641
|
+
return await pod_manager.get_status()
|