more-compute 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
morecompute/server.py ADDED
@@ -0,0 +1,641 @@
1
+ from cachetools import TTLCache
2
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException
3
+ from fastapi.responses import PlainTextResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ import os
6
+ from datetime import datetime, timezone
7
+ from pathlib import Path
8
+ import importlib.metadata as importlib_metadata
9
+ import zmq
10
+ import textwrap
11
+
12
+ from .notebook import Notebook
13
+ from .execution import NextZmqExecutor
14
+ from .utils.python_environment_util import PythonEnvironmentDetector
15
+ from .utils.system_environment_util import DeviceMetrics
16
+ from .utils.error_utils import ErrorUtils
17
+ from .utils.cache_util import make_cache_key
18
+ from .utils.notebook_util import coerce_cell_source
19
+ from .services.prime_intellect import PrimeIntellectService, CreatePodRequest, PodResponse
20
+ from .services.pod_manager import PodKernelManager
21
+
22
+
23
+ BASE_DIR = Path(os.getenv("MORECOMPUTE_ROOT", Path.cwd())).resolve()
24
+ PACKAGE_DIR = Path(__file__).resolve().parent
25
+ ASSETS_DIR = Path(os.getenv("MORECOMPUTE_ASSETS_DIR", BASE_DIR / "assets")).resolve()
26
+
27
+
28
+ def resolve_path(requested_path: str) -> Path:
29
+ relative = requested_path or "."
30
+ target = (BASE_DIR / relative).resolve()
31
+ try:
32
+ target.relative_to(BASE_DIR)
33
+ except ValueError:
34
+ raise HTTPException(status_code=400, detail="Path outside notebook root")
35
+ return target
36
+
37
+
38
+ app = FastAPI()
39
+ gpu_cache = TTLCache(maxsize=50, ttl = 60)
40
+ pod_cache = TTLCache(maxsize = 100, ttl = 300)
41
+ packages_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for packages
42
+ environments_cache = TTLCache(maxsize=1, ttl=300) # 5 minutes cache for environments
43
+
44
+ # Mount assets directory for icons, images, etc.
45
+ if ASSETS_DIR.exists():
46
+ app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
47
+
48
+ # Global instances for the application state
49
+ notebook_path_env = os.getenv("MORECOMPUTE_NOTEBOOK_PATH")
50
+ if notebook_path_env:
51
+ notebook = Notebook(file_path=notebook_path_env)
52
+ else:
53
+ notebook = Notebook()
54
+ error_utils = ErrorUtils()
55
+ executor = NextZmqExecutor(error_utils=error_utils)
56
+ metrics = DeviceMetrics()
57
+
58
+ # Initialize Prime Intellect service if API key is provided
59
+ # Check environment variable first, then .env file (commonly gitignored)
60
+ prime_api_key = os.getenv("PRIME_INTELLECT_API_KEY")
61
+ if not prime_api_key:
62
+ env_path = BASE_DIR / ".env"
63
+ if env_path.exists():
64
+ try:
65
+ with env_path.open("r", encoding="utf-8") as f:
66
+ for line in f:
67
+ line = line.strip()
68
+ if line.startswith("PRIME_INTELLECT_API_KEY="):
69
+ prime_api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
70
+ break
71
+ except Exception:
72
+ pass
73
+
74
+ prime_intellect = PrimeIntellectService(api_key=prime_api_key) if prime_api_key else None
75
+ pod_manager: PodKernelManager | None = None
76
+
77
+
78
+ @app.get("/api/packages")
79
+ async def list_installed_packages(force_refresh: bool = False):
80
+ """
81
+ Return installed packages for the current Python runtime.
82
+ Args:
83
+ force_refresh: If True, bypass cache and fetch fresh data
84
+ """
85
+ cache_key = "packages_list"
86
+
87
+ # Check cache first unless force refresh is requested
88
+ if not force_refresh and cache_key in packages_cache:
89
+ return packages_cache[cache_key]
90
+
91
+ try:
92
+ packages = []
93
+ for dist in importlib_metadata.distributions():
94
+ name = dist.metadata.get("Name") or dist.metadata.get("Summary") or dist.metadata.get("name")
95
+ version = dist.version
96
+ if name and version:
97
+ packages.append({"name": str(name), "version": str(version)})
98
+ packages.sort(key=lambda p: p["name"].lower())
99
+
100
+ result = {"packages": packages}
101
+ packages_cache[cache_key] = result
102
+ return result
103
+ except Exception as exc:
104
+ raise HTTPException(status_code=500, detail=f"Failed to list packages: {exc}")
105
+
106
+
107
+ @app.get("/api/metrics")
108
+ async def get_metrics():
109
+ try:
110
+ return metrics.get_all_devices()
111
+ except Exception as exc:
112
+ raise HTTPException(status_code=500, detail=f"Failed to get metrics: {exc}")
113
+
114
+ @app.get("/api/environments")
115
+ async def get_environments(full: bool = True, force_refresh: bool = False):
116
+ """
117
+ Return available Python environments.
118
+ Args:
119
+ full: If True (default), performs comprehensive scan (conda, system, venv).
120
+ Takes a few seconds but finds all environments.
121
+ force_refresh: If True, bypass cache and fetch fresh data
122
+ """
123
+ cache_key = f"environments_{full}"
124
+
125
+ # Check cache first unless force refresh is requested
126
+ if not force_refresh and cache_key in environments_cache:
127
+ return environments_cache[cache_key]
128
+
129
+ try:
130
+ detector = PythonEnvironmentDetector()
131
+ environments = detector.detect_all_environments()
132
+ current_env = detector.get_current_environment()
133
+
134
+ result = {
135
+ "status": "success",
136
+ "environments": environments,
137
+ "current": current_env
138
+ }
139
+
140
+ environments_cache[cache_key] = result # Cache the result
141
+ return result
142
+
143
+ except Exception as exc:
144
+ raise HTTPException(status_code=500, detail=f"Failed to detect environments: {exc}")
145
+
146
+ @app.get("/api/files")
147
+ async def list_files(path: str = "."):
148
+ directory = resolve_path(path)
149
+ if not directory.exists() or not directory.is_dir():
150
+ raise HTTPException(status_code=404, detail="Directory not found")
151
+
152
+ items: list[dict[str, str | int]] = []
153
+ try:
154
+ for entry in sorted(directory.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
155
+ stat = entry.stat()
156
+ item_path = entry.relative_to(BASE_DIR)
157
+ items.append({
158
+ "name": entry.name,
159
+ "path": str(item_path).replace("\\", "/"),
160
+ "type": "directory" if entry.is_dir() else "file",
161
+ "size": stat.st_size,
162
+ "modified": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
163
+ })
164
+ except PermissionError as exc:
165
+ raise HTTPException(status_code=403, detail=f"Permission denied: {exc}")
166
+
167
+ return {
168
+ "root": str(BASE_DIR),
169
+ "path": str(directory.relative_to(BASE_DIR)) if directory != BASE_DIR else ".",
170
+ "items": items,
171
+ }
172
+
173
+
174
+ @app.post("/api/fix-indentation")
175
+ async def fix_indentation(request: Request):
176
+ """Fix indentation in Python code using textwrap.dedent()."""
177
+ try:
178
+ body = await request.json()
179
+ code = body.get("code", "")
180
+ fixed_code = textwrap.dedent(code)
181
+ return {"fixed_code": fixed_code}
182
+ except Exception as exc:
183
+ raise HTTPException(status_code=500, detail=f"Failed to fix indentation: {exc}")
184
+
185
+
186
+ @app.get("/api/file")
187
+ async def read_file(path: str, max_bytes: int = 256_000):
188
+ file_path = resolve_path(path)
189
+ if not file_path.exists() or not file_path.is_file():
190
+ raise HTTPException(status_code=404, detail="File not found")
191
+
192
+ try:
193
+ with file_path.open("rb") as f:
194
+ content = f.read(max_bytes + 1)
195
+ except PermissionError as exc:
196
+ raise HTTPException(status_code=403, detail=f"Permission denied: {exc}")
197
+
198
+ truncated = len(content) > max_bytes
199
+ if truncated:
200
+ content = content[:max_bytes]
201
+
202
+ text = content.decode("utf-8", errors="replace")
203
+ if truncated:
204
+ text += "\n\n… (truncated)"
205
+
206
+ return PlainTextResponse(text)
207
+
208
+
209
+ class WebSocketManager:
210
+ """Manages WebSocket connections and message handling."""
211
+ def __init__(self) -> None:
212
+ self.clients: dict[WebSocket, None] = {}
213
+ self.executor = executor
214
+ self.notebook = notebook
215
+
216
+ async def connect(self, websocket: WebSocket):
217
+ await websocket.accept()
218
+ self.clients[websocket] = None
219
+ # Send the initial notebook state to the new client
220
+ await websocket.send_json({
221
+ "type": "notebook_data",
222
+ "data": self.notebook.get_notebook_data()
223
+ })
224
+
225
+ def disconnect(self, websocket: WebSocket):
226
+ del self.clients[websocket]
227
+
228
+ async def broadcast_notebook_update(self):
229
+ """Send the entire notebook state to all connected clients."""
230
+ updated_data = self.notebook.get_notebook_data()
231
+ for client in self.clients:
232
+ await client.send_json({
233
+ "type": "notebook_updated",
234
+ "data": updated_data
235
+ })
236
+
237
+ async def handle_message_loop(self, websocket: WebSocket):
238
+ """Main loop to handle incoming WebSocket messages."""
239
+ while True:
240
+ try:
241
+ message = await websocket.receive_json()
242
+ await self._handle_message(websocket, message)
243
+ except WebSocketDisconnect:
244
+ self.disconnect(websocket)
245
+ break
246
+ except Exception as e:
247
+ await self._send_error(websocket, f"Unhandled error: {e}")
248
+
249
+ async def _handle_message(self, websocket: WebSocket, message: dict):
250
+ message_type = message.get("type")
251
+ data = message.get("data", {})
252
+
253
+ handlers = {
254
+ "execute_cell": self._handle_execute_cell,
255
+ "add_cell": self._handle_add_cell,
256
+ "delete_cell": self._handle_delete_cell,
257
+ "update_cell": self._handle_update_cell,
258
+ "interrupt_kernel": self._handle_interrupt_kernel,
259
+ "reset_kernel": self._handle_reset_kernel,
260
+ "load_notebook": self._handle_load_notebook,
261
+ "save_notebook": self._handle_save_notebook,
262
+ }
263
+
264
+ handler = handlers.get(message_type)
265
+ if handler:
266
+ await handler(websocket, data)
267
+ else:
268
+ await self._send_error(websocket, f"Unknown message type: {message_type}")
269
+
270
+ async def _handle_execute_cell(self, websocket: WebSocket, data: dict):
271
+ import sys
272
+ cell_index = data.get("cell_index")
273
+ if cell_index is None or not (0 <= cell_index < len(self.notebook.cells)):
274
+ await self._send_error(websocket, "Invalid cell index.")
275
+ return
276
+
277
+ source = coerce_cell_source(self.notebook.cells[cell_index].get('source', ''))
278
+
279
+ await websocket.send_json({
280
+ "type": "execution_start",
281
+ "data": {"cell_index": cell_index, "execution_count": getattr(self.executor, 'execution_count', 0) + 1}
282
+ })
283
+
284
+ try:
285
+ result = await self.executor.execute_cell(cell_index, source, websocket)
286
+ except Exception as e:
287
+ error_msg = str(e)
288
+ print(f"[SERVER ERROR] execute_cell failed: {error_msg}", file=sys.stderr, flush=True)
289
+
290
+ # Send error to frontend
291
+ result = {
292
+ 'status': 'error',
293
+ 'execution_count': None,
294
+ 'execution_time': '0ms',
295
+ 'outputs': [],
296
+ 'error': {
297
+ 'output_type': 'error',
298
+ 'ename': type(e).__name__,
299
+ 'evalue': error_msg,
300
+ 'traceback': [f'{type(e).__name__}: {error_msg}', 'Worker failed to start or crashed. Check server logs.']
301
+ }
302
+ }
303
+ await websocket.send_json({
304
+ "type": "execution_error",
305
+ "data": {
306
+ "cell_index": cell_index,
307
+ "error": result['error']
308
+ }
309
+ })
310
+
311
+ self.notebook.cells[cell_index]['outputs'] = result.get('outputs', [])
312
+ self.notebook.cells[cell_index]['execution_count'] = result.get('execution_count')
313
+
314
+ await websocket.send_json({
315
+ "type": "execution_complete",
316
+ "data": { "cell_index": cell_index, "result": result }
317
+ })
318
+
319
+ async def _handle_add_cell(self, websocket: WebSocket, data: dict):
320
+ index = data.get('index', len(self.notebook.cells))
321
+ cell_type = data.get('cell_type', 'code')
322
+ self.notebook.add_cell(index=index, cell_type=cell_type)
323
+ await self.broadcast_notebook_update()
324
+
325
+ async def _handle_delete_cell(self, websocket: WebSocket, data: dict):
326
+ index = data.get('cell_index')
327
+ if index is not None:
328
+ self.notebook.delete_cell(index)
329
+ await self.broadcast_notebook_update()
330
+
331
+ async def _handle_update_cell(self, websocket: WebSocket, data: dict):
332
+ index = data.get('cell_index')
333
+ source = data.get('source')
334
+ if index is not None and source is not None:
335
+ self.notebook.update_cell(index, source)
336
+ #self.notebook.save_to_file()
337
+ #to -do?
338
+
339
+
340
+ async def _handle_load_notebook(self, websocket: WebSocket, data: dict):
341
+ # In a real app, this would load from a file path in `data`
342
+ # For now, it just sends the current state back to the requester
343
+ await websocket.send_json({
344
+ "type": "notebook_data",
345
+ "data": self.notebook.get_notebook_data()
346
+ })
347
+
348
+ async def _handle_save_notebook(self, websocket: WebSocket, data: dict):
349
+ try:
350
+ self.notebook.save_to_file()
351
+ await websocket.send_json({"type": "notebook_saved", "data": {"file_path": self.notebook.file_path}})
352
+ except Exception as exc:
353
+ await self._send_error(websocket, f"Failed to save notebook: {exc}")
354
+
355
+ async def _handle_interrupt_kernel(self, websocket: WebSocket, data: dict):
356
+ try:
357
+ cell_index = data.get('cell_index')
358
+ except Exception:
359
+ cell_index = None
360
+
361
+ import sys
362
+ print(f"[SERVER] Interrupt request received for cell {cell_index}", file=sys.stderr, flush=True)
363
+
364
+ # Perform the interrupt (this may take up to 1 second)
365
+ await self.executor.interrupt_kernel(cell_index=cell_index)
366
+
367
+ print(f"[SERVER] Interrupt completed, sending error message", file=sys.stderr, flush=True)
368
+
369
+ # Inform all clients that the currently running cell (if any) is interrupted
370
+ try:
371
+ await websocket.send_json({
372
+ "type": "execution_error",
373
+ "data": {
374
+ "cell_index": cell_index,
375
+ "error": {
376
+ "output_type": "error",
377
+ "ename": "KeyboardInterrupt",
378
+ "evalue": "Execution interrupted by user",
379
+ "traceback": ["KeyboardInterrupt: Execution was stopped by user"]
380
+ }
381
+ }
382
+ })
383
+ await websocket.send_json({
384
+ "type": "execution_complete",
385
+ "data": {
386
+ "cell_index": cell_index,
387
+ "result": {
388
+ "status": "error",
389
+ "execution_count": None,
390
+ "execution_time": "interrupted",
391
+ "outputs": [],
392
+ "error": {
393
+ "output_type": "error",
394
+ "ename": "KeyboardInterrupt",
395
+ "evalue": "Execution interrupted by user",
396
+ "traceback": ["KeyboardInterrupt: Execution was stopped by user"]
397
+ }
398
+ }
399
+ }
400
+ })
401
+ print(f"[SERVER] Error messages sent for cell {cell_index}", file=sys.stderr, flush=True)
402
+ except Exception as e:
403
+ print(f"[SERVER] Failed to send error messages: {e}", file=sys.stderr, flush=True)
404
+
405
+ async def _handle_reset_kernel(self, websocket: WebSocket, data: dict):
406
+ self.executor.reset_kernel()
407
+ self.notebook.clear_all_outputs()
408
+ await self.broadcast_notebook_update()
409
+
410
+ async def _send_error(self, websocket: WebSocket, error_message: str):
411
+ await websocket.send_json({"type": "error", "data": {"error": error_message}})
412
+
413
+
414
+ manager = WebSocketManager()
415
+
416
+
417
+ @app.websocket("/ws")
418
+ async def websocket_endpoint(websocket: WebSocket):
419
+ await manager.connect(websocket)
420
+ await manager.handle_message_loop(websocket)
421
+
422
+
423
+ #gpu connection api
424
+ @app.get("/api/gpu/config")
425
+ async def get_gpu_config():
426
+ """Check if Prime Intellect API is configured."""
427
+ return {"configured": prime_intellect is not None}
428
+
429
+
430
+ @app.post("/api/gpu/config")
431
+ async def set_gpu_config(request: Request):
432
+ """Save Prime Intellect API key to .env file (commonly gitignored) and reinitialize service."""
433
+ global prime_intellect
434
+
435
+ try:
436
+ body = await request.json()
437
+ api_key = body.get("api_key", "").strip()
438
+ if not api_key:
439
+ raise HTTPException(status_code=400, detail="API key is required")
440
+
441
+ env_path = BASE_DIR / ".env"
442
+
443
+ # Read existing .env content
444
+ existing_lines = []
445
+ if env_path.exists():
446
+ with env_path.open("r", encoding="utf-8") as f:
447
+ existing_lines = f.readlines()
448
+
449
+ # Remove any existing PRIME_INTELLECT_API_KEY lines
450
+ new_lines = [line for line in existing_lines if not line.strip().startswith("PRIME_INTELLECT_API_KEY=")]
451
+ # Add the new API key
452
+ new_lines.append(f"PRIME_INTELLECT_API_KEY={api_key}\n")
453
+ # Write back to .env
454
+ with env_path.open("w", encoding="utf-8") as f:
455
+ f.writelines(new_lines)
456
+ prime_intellect = PrimeIntellectService(api_key=api_key)
457
+
458
+ return {"configured": True, "message": "API key saved successfully"}
459
+
460
+ except HTTPException:
461
+ raise
462
+ except Exception as exc:
463
+ raise HTTPException(status_code=500, detail=f"Failed to save API key: {exc}")
464
+
465
+
466
+ @app.get("/api/gpu/availability")
467
+ async def get_gpu_availability(
468
+ regions: list[str] | None = None,
469
+ gpu_count: int | None = None,
470
+ gpu_type: str | None = None,
471
+ security: str | None = None
472
+ ):
473
+ """Get available GPU resources from Prime Intellect."""
474
+ if not prime_intellect:
475
+ raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
476
+
477
+ cache_key = make_cache_key(
478
+ "gpu_avail",
479
+ regions = regions,
480
+ gpu_count = gpu_count,
481
+ gpu_type = gpu_type,
482
+ security=security
483
+ )
484
+
485
+ if cache_key in gpu_cache:
486
+ return gpu_cache[cache_key]
487
+
488
+ #cache miss
489
+ result = await prime_intellect.get_gpu_availability(regions, gpu_count, gpu_type, security)
490
+ gpu_cache[cache_key] = result
491
+ return result
492
+
493
+ @app.get("/api/gpu/pods")
494
+ async def get_gpu_pods(status: str | None = None, limit: int = 100, offset: int = 0):
495
+ """Get list of user's GPU pods."""
496
+ if not prime_intellect:
497
+ raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
498
+
499
+ cache_key = make_cache_key(
500
+ "gpu_pod",
501
+ status=status,
502
+ limit=limit,
503
+ offset=offset
504
+ )
505
+
506
+ if cache_key in pod_cache:
507
+ return pod_cache[cache_key]
508
+
509
+ # Cache miss: fetch from API
510
+ result = await prime_intellect.get_pods(status, limit, offset)
511
+ pod_cache[cache_key] = result
512
+ return result
513
+
514
+ @app.post("/api/gpu/pods")
515
+ async def create_gpu_pod(pod_request: CreatePodRequest) -> PodResponse:
516
+ """Create a new GPU pod."""
517
+ import sys
518
+ print(f"[CREATE POD] Received request: {pod_request.model_dump()}", file=sys.stderr, flush=True)
519
+
520
+ if not prime_intellect:
521
+ raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
522
+
523
+ try:
524
+ result = await prime_intellect.create_pod(pod_request)
525
+ print(f"[CREATE POD] Success: {result}", file=sys.stderr, flush=True)
526
+ pod_cache.clear()
527
+
528
+ return result
529
+ except HTTPException as e:
530
+ if e.status_code == 402:
531
+ raise HTTPException(
532
+ status_code=402,
533
+ detail="Insufficient funds in your Prime Intellect wallet. Please add credits at https://app.primeintellect.ai/dashboard/billing"
534
+ )
535
+ elif e.status_code == 401 or e.status_code == 403:
536
+ raise HTTPException(
537
+ status_code=e.status_code,
538
+ detail="Authentication failed. Please check your Prime Intellect API key."
539
+ )
540
+ else:
541
+ print(f"[CREATE POD] Error: {e}", file=sys.stderr, flush=True)
542
+ raise
543
+
544
+
545
+ @app.get("/api/gpu/pods/{pod_id}")
546
+ async def get_gpu_pod(pod_id: str) -> PodResponse:
547
+ """Get details of a specific GPU pod."""
548
+ if not prime_intellect:
549
+ raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
550
+
551
+ cache_key = make_cache_key("gpu_pod_detail", pod_id=pod_id)
552
+
553
+ if cache_key in pod_cache:
554
+ return pod_cache[cache_key]
555
+
556
+ result = await prime_intellect.get_pod(pod_id)
557
+ pod_cache[cache_key] = result
558
+ return result
559
+
560
+
561
+ @app.delete("/api/gpu/pods/{pod_id}")
562
+ async def delete_gpu_pod(pod_id: str):
563
+ """Delete a GPU pod."""
564
+ if not prime_intellect:
565
+ raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
566
+
567
+ result = await prime_intellect.delete_pod(pod_id)
568
+ pod_cache.clear()
569
+ return result
570
+
571
+
572
+ @app.post("/api/gpu/pods/{pod_id}/connect")
573
+ async def connect_to_pod(pod_id: str):
574
+ """Connect to a GPU pod and establish SSH tunnel for remote execution."""
575
+ global pod_manager
576
+
577
+ if not prime_intellect:
578
+ raise HTTPException(status_code=503, detail="Prime Intellect API key not configured")
579
+ if pod_manager is None:
580
+ pod_manager = PodKernelManager(pi_service=prime_intellect)
581
+
582
+ # Disconnect from any existing pod first, may need to fix later for multi-pod
583
+ if pod_manager.pod is not None:
584
+ await pod_manager.disconnect()
585
+
586
+ # Connect to the new pod
587
+ result = await pod_manager.connect_to_pod(pod_id)
588
+
589
+ if result.get("status") == "ok":
590
+ pod_manager.attach_executor(executor)
591
+ addresses = pod_manager.get_executor_addresses()
592
+ executor.cmd_addr = addresses["cmd_addr"]
593
+ executor.pub_addr = addresses["pub_addr"]
594
+
595
+ # Reconnect executor sockets to tunneled ports
596
+ executor.req.close(0) # type: ignore[reportAttributeAccessIssue]
597
+ executor.req = executor.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
598
+ executor.req.connect(executor.cmd_addr) # type: ignore[reportAttributeAccessIssue]
599
+
600
+ executor.sub.close(0) # type: ignore[reportAttributeAccessIssue]
601
+ executor.sub = executor.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
602
+ executor.sub.connect(executor.pub_addr) # type: ignore[reportAttributeAccessIssue]
603
+ executor.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
604
+
605
+ return result
606
+
607
+
608
+ @app.post("/api/gpu/pods/disconnect")
609
+ async def disconnect_from_pod():
610
+ """Disconnect from current GPU pod."""
611
+ global pod_manager
612
+
613
+ if pod_manager is None or pod_manager.pod is None:
614
+ return {"status": "ok", "message": "No active connection"}
615
+
616
+ result = await pod_manager.disconnect()
617
+
618
+ # Reset executor to local addresses
619
+ executor.cmd_addr = os.getenv('MC_ZMQ_CMD_ADDR', 'tcp://127.0.0.1:5555')
620
+ executor.pub_addr = os.getenv('MC_ZMQ_PUB_ADDR', 'tcp://127.0.0.1:5556')
621
+
622
+ # Reconnect to local worker
623
+ executor.req.close(0) # type: ignore[reportAttributeAccessIssue]
624
+ executor.req = executor.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
625
+ executor.req.connect(executor.cmd_addr) # type: ignore[reportAttributeAccessIssue]
626
+
627
+ executor.sub.close(0) # type: ignore[reportAttributeAccessIssue]
628
+ executor.sub = executor.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
629
+ executor.sub.connect(executor.pub_addr) # type: ignore[reportAttributeAccessIssue]
630
+ executor.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
631
+
632
+ return result
633
+
634
+
635
+ @app.get("/api/gpu/pods/connection/status")
636
+ async def get_pod_connection_status():
637
+ """Get status of current pod connection."""
638
+ if pod_manager is None:
639
+ return {"connected": False, "pod": None}
640
+
641
+ return await pod_manager.get_status()