more-compute 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,503 @@
1
+ import asyncio
2
+ import subprocess
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import tarfile
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ from .prime_intellect import PrimeIntellectService, PodResponse
11
+
12
+ if TYPE_CHECKING:
13
+ from ..execution.executor import NextZmqExecutor
14
+
15
+ class PodKernelManager:
16
+ """
17
+ Manages remote GPU pod connections (currently PI as provider, hope to provide other providers in the future)
18
+ and SSH tunnels for ZMQ execution
19
+ """
20
+ pi_service: PrimeIntellectService
21
+ pod: PodResponse | None
22
+ ssh_tunnel_proc: subprocess.Popen[bytes] | None
23
+ local_cmd_port: int
24
+ local_pub_port: int
25
+ remote_cmd_port : int
26
+ remote_pub_port: int
27
+ executor: "NextZmqExecutor | None"
28
+ _ssh_key_cache: str | None
29
+
30
+ def __init__(
31
+ self,
32
+ pi_service: PrimeIntellectService,
33
+ local_cmd_port: int = 15555,
34
+ local_pub_port: int = 15556,
35
+ remote_cmd_port: int = 5555,
36
+ remote_pub_port: int = 5556
37
+ ) -> None:
38
+ """
39
+ Initialize pod manager
40
+
41
+ args:
42
+ pi_service : Prime Intellect API service
43
+ local_cmd_port: Local port for REQ/REP tunnel
44
+ local_pub_port: Local port for PUB/SUB tunnel
45
+ remote_cmd_port: Remote port for REQ/REP socket
46
+ remote_pub_port: Remote port for PUB/SUB socket
47
+ """
48
+ self.pi_service = pi_service
49
+ self.pod = None
50
+ self.ssh_tunnel_proc = None
51
+ self.local_cmd_port = local_cmd_port
52
+ self.local_pub_port = local_pub_port
53
+ self.remote_cmd_port = remote_cmd_port
54
+ self.remote_pub_port = remote_pub_port
55
+ self.executor = None
56
+ self._ssh_key_cache = None
57
+
58
+ def _get_ssh_key(self) -> str | None:
59
+ """
60
+ Get SSH key path, checking environment variable and common locations.
61
+ Returns None if no key is found.
62
+ """
63
+ if self._ssh_key_cache is not None:
64
+ return self._ssh_key_cache
65
+
66
+ # Check environment variable first
67
+ ssh_key = os.getenv("MORECOMPUTE_SSH_KEY")
68
+ if ssh_key:
69
+ expanded = os.path.expanduser(ssh_key)
70
+ if os.path.exists(expanded):
71
+ self._ssh_key_cache = expanded
72
+ return expanded
73
+
74
+ # Try common SSH key paths (including Prime Intellect's recommended name)
75
+ common_keys = [
76
+ "~/.ssh/primeintellect_ed25519", # Prime Intellect's recommended name
77
+ "~/.ssh/id_ed25519",
78
+ "~/.ssh/id_rsa",
79
+ ]
80
+ for key_path in common_keys:
81
+ expanded_path = os.path.expanduser(key_path)
82
+ if os.path.exists(expanded_path):
83
+ self._ssh_key_cache = expanded_path
84
+ return expanded_path
85
+
86
+ return None
87
+
88
+ async def connect_to_pod(self, pod_id:str) -> dict[str, object]:
89
+ """
90
+ Connects to existing pod and set up ssh tunnel
91
+ args:
92
+ pod_id: the pod identifier
93
+
94
+ Response:
95
+ dict with connection status
96
+ """
97
+ import sys
98
+
99
+ self.pod = await self.pi_service.get_pod(pod_id)
100
+
101
+ print(f"[POD MANAGER] Pod status: {self.pod.status}", file=sys.stderr, flush=True)
102
+ print(f"[POD MANAGER] SSH connection: {self.pod.sshConnection}", file=sys.stderr, flush=True)
103
+
104
+ if not self.pod.sshConnection:
105
+ return{
106
+ "status":"error",
107
+ "message": f"Pod is not ready yet. Status: {self.pod.status}. SSH connection info is not available. Please wait for pod to finish provisioning."
108
+ }
109
+
110
+ # Validate SSH connection string is not empty/whitespace
111
+ if not self.pod.sshConnection.strip():
112
+ return{
113
+ "status":"error",
114
+ "message": f"Pod SSH connection is empty. Status: {self.pod.status}"
115
+ }
116
+
117
+ # Parse SSH connection string
118
+ # Format can be: "ssh root@ip -p port" OR "root@ip -p port"
119
+ ssh_parts = self.pod.sshConnection.split()
120
+
121
+ # Find the part containing @ (user@host)
122
+ host_part = None
123
+ for part in ssh_parts:
124
+ if "@" in part:
125
+ host_part = part
126
+ break
127
+
128
+ if not host_part:
129
+ return{
130
+ "status":"error",
131
+ "message": f"Invalid SSH connection format (no user@host found): {self.pod.sshConnection}"
132
+ }
133
+
134
+ # Extract host from user@host
135
+ ssh_host = host_part.split("@")[1]
136
+ ssh_port = "22"
137
+
138
+ # Extract port if specified with -p flag
139
+ if "-p" in ssh_parts:
140
+ port_idx = ssh_parts.index("-p")
141
+ if port_idx + 1 < len(ssh_parts):
142
+ ssh_port = ssh_parts[port_idx + 1]
143
+
144
+ print(f"[POD MANAGER] Parsed SSH host: {ssh_host}, port: {ssh_port}", file=sys.stderr, flush=True)
145
+
146
+ #deploy worker code to pod
147
+ deploy_result = await self._deploy_worker(ssh_host, ssh_port)
148
+ if deploy_result.get("status") == "error":
149
+ return deploy_result
150
+
151
+ #create ssh tunnel for ZMQ ports
152
+ tunnel_result = await self._create_ssh_tunnel(ssh_host, ssh_port)
153
+ if tunnel_result.get("status") == "error":
154
+ return tunnel_result
155
+
156
+ #start remote worker
157
+ worker_result = await self._start_remote_worker(ssh_host, ssh_port)
158
+ if worker_result.get("status") == "error":
159
+ await self.disconnect()
160
+ return worker_result
161
+
162
+ return {
163
+ "status": "ok",
164
+ "message": f"Connected to pod {pod_id}",
165
+ "ssh_host": ssh_host,
166
+ "ssh_port": ssh_port,
167
+ "tunnel_ports": {
168
+ "cmd": f"localhost:{self.local_cmd_port}",
169
+ "pub": f"localhost:{self.local_pub_port}"
170
+ }
171
+ }
172
+
173
+ async def _deploy_worker(self, ssh_host: str, ssh_port: str) -> dict[str,object]:
174
+ """
175
+ Deploy worker code to remote pod via Secure Copy Protocol.
176
+
177
+ args:
178
+ ssh_host: SSH host address
179
+ ssh_port: SSH port
180
+
181
+ returns:
182
+ dict with deployment status
183
+ """
184
+ try:
185
+ # Create temporary tarball of morecompute package
186
+ project_root = Path(__file__).parent.parent.parent
187
+ morecompute_dir = project_root / "morecompute"
188
+
189
+ with tempfile.NamedTemporaryFile(suffix='.tar.gz', delete=False) as tmp:
190
+ tmp_path = tmp.name
191
+
192
+ with tarfile.open(tmp_path, 'w:gz') as tar:
193
+ tar.add(morecompute_dir, arcname='morecompute')
194
+
195
+ # Build SSH command with optional key override
196
+ scp_cmd = ["scp", "-P", ssh_port]
197
+
198
+ ssh_key = self._get_ssh_key()
199
+ if ssh_key:
200
+ scp_cmd.extend(["-i", ssh_key])
201
+
202
+ scp_cmd.extend([
203
+ "-o", "StrictHostKeyChecking=no",
204
+ "-o", "UserKnownHostsFile=/dev/null",
205
+ "-o", "BatchMode=yes", # Prevent password prompts, fail fast if key auth doesn't work
206
+ "-o", "ConnectTimeout=10",
207
+ tmp_path,
208
+ f"root@{ssh_host}:/tmp/morecompute.tar.gz"
209
+ ])
210
+
211
+ result = subprocess.run(
212
+ scp_cmd,
213
+ capture_output=True,
214
+ text=True,
215
+ timeout=60
216
+ )
217
+
218
+ if result.returncode != 0:
219
+ error_msg = result.stderr.lower()
220
+ if "permission denied" in error_msg or "publickey" in error_msg:
221
+ return {
222
+ "status": "error",
223
+ "message": (
224
+ "SSH authentication failed. Please add your SSH public key to Prime Intellect:\n"
225
+ "1. Visit https://app.primeintellect.ai/dashboard/tokens\n"
226
+ "2. Upload your public key (~/.ssh/id_ed25519.pub or ~/.ssh/id_rsa.pub)\n"
227
+ "3. Try connecting again"
228
+ )
229
+ }
230
+ elif "host key verification failed" in error_msg:
231
+ return {
232
+ "status": "error",
233
+ "message": f"SSH host verification failed. This is unusual. Error: {result.stderr}"
234
+ }
235
+ else:
236
+ return {
237
+ "status": "error",
238
+ "message": f"Failed to copy worker code to pod. SSH Error: {result.stderr}"
239
+ }
240
+
241
+ # Extract on remote and install dependencies
242
+ ssh_cmd = ["ssh", "-p", ssh_port]
243
+
244
+ if ssh_key:
245
+ ssh_cmd.extend(["-i", ssh_key])
246
+
247
+ ssh_cmd.extend([
248
+ "-o", "StrictHostKeyChecking=no",
249
+ "-o", "UserKnownHostsFile=/dev/null",
250
+ "-o", "BatchMode=yes",
251
+ "-o", "ConnectTimeout=10",
252
+ f"root@{ssh_host}",
253
+ (
254
+ "cd /tmp && "
255
+ "tar -xzf morecompute.tar.gz && "
256
+ "pip install --quiet pyzmq && "
257
+ "echo 'Deployment complete'"
258
+ )
259
+ ])
260
+
261
+ result = subprocess.run(
262
+ ssh_cmd,
263
+ capture_output=True,
264
+ text=True,
265
+ timeout=120
266
+ )
267
+
268
+ # Cleanup local tarball
269
+ os.unlink(tmp_path)
270
+
271
+ if result.returncode != 0:
272
+ return {
273
+ "status": "error",
274
+ "message": f"Failed to extract/setup worker: {result.stderr}"
275
+ }
276
+
277
+ return {"status": "ok", "message": "Worker deployed successfully"}
278
+
279
+ except Exception as e:
280
+ return {
281
+ "status": "error",
282
+ "message": f"Deployment error: {str(e)}"
283
+ }
284
+
285
+ async def _create_ssh_tunnel(self, ssh_host: str, ssh_port: str) -> dict[str, object]:
286
+ """
287
+ Create SSH tunnel for ZMQ ports.
288
+
289
+ args:
290
+ ssh_host: SSH host address
291
+ ssh_port: SSH port
292
+
293
+ returns:
294
+ dict with tunnel status
295
+ """
296
+ try:
297
+ # Create SSH tunnel: local ports -> remote ports
298
+ ssh_key = self._get_ssh_key()
299
+ tunnel_cmd = ["ssh", "-p", ssh_port]
300
+
301
+ if ssh_key:
302
+ tunnel_cmd.extend(["-i", ssh_key])
303
+
304
+ tunnel_cmd.extend([
305
+ "-o", "StrictHostKeyChecking=no",
306
+ "-o", "UserKnownHostsFile=/dev/null",
307
+ "-o", "BatchMode=yes",
308
+ "-o", "ServerAliveInterval=60",
309
+ "-o", "ServerAliveCountMax=3",
310
+ "-N", # No command execution
311
+ "-L", f"{self.local_cmd_port}:localhost:{self.remote_cmd_port}",
312
+ "-L", f"{self.local_pub_port}:localhost:{self.remote_pub_port}",
313
+ f"root@{ssh_host}"
314
+ ])
315
+
316
+ self.ssh_tunnel_proc = subprocess.Popen(
317
+ tunnel_cmd,
318
+ stdout=subprocess.DEVNULL,
319
+ stderr=subprocess.DEVNULL
320
+ )
321
+
322
+ # Wait briefly for tunnel to establish
323
+ await asyncio.sleep(2)
324
+
325
+ if self.ssh_tunnel_proc.poll() is not None:
326
+ return {
327
+ "status": "error",
328
+ "message": "SSH tunnel failed to establish"
329
+ }
330
+
331
+ return {
332
+ "status": "ok",
333
+ "message": "SSH tunnel created",
334
+ "pid": self.ssh_tunnel_proc.pid
335
+ }
336
+
337
+ except Exception as e:
338
+ return {
339
+ "status": "error",
340
+ "message": f"Tunnel creation error: {str(e)}"
341
+ }
342
+
343
+ async def _start_remote_worker(self, ssh_host: str, ssh_port: str) -> dict[str, object]:
344
+ """
345
+ Start ZMQ worker on remote pod.
346
+
347
+ args:
348
+ ssh_host: SSH host address
349
+ ssh_port: SSH port
350
+
351
+ returns:
352
+ dict with worker start status
353
+ """
354
+ try:
355
+ # Start worker in background on remote pod
356
+ ssh_key = self._get_ssh_key()
357
+ worker_cmd = ["ssh", "-p", ssh_port]
358
+
359
+ if ssh_key:
360
+ worker_cmd.extend(["-i", ssh_key])
361
+
362
+ worker_cmd.extend([
363
+ "-o", "StrictHostKeyChecking=no",
364
+ "-o", "UserKnownHostsFile=/dev/null",
365
+ "-o", "BatchMode=yes",
366
+ "-o", "ConnectTimeout=10",
367
+ f"root@{ssh_host}",
368
+ (
369
+ f"cd /tmp && "
370
+ f"export MC_ZMQ_CMD_ADDR=tcp://0.0.0.0:{self.remote_cmd_port} && "
371
+ f"export MC_ZMQ_PUB_ADDR=tcp://0.0.0.0:{self.remote_pub_port} && "
372
+ f"export PYTHONPATH=/tmp:$PYTHONPATH && "
373
+ f"nohup {sys.executable} -m morecompute.execution.worker "
374
+ f"> /tmp/worker.log 2>&1 & "
375
+ f"echo $!"
376
+ )
377
+ ])
378
+
379
+ result = subprocess.run(
380
+ worker_cmd,
381
+ capture_output=True,
382
+ text=True,
383
+ timeout=30
384
+ )
385
+
386
+ if result.returncode != 0:
387
+ return {
388
+ "status": "error",
389
+ "message": f"Failed to start remote worker: {result.stderr}"
390
+ }
391
+
392
+ remote_pid = result.stdout.strip()
393
+
394
+ # Wait for worker to be ready
395
+ await asyncio.sleep(2)
396
+
397
+ return {
398
+ "status": "ok",
399
+ "message": "Remote worker started",
400
+ "remote_pid": remote_pid
401
+ }
402
+
403
+ except Exception as e:
404
+ return {
405
+ "status": "error",
406
+ "message": f"Worker start error: {str(e)}"
407
+ }
408
+
409
+ def get_executor_addresses(self) -> dict[str, str]:
410
+ """
411
+ Get ZMQ addresses for executor to connect to tunneled ports.
412
+
413
+ returns:
414
+ dict with cmd_addr and pub_addr
415
+ """
416
+ return {
417
+ "cmd_addr": f"tcp://127.0.0.1:{self.local_cmd_port}",
418
+ "pub_addr": f"tcp://127.0.0.1:{self.local_pub_port}"
419
+ }
420
+
421
+ def attach_executor(self, executor: "NextZmqExecutor") -> None:
422
+ """
423
+ Attach an executor instance to this pod manager.
424
+
425
+ args:
426
+ executor: The executor to attach
427
+ """
428
+ self.executor = executor
429
+
430
+ async def disconnect(self) -> dict[str, object]:
431
+ """
432
+ Disconnect from pod and cleanup tunnels.
433
+
434
+ returns:
435
+ dict with disconnection status
436
+ """
437
+ messages = []
438
+ if self.ssh_tunnel_proc:
439
+ try:
440
+ self.ssh_tunnel_proc.terminate()
441
+ try:
442
+ self.ssh_tunnel_proc.wait(timeout=5)
443
+ except subprocess.TimeoutExpired:
444
+ self.ssh_tunnel_proc.kill()
445
+ messages.append("SSH tunnel closed")
446
+ except Exception as e:
447
+ messages.append(f"Error closing tunnel: {str(e)}")
448
+ finally:
449
+ self.ssh_tunnel_proc = None
450
+
451
+ # Note: We don't kill remote worker as it may be used by other connections
452
+ # The pod itself should clean up when terminated
453
+
454
+ self.pod = None
455
+
456
+ return {
457
+ "status": "ok",
458
+ "messages": messages
459
+ }
460
+
461
+ async def get_status(self) -> dict[str, object]:
462
+ """
463
+ Get current connection status.
464
+
465
+ returns:
466
+ dict with status information
467
+ """
468
+ if not self.pod:
469
+ return {
470
+ "connected": False,
471
+ "pod": None
472
+ }
473
+
474
+ # Check tunnel status
475
+ tunnel_alive = False
476
+ if self.ssh_tunnel_proc:
477
+ tunnel_alive = self.ssh_tunnel_proc.poll() is None
478
+
479
+ # Get updated pod info
480
+ try:
481
+ updated_pod = await self.pi_service.get_pod(self.pod.id)
482
+ pod_status = updated_pod.status
483
+ except Exception:
484
+ pod_status = "unknown"
485
+
486
+ return {
487
+ "connected": True,
488
+ "pod": {
489
+ "id": self.pod.id,
490
+ "name": self.pod.name,
491
+ "status": pod_status,
492
+ "gpu_type": self.pod.gpuName,
493
+ "gpu_count": self.pod.gpuCount,
494
+ "price_hr": self.pod.priceHr,
495
+ "ssh_connection": self.pod.sshConnection
496
+ },
497
+ "tunnel": {
498
+ "alive": tunnel_alive,
499
+ "local_cmd_port": self.local_cmd_port,
500
+ "local_pub_port": self.local_pub_port
501
+ },
502
+ "executor_attached": self.executor is not None
503
+ }