podstack 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
podstack/gpu_runner.py ADDED
@@ -0,0 +1,1141 @@
1
+ """
2
+ Podstack GPU Runner
3
+
4
+ Handles actual GPU provisioning and remote code execution.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import inspect
11
+ import textwrap
12
+ import logging
13
+ import threading
14
+ from typing import Optional, Dict, Any, Callable, Union, List, Iterator
15
+ from dataclasses import dataclass, field
16
+ import httpx
17
+
18
+ # Configure logging
19
+ logger = logging.getLogger("podstack.gpu_runner")
20
+
21
+
22
+ def is_jupyter() -> bool:
23
+ """Check if running in a Jupyter notebook."""
24
+ try:
25
+ from IPython import get_ipython
26
+ shell = get_ipython()
27
+ if shell is None:
28
+ return False
29
+ if shell.__class__.__name__ == 'ZMQInteractiveShell':
30
+ return True # Jupyter notebook or qtconsole
31
+ if shell.__class__.__name__ == 'TerminalInteractiveShell':
32
+ return False # Terminal IPython
33
+ except (ImportError, NameError):
34
+ pass
35
+ return False
36
+
37
+
38
+ class OutputStreamer:
39
+ """Handles real-time output streaming from GPU executions."""
40
+
41
+ def __init__(self, execution_id: str, api_url: str, headers: Dict[str, str], timeout: float = 30.0):
42
+ self.execution_id = execution_id
43
+ self.api_url = api_url
44
+ self.headers = headers
45
+ self.timeout = timeout
46
+ self._output_buffer: List[str] = []
47
+ self._status = "running"
48
+ self._final_result: Dict[str, Any] = {}
49
+ self._jupyter_widget = None
50
+
51
+ def stream(self, show_output: bool = True) -> Iterator[Dict[str, Any]]:
52
+ """
53
+ Stream output events from the execution.
54
+
55
+ Yields:
56
+ Dict with event data (type, content, timestamp, etc.)
57
+ """
58
+ url = f"{self.api_url}/api/v1/executions/{self.execution_id}/stream"
59
+
60
+ try:
61
+ with httpx.Client(timeout=httpx.Timeout(self.timeout, read=None)) as client:
62
+ with client.stream("GET", url, headers=self.headers) as response:
63
+ if response.status_code != 200:
64
+ raise RuntimeError(f"Failed to connect to stream: HTTP {response.status_code}")
65
+
66
+ for line in response.iter_lines():
67
+ if not line:
68
+ continue
69
+
70
+ # SSE format: "data: {...}"
71
+ if line.startswith("data: "):
72
+ data_str = line[6:] # Remove "data: " prefix
73
+ try:
74
+ event = __import__('json').loads(data_str)
75
+ yield event
76
+
77
+ # Track status
78
+ if "status" in event:
79
+ self._status = event["status"]
80
+ self._final_result = event
81
+
82
+ # Display output
83
+ if show_output and event.get("type") in ("stdout", "stderr", "output"):
84
+ content = event.get("content", "")
85
+ if content:
86
+ self._output_buffer.append(content)
87
+ self._display_output(content, event.get("type", "stdout"))
88
+
89
+ except __import__('json').JSONDecodeError as e:
90
+ logger.warning(f"Failed to parse SSE event: {e}")
91
+ continue
92
+
93
+ except httpx.ConnectError as e:
94
+ raise ConnectionError(f"Failed to connect to stream: {e}")
95
+ except httpx.TimeoutException:
96
+ raise TimeoutError(f"Stream connection timed out")
97
+
98
+ def _display_output(self, content: str, output_type: str = "stdout"):
99
+ """Display output, with special handling for Jupyter."""
100
+ if is_jupyter():
101
+ # In Jupyter, print directly (output appears in cell)
102
+ if output_type == "stderr":
103
+ sys.stderr.write(content)
104
+ sys.stderr.flush()
105
+ else:
106
+ sys.stdout.write(content)
107
+ sys.stdout.flush()
108
+ else:
109
+ # Terminal output
110
+ if output_type == "stderr":
111
+ sys.stderr.write(content)
112
+ sys.stderr.flush()
113
+ else:
114
+ sys.stdout.write(content)
115
+ sys.stdout.flush()
116
+
117
+ def get_full_output(self) -> str:
118
+ """Get the complete buffered output."""
119
+ return "".join(self._output_buffer)
120
+
121
+ def get_final_result(self) -> Dict[str, Any]:
122
+ """Get the final result after streaming completes."""
123
+ return self._final_result
124
+
125
+
126
+ class PodstackError(Exception):
127
+ """Base exception for Podstack SDK errors."""
128
+ pass
129
+
130
+
131
+ class PodstackTimeoutError(PodstackError):
132
+ """Raised when GPU execution times out."""
133
+ def __init__(self, execution_id: str, timeout: int, last_status: str, message: str = None):
134
+ self.execution_id = execution_id
135
+ self.timeout = timeout
136
+ self.last_status = last_status
137
+ self.message = message or f"Execution {execution_id} timed out after {timeout}s (last status: {last_status})"
138
+ super().__init__(self.message)
139
+
140
+
141
+ class PodstackExecutionError(PodstackError):
142
+ """Raised when GPU execution fails."""
143
+ def __init__(self, execution_id: str, error: str, output: str = None):
144
+ self.execution_id = execution_id
145
+ self.error = error
146
+ self.output = output
147
+ message = f"Execution {execution_id} failed: {error}"
148
+ if output:
149
+ message += f"\n\nOutput:\n{output[:1000]}"
150
+ super().__init__(message)
151
+
152
+
153
+ class PodstackProvisioningError(PodstackError):
154
+ """Raised when GPU provisioning fails or takes too long."""
155
+ def __init__(self, execution_id: str, message: str):
156
+ self.execution_id = execution_id
157
+ self.message = message
158
+ super().__init__(message)
159
+
160
+
161
+ @dataclass
162
+ class GPUExecutionResult:
163
+ """Result of a GPU execution."""
164
+ execution_id: str
165
+ status: str
166
+ output: str = ""
167
+ error: Optional[str] = None
168
+ gpu_seconds: float = 0.0
169
+ cost_paise: int = 0
170
+ gpu_type: str = ""
171
+ gpu_count: int = 1
172
+ success: bool = False
173
+
174
+ @classmethod
175
+ def from_dict(cls, data: Dict[str, Any]) -> "GPUExecutionResult":
176
+ return cls(
177
+ execution_id=data.get("execution_id", ""),
178
+ status=data.get("status", "unknown"),
179
+ output=data.get("output", ""),
180
+ error=data.get("error"),
181
+ gpu_seconds=data.get("gpu_seconds", 0.0),
182
+ cost_paise=data.get("actual_cost_paise", 0),
183
+ gpu_type=data.get("gpu_type", ""),
184
+ gpu_count=data.get("gpu_count", 1),
185
+ success=data.get("status") == "completed"
186
+ )
187
+
188
+
189
+ class GPURunner:
190
+ """
191
+ GPU Runner for executing code on remote GPU instances.
192
+
193
+ Usage:
194
+ runner = GPURunner(api_key="...", project_id="...")
195
+
196
+ # Execute code on GPU
197
+ result = runner.run('''
198
+ import torch
199
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
200
+ ''', gpu="L40S", fraction=100)
201
+
202
+ print(result.output)
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ api_key: str = None,
208
+ project_id: str = None,
209
+ api_url: str = None,
210
+ timeout: float = 30.0
211
+ ):
212
+ """
213
+ Initialize the GPU Runner.
214
+
215
+ Args:
216
+ api_key: API key or JWT token for authentication
217
+ project_id: Project ID for billing and tracking
218
+ api_url: Notebook service API URL
219
+ timeout: HTTP request timeout in seconds
220
+ """
221
+ self.api_key = api_key or os.environ.get("PODSTACK_API_KEY")
222
+ self.project_id = project_id or os.environ.get("PODSTACK_PROJECT_ID")
223
+
224
+ # Determine API URL - prefer internal cluster URL if running inside K8s
225
+ if api_url:
226
+ self.api_url = api_url
227
+ elif os.environ.get("PODSTACK_NOTEBOOK_SERVICE_URL"):
228
+ self.api_url = os.environ.get("PODSTACK_NOTEBOOK_SERVICE_URL")
229
+ elif os.environ.get("KUBERNETES_SERVICE_HOST"):
230
+ # Running inside K8s cluster - use internal service URL
231
+ self.api_url = "http://podstack-svc-podstack-services-notebook.podstack.svc.cluster.local:8084"
232
+ else:
233
+ # External access
234
+ self.api_url = "https://cloud.podstack.ai/notebooks"
235
+ self.timeout = timeout
236
+
237
+ if not self.api_key:
238
+ raise ValueError("API key is required. Set PODSTACK_API_KEY or pass api_key.")
239
+ if not self.project_id:
240
+ raise ValueError("Project ID is required. Set PODSTACK_PROJECT_ID or pass project_id.")
241
+
242
+ # Debug info
243
+ if os.environ.get("PODSTACK_DEBUG"):
244
+ print(f"[Podstack Debug] API URL: {self.api_url}")
245
+ print(f"[Podstack Debug] Project ID: {self.project_id}")
246
+
247
+ def _get_headers(self) -> Dict[str, str]:
248
+ """Get request headers."""
249
+ return {
250
+ "Authorization": f"Bearer {self.api_key}",
251
+ "Content-Type": "application/json",
252
+ "User-Agent": "podstack-python-sdk/1.0.3"
253
+ }
254
+
255
+ def _build_annotation(
256
+ self,
257
+ gpu: str = "L40S",
258
+ count: int = 1,
259
+ fraction: int = 100,
260
+ timeout: int = 3600,
261
+ env: str = None,
262
+ pip: Union[str, list] = None,
263
+ uv: Union[str, list] = None,
264
+ conda: Union[str, list] = None,
265
+ requirements: str = None,
266
+ use_uv: bool = False
267
+ ) -> str:
268
+ """Build the @podstack annotation string."""
269
+ parts = [f"#@podstack gpu={gpu}"]
270
+ if count > 1:
271
+ parts.append(f"count={count}")
272
+ if fraction != 100:
273
+ parts.append(f"fraction={fraction}")
274
+ if timeout != 3600:
275
+ parts.append(f"timeout={timeout}")
276
+ if env:
277
+ parts.append(f"env={env}")
278
+
279
+ # Handle pip packages
280
+ if pip:
281
+ if isinstance(pip, list):
282
+ pip = ",".join(pip)
283
+ parts.append(f"pip={pip}")
284
+
285
+ # Handle uv packages
286
+ if uv:
287
+ if isinstance(uv, list):
288
+ uv = ",".join(uv)
289
+ parts.append(f"uv={uv}")
290
+
291
+ # Handle conda packages
292
+ if conda:
293
+ if isinstance(conda, list):
294
+ conda = ",".join(conda)
295
+ parts.append(f"conda={conda}")
296
+
297
+ # Handle requirements.txt
298
+ if requirements:
299
+ parts.append(f"requirements={requirements}")
300
+
301
+ # Use uv for pip installs
302
+ if use_uv:
303
+ parts.append("use_uv=true")
304
+
305
+ return " ".join(parts)
306
+
307
+ def submit(
308
+ self,
309
+ code: str,
310
+ gpu: str = "L40S",
311
+ count: int = 1,
312
+ fraction: int = 100,
313
+ timeout: int = 3600,
314
+ env: str = None,
315
+ pip: Union[str, list] = None,
316
+ uv: Union[str, list] = None,
317
+ conda: Union[str, list] = None,
318
+ requirements: str = None,
319
+ use_uv: bool = False,
320
+ add_annotation: bool = True
321
+ ) -> Dict[str, Any]:
322
+ """
323
+ Submit code for GPU execution (non-blocking).
324
+
325
+ Args:
326
+ code: Python code to execute
327
+ gpu: GPU type (L40S, A100-40G, A100-80G, H100, A10, T4)
328
+ count: Number of GPUs (1-8)
329
+ fraction: GPU time-slice percentage (25, 50, 75, 100)
330
+ timeout: Maximum execution time in seconds
331
+ env: Environment preset (ml, nlp, cv, audio, tabular, rl, scientific)
332
+ pip: Pip packages to install (string "pkg1,pkg2" or list ["pkg1", "pkg2"])
333
+ uv: UV packages to install (faster than pip)
334
+ conda: Conda packages to install (string or list)
335
+ requirements: Path to requirements.txt file to install
336
+ use_uv: Use uv instead of pip for installation (faster)
337
+ add_annotation: Whether to add @podstack annotation to code
338
+
339
+ Returns:
340
+ Dict with execution_id and status
341
+ """
342
+ # Read requirements.txt if provided
343
+ requirements_content = None
344
+ if requirements and os.path.exists(requirements):
345
+ with open(requirements, 'r') as f:
346
+ requirements_content = f.read()
347
+
348
+ # Add annotation if not present
349
+ if add_annotation and not code.strip().startswith("#@podstack"):
350
+ annotation = self._build_annotation(gpu, count, fraction, timeout, env, pip, uv, conda, requirements, use_uv)
351
+ code = f"{annotation}\n\n{code}"
352
+
353
+ # Build installation code for packages
354
+ install_parts = []
355
+
356
+ # Helper function for streaming subprocess output
357
+ install_helper = '''
358
+ # Helper function to stream install output with progress
359
+ import subprocess
360
+ import sys
361
+
362
+ def _stream_install(cmd, description):
363
+ """Run command and stream output in real-time."""
364
+ print(f"\\n{'='*60}")
365
+ print(f"[Podstack] {description}")
366
+ print(f"{'='*60}")
367
+ print(f"$ {' '.join(cmd)}\\n")
368
+ sys.stdout.flush()
369
+
370
+ process = subprocess.Popen(
371
+ cmd,
372
+ stdout=subprocess.PIPE,
373
+ stderr=subprocess.STDOUT,
374
+ text=True,
375
+ bufsize=1
376
+ )
377
+
378
+ for line in process.stdout:
379
+ print(line, end='')
380
+ sys.stdout.flush()
381
+
382
+ process.wait()
383
+ if process.returncode != 0:
384
+ raise subprocess.CalledProcessError(process.returncode, cmd)
385
+
386
+ print(f"\\n[Podstack] ✓ {description} - Complete!")
387
+ print(f"{'='*60}\\n")
388
+ sys.stdout.flush()
389
+ '''
390
+
391
+ # Handle pip packages (use uv if use_uv=True)
392
+ if pip:
393
+ if isinstance(pip, list):
394
+ pip_list = pip
395
+ else:
396
+ pip_list = [p.strip() for p in pip.replace(",", " ").split() if p.strip()]
397
+ if use_uv:
398
+ install_parts.append(f'''
399
+ # Install pip packages using uv (faster)
400
+ _pip_packages = {repr(pip_list)}
401
+ _stream_install(
402
+ ["uv", "pip", "install", "--system", "--progress"] + _pip_packages,
403
+ f"Installing pip packages with uv: {{' '.join(_pip_packages)}}"
404
+ )
405
+ ''')
406
+ else:
407
+ install_parts.append(f'''
408
+ # Install pip packages
409
+ _pip_packages = {repr(pip_list)}
410
+ _stream_install(
411
+ [sys.executable, "-m", "pip", "install", "--progress-bar", "on"] + _pip_packages,
412
+ f"Installing pip packages: {{' '.join(_pip_packages)}}"
413
+ )
414
+ ''')
415
+
416
+ # Handle uv packages (always use uv)
417
+ if uv:
418
+ if isinstance(uv, list):
419
+ uv_list = uv
420
+ else:
421
+ uv_list = [u.strip() for u in uv.replace(",", " ").split() if u.strip()]
422
+ install_parts.append(f'''
423
+ # Install packages using uv (fast package manager)
424
+ _uv_packages = {repr(uv_list)}
425
+ _stream_install(
426
+ ["uv", "pip", "install", "--system", "--progress"] + _uv_packages,
427
+ f"Installing packages with uv: {{' '.join(_uv_packages)}}"
428
+ )
429
+ ''')
430
+
431
+ # Handle conda/mamba packages (using micromamba)
432
+ if conda:
433
+ if isinstance(conda, list):
434
+ conda_list = conda
435
+ else:
436
+ conda_list = [c.strip() for c in conda.replace(",", " ").split() if c.strip()]
437
+ install_parts.append(f'''
438
+ # Install conda packages using micromamba (fast conda alternative)
439
+ import shutil
440
+ _conda_packages = {repr(conda_list)}
441
+ # Use micromamba if available, fall back to mamba, then conda
442
+ if shutil.which("micromamba"):
443
+ _pkg_mgr = "micromamba"
444
+ elif shutil.which("mamba"):
445
+ _pkg_mgr = "mamba"
446
+ else:
447
+ _pkg_mgr = "conda"
448
+ _stream_install(
449
+ [_pkg_mgr, "install", "-y"] + _conda_packages,
450
+ f"Installing conda packages: {{' '.join(_conda_packages)}}"
451
+ )
452
+ ''')
453
+
454
+ # Handle requirements.txt (use uv if use_uv=True)
455
+ if requirements_content:
456
+ if use_uv:
457
+ install_parts.append(f'''
458
+ # Install from requirements.txt using uv (faster)
459
+ _requirements = """{requirements_content}"""
460
+
461
+ with open("/tmp/requirements.txt", "w") as f:
462
+ f.write(_requirements)
463
+
464
+ print("\\n[Podstack] requirements.txt contents:")
465
+ print(_requirements)
466
+
467
+ _stream_install(
468
+ ["uv", "pip", "install", "--system", "--progress", "-r", "/tmp/requirements.txt"],
469
+ "Installing from requirements.txt with uv"
470
+ )
471
+ ''')
472
+ else:
473
+ install_parts.append(f'''
474
+ # Install from requirements.txt
475
+ _requirements = """{requirements_content}"""
476
+
477
+ with open("/tmp/requirements.txt", "w") as f:
478
+ f.write(_requirements)
479
+
480
+ print("\\n[Podstack] requirements.txt contents:")
481
+ print(_requirements)
482
+
483
+ _stream_install(
484
+ [sys.executable, "-m", "pip", "install", "--progress-bar", "on", "-r", "/tmp/requirements.txt"],
485
+ "Installing from requirements.txt"
486
+ )
487
+ ''')
488
+
489
+ # Prepend all installation code with helper function
490
+ if install_parts:
491
+ install_code = install_helper + "\n".join(install_parts) + "\nprint('\\n[Podstack] ✓ All packages installed successfully!\\n')\n\n"
492
+ code = install_code + code
493
+
494
+ url = f"{self.api_url}/api/v1/executions/submit"
495
+
496
+ with httpx.Client(timeout=self.timeout) as client:
497
+ try:
498
+ response = client.post(
499
+ url,
500
+ headers=self._get_headers(),
501
+ json={
502
+ "code": code,
503
+ "project_id": self.project_id
504
+ }
505
+ )
506
+ except httpx.ConnectError as e:
507
+ raise ConnectionError(f"Failed to connect to {url}: {e}")
508
+ except httpx.TimeoutException:
509
+ raise TimeoutError(f"Request to {url} timed out")
510
+
511
+ if response.status_code == 401:
512
+ raise PermissionError("Authentication failed. Check your API key.")
513
+ elif response.status_code == 402:
514
+ raise ValueError("Insufficient balance. Please add funds to your wallet.")
515
+ elif response.status_code >= 400:
516
+ # Try to parse JSON error, fallback to text
517
+ try:
518
+ error_data = response.json()
519
+ error_msg = error_data.get("error", error_data.get("message", response.text))
520
+ except:
521
+ error_msg = response.text[:500] if response.text else f"HTTP {response.status_code}"
522
+ raise RuntimeError(f"Submission failed ({response.status_code}): {error_msg}")
523
+
524
+ # Parse successful response
525
+ try:
526
+ return response.json()
527
+ except:
528
+ raise RuntimeError(f"Invalid JSON response from {url}: {response.text[:200]}")
529
+
530
+ def status(self, execution_id: str) -> Dict[str, Any]:
531
+ """
532
+ Get execution status.
533
+
534
+ Args:
535
+ execution_id: The execution ID
536
+
537
+ Returns:
538
+ Dict with status, queue_position, etc.
539
+ """
540
+ url = f"{self.api_url}/api/v1/executions/{execution_id}/status"
541
+
542
+ with httpx.Client(timeout=self.timeout) as client:
543
+ try:
544
+ response = client.get(url, headers=self._get_headers())
545
+ except httpx.ConnectError as e:
546
+ raise ConnectionError(f"Failed to connect to {url}: {e}")
547
+
548
+ if response.status_code >= 400:
549
+ try:
550
+ error_msg = response.json().get("error", response.text)
551
+ except:
552
+ error_msg = response.text[:500] if response.text else f"HTTP {response.status_code}"
553
+ raise RuntimeError(f"Failed to get status: {error_msg}")
554
+
555
+ try:
556
+ return response.json()
557
+ except:
558
+ raise RuntimeError(f"Invalid JSON response: {response.text[:200]}")
559
+
560
+ def result(self, execution_id: str, timeout: float = None) -> GPUExecutionResult:
561
+ """
562
+ Get execution result (blocks until complete).
563
+
564
+ Args:
565
+ execution_id: The execution ID
566
+ timeout: Max time to wait in seconds
567
+
568
+ Returns:
569
+ GPUExecutionResult object
570
+ """
571
+ url = f"{self.api_url}/api/v1/executions/{execution_id}/result"
572
+ request_timeout = timeout or 3600
573
+
574
+ with httpx.Client(timeout=request_timeout) as client:
575
+ response = client.get(url, headers=self._get_headers())
576
+
577
+ if response.status_code >= 400:
578
+ raise RuntimeError(f"Failed to get result: {response.text}")
579
+
580
+ return GPUExecutionResult.from_dict(response.json())
581
+
582
+ def stream_output(self, execution_id: str, show_output: bool = True) -> Iterator[Dict[str, Any]]:
583
+ """
584
+ Stream real-time output from a running execution.
585
+
586
+ This connects to the Server-Sent Events (SSE) endpoint and yields
587
+ output events as they arrive. Ideal for Jupyter notebooks.
588
+
589
+ Args:
590
+ execution_id: The execution ID to stream
591
+ show_output: Whether to print output to stdout/stderr (default: True)
592
+
593
+ Yields:
594
+ Dict with event data:
595
+ - type: "stdout", "stderr", "output", or "status"
596
+ - content: The output content (for output events)
597
+ - status: Execution status (for status events)
598
+ - timestamp: Event timestamp
599
+ - execution_id: The execution ID
600
+
601
+ Example:
602
+ for event in runner.stream_output("exec_123"):
603
+ if event.get("status") == "completed":
604
+ print(f"Done! Cost: {event.get('cost_cents', 0)/100:.2f}")
605
+ """
606
+ streamer = OutputStreamer(
607
+ execution_id=execution_id,
608
+ api_url=self.api_url,
609
+ headers=self._get_headers(),
610
+ timeout=self.timeout
611
+ )
612
+ yield from streamer.stream(show_output=show_output)
613
+
614
+ def run(
615
+ self,
616
+ code: str,
617
+ gpu: str = "L40S",
618
+ count: int = 1,
619
+ fraction: int = 100,
620
+ timeout: int = 3600,
621
+ env: str = None,
622
+ pip: Union[str, list] = None,
623
+ uv: Union[str, list] = None,
624
+ conda: Union[str, list] = None,
625
+ requirements: str = None,
626
+ use_uv: bool = False,
627
+ wait: bool = True,
628
+ poll_interval: float = 2.0,
629
+ max_retries: int = 3,
630
+ provisioning_timeout: int = 300,
631
+ cancel_on_timeout: bool = True,
632
+ stream: bool = None
633
+ ) -> GPUExecutionResult:
634
+ """
635
+ Execute code on GPU and optionally wait for completion.
636
+
637
+ Args:
638
+ code: Python code to execute
639
+ gpu: GPU type (L40S, A100-40G, A100-80G, H100, A10, T4)
640
+ count: Number of GPUs (1-8)
641
+ fraction: GPU time-slice percentage (25, 50, 75, 100)
642
+ timeout: Maximum execution time in seconds
643
+ env: Environment preset (ml, nlp, cv, audio, tabular, rl, scientific)
644
+ pip: Pip packages to install (string "pkg1,pkg2" or list ["pkg1", "pkg2"])
645
+ uv: UV packages to install (faster than pip)
646
+ conda: Conda packages to install (string or list)
647
+ requirements: Path to requirements.txt file
648
+ use_uv: Use uv instead of pip for installation (faster)
649
+ wait: Whether to wait for completion
650
+ poll_interval: Seconds between status checks when waiting
651
+ max_retries: Max retries for transient network errors (default: 3)
652
+ provisioning_timeout: Max time to wait for GPU provisioning (default: 300s)
653
+ cancel_on_timeout: Whether to cancel execution on timeout (default: True)
654
+ stream: Stream output in real-time (default: True in Jupyter, False otherwise)
655
+
656
+ Returns:
657
+ GPUExecutionResult object
658
+
659
+ Raises:
660
+ PodstackTimeoutError: If execution times out
661
+ PodstackProvisioningError: If GPU provisioning fails or times out
662
+ PodstackExecutionError: If execution fails
663
+ ConnectionError: If unable to connect to API
664
+ ValueError: If parameters are invalid
665
+ """
666
+ # Submit the code
667
+ submission = self.submit(code, gpu, count, fraction, timeout, env, pip, uv, conda, requirements, use_uv)
668
+ execution_id = submission.get("execution_id")
669
+
670
+ if not execution_id:
671
+ raise RuntimeError(f"No execution_id in response: {submission}")
672
+
673
+ print(f"[Podstack] Execution submitted: {execution_id}")
674
+
675
+ if not wait:
676
+ return GPUExecutionResult(
677
+ execution_id=execution_id,
678
+ status="submitted",
679
+ gpu_type=gpu,
680
+ gpu_count=count
681
+ )
682
+
683
+ # Determine if we should stream output
684
+ should_stream = stream if stream is not None else is_jupyter()
685
+
686
+ if should_stream:
687
+ return self._run_with_streaming(execution_id, gpu, count, timeout, max_retries, cancel_on_timeout)
688
+ else:
689
+ return self._run_with_polling(execution_id, gpu, count, timeout, poll_interval, max_retries, provisioning_timeout, cancel_on_timeout)
690
+
691
+ def _run_with_streaming(
692
+ self,
693
+ execution_id: str,
694
+ gpu: str,
695
+ count: int,
696
+ timeout: int,
697
+ max_retries: int,
698
+ cancel_on_timeout: bool
699
+ ) -> GPUExecutionResult:
700
+ """Run execution with real-time output streaming."""
701
+ print(f"[Podstack] Waiting for GPU runner ({gpu} x{count})...")
702
+
703
+ start_time = time.time()
704
+ output_buffer = []
705
+ final_status = "unknown"
706
+ final_event = {}
707
+
708
+ try:
709
+ for event in self.stream_output(execution_id, show_output=True):
710
+ elapsed = time.time() - start_time
711
+ if elapsed > timeout:
712
+ if cancel_on_timeout:
713
+ try:
714
+ self.cancel(execution_id)
715
+ print(f"\n[Podstack] Execution cancelled due to timeout")
716
+ except Exception as e:
717
+ logger.warning(f"Failed to cancel execution: {e}")
718
+
719
+ raise PodstackTimeoutError(
720
+ execution_id=execution_id,
721
+ timeout=timeout,
722
+ last_status=final_status,
723
+ message=(
724
+ f"Execution {execution_id} timed out after {timeout}s.\n"
725
+ f"Last status: {final_status}\n"
726
+ f"Suggestions:\n"
727
+ f" - Increase timeout parameter (current: {timeout}s)\n"
728
+ f" - Check execution logs: runner.result('{execution_id}')"
729
+ )
730
+ )
731
+
732
+ # Track output
733
+ if event.get("type") in ("stdout", "stderr", "output"):
734
+ content = event.get("content", "")
735
+ if content:
736
+ output_buffer.append(content)
737
+
738
+ # Track status
739
+ if "status" in event:
740
+ new_status = event["status"]
741
+ if new_status != final_status:
742
+ final_status = new_status
743
+ if final_status == "provisioning":
744
+ print(f"\n[Podstack] Provisioning GPU runner...")
745
+ elif final_status == "running":
746
+ print(f"\n[Podstack] Running on GPU...")
747
+
748
+ # Check for terminal status
749
+ if final_status in ("completed", "failed", "timeout", "cancelled"):
750
+ final_event = event
751
+ break
752
+
753
+ except (ConnectionError, httpx.ConnectError) as e:
754
+ # Try to recover and get the result
755
+ logger.warning(f"Stream connection lost: {e}")
756
+ print(f"\n[Podstack] Stream connection lost, fetching final result...")
757
+
758
+ # Get final result
759
+ result = None
760
+ for attempt in range(max_retries):
761
+ try:
762
+ result = self.result(execution_id)
763
+ break
764
+ except (ConnectionError, httpx.ConnectError, httpx.TimeoutException) as e:
765
+ if attempt == max_retries - 1:
766
+ raise ConnectionError(f"Failed to get result after {max_retries} attempts: {e}")
767
+ time.sleep(2 * (attempt + 1))
768
+
769
+ if result.success:
770
+ print(f"\n[Podstack] Completed in {result.gpu_seconds:.1f}s (cost: ₹{result.cost_paise/100:.2f})")
771
+ else:
772
+ error_msg = result.error or 'Unknown error'
773
+ print(f"\n[Podstack] Failed: {error_msg}")
774
+
775
+ return result
776
+
777
+ def _run_with_polling(
778
+ self,
779
+ execution_id: str,
780
+ gpu: str,
781
+ count: int,
782
+ timeout: int,
783
+ poll_interval: float,
784
+ max_retries: int,
785
+ provisioning_timeout: int,
786
+ cancel_on_timeout: bool
787
+ ) -> GPUExecutionResult:
788
+ """Run execution with polling (non-streaming mode)."""
789
+ print(f"[Podstack] Waiting for GPU runner ({gpu} x{count})...")
790
+ start_time = time.time()
791
+ provisioning_start = None
792
+ last_status = ""
793
+ consecutive_errors = 0
794
+ last_error = None
795
+
796
+ while True:
797
+ elapsed = time.time() - start_time
798
+
799
+ # Check for overall timeout
800
+ if elapsed > timeout:
801
+ if cancel_on_timeout:
802
+ try:
803
+ self.cancel(execution_id)
804
+ print(f"[Podstack] Execution cancelled due to timeout")
805
+ except Exception as e:
806
+ logger.warning(f"Failed to cancel execution: {e}")
807
+
808
+ raise PodstackTimeoutError(
809
+ execution_id=execution_id,
810
+ timeout=timeout,
811
+ last_status=last_status,
812
+ message=(
813
+ f"Execution {execution_id} timed out after {timeout}s.\n"
814
+ f"Last status: {last_status}\n"
815
+ f"Suggestions:\n"
816
+ f" - Increase timeout parameter (current: {timeout}s)\n"
817
+ f" - Check execution logs: runner.result('{execution_id}')\n"
818
+ f" - Cancel and retry: runner.cancel('{execution_id}')"
819
+ )
820
+ )
821
+
822
+ # Try to get status with retry on transient errors
823
+ try:
824
+ status_data = self.status(execution_id)
825
+ consecutive_errors = 0 # Reset on success
826
+ except (ConnectionError, httpx.ConnectError, httpx.TimeoutException) as e:
827
+ consecutive_errors += 1
828
+ last_error = str(e)
829
+ if consecutive_errors >= max_retries:
830
+ raise ConnectionError(
831
+ f"Lost connection to API after {consecutive_errors} retries. "
832
+ f"Last error: {last_error}\n"
833
+ f"Execution ID: {execution_id}"
834
+ )
835
+ logger.warning(f"Transient error (attempt {consecutive_errors}/{max_retries}): {e}")
836
+ time.sleep(poll_interval * consecutive_errors) # Exponential backoff
837
+ continue
838
+ except Exception as e:
839
+ # For other errors, fail immediately
840
+ raise RuntimeError(f"Failed to get status for {execution_id}: {e}")
841
+
842
+ current_status = status_data.get("status", "unknown")
843
+
844
+ if current_status != last_status:
845
+ last_status = current_status
846
+ if current_status == "pending":
847
+ print(f"[Podstack] Pending...")
848
+ elif current_status == "queued":
849
+ pos = status_data.get("queue_position", "?")
850
+ print(f"[Podstack] Queued (position: {pos})")
851
+ elif current_status == "provisioning":
852
+ provisioning_start = time.time()
853
+ print(f"[Podstack] Provisioning GPU runner...")
854
+ elif current_status == "running":
855
+ print(f"[Podstack] Running on GPU...")
856
+ elif current_status == "streaming":
857
+ print(f"[Podstack] Streaming output...")
858
+ elif current_status in ("completed", "failed", "timeout", "cancelled"):
859
+ pass # Terminal states handled below
860
+ else:
861
+ # Unknown status - log but continue
862
+ logger.warning(f"Unknown status: {current_status}")
863
+ print(f"[Podstack] Status: {current_status}")
864
+
865
+ # Check for provisioning timeout
866
+ if provisioning_start and current_status == "provisioning":
867
+ provisioning_elapsed = time.time() - provisioning_start
868
+ if provisioning_elapsed > provisioning_timeout:
869
+ if cancel_on_timeout:
870
+ try:
871
+ self.cancel(execution_id)
872
+ except Exception:
873
+ pass
874
+ raise PodstackProvisioningError(
875
+ execution_id=execution_id,
876
+ message=(
877
+ f"GPU provisioning timed out after {provisioning_timeout}s for {execution_id}.\n"
878
+ f"This may indicate:\n"
879
+ f" - No available {gpu} GPUs at this time\n"
880
+ f" - Cluster resource constraints\n"
881
+ f" - Network issues in the cluster\n"
882
+ f"Suggestions:\n"
883
+ f" - Try a different GPU type\n"
884
+ f" - Retry after a few minutes\n"
885
+ f" - Contact support if the issue persists"
886
+ )
887
+ )
888
+
889
+ # Check for terminal states
890
+ if current_status in ("completed", "failed", "timeout", "cancelled"):
891
+ break
892
+
893
+ time.sleep(poll_interval)
894
+
895
+ # Get final result with retry
896
+ result = None
897
+ for attempt in range(max_retries):
898
+ try:
899
+ result = self.result(execution_id)
900
+ break
901
+ except (ConnectionError, httpx.ConnectError, httpx.TimeoutException) as e:
902
+ if attempt == max_retries - 1:
903
+ raise ConnectionError(f"Failed to get result after {max_retries} attempts: {e}")
904
+ time.sleep(poll_interval * (attempt + 1))
905
+
906
+ if result.success:
907
+ print(f"[Podstack] Completed in {result.gpu_seconds:.1f}s (cost: ₹{result.cost_paise/100:.2f})")
908
+ else:
909
+ error_msg = result.error or 'Unknown error'
910
+ print(f"[Podstack] Failed: {error_msg}")
911
+ # Include partial output in the error for debugging
912
+ if result.output:
913
+ print(f"[Podstack] Output (last 500 chars):\n{result.output[-500:]}")
914
+
915
+ return result
916
+
917
+ def cancel(self, execution_id: str) -> Dict[str, Any]:
918
+ """
919
+ Cancel a running execution.
920
+
921
+ Args:
922
+ execution_id: The execution ID to cancel
923
+
924
+ Returns:
925
+ Dict with cancellation status
926
+ """
927
+ url = f"{self.api_url}/api/v1/executions/{execution_id}/cancel"
928
+
929
+ with httpx.Client(timeout=self.timeout) as client:
930
+ try:
931
+ response = client.post(url, headers=self._get_headers())
932
+ except httpx.ConnectError as e:
933
+ raise ConnectionError(f"Failed to connect to {url}: {e}")
934
+ except httpx.TimeoutException:
935
+ raise TimeoutError(f"Request to {url} timed out")
936
+
937
+ if response.status_code >= 400:
938
+ try:
939
+ error_msg = response.json().get("error", response.text)
940
+ except:
941
+ error_msg = response.text[:500] if response.text else f"HTTP {response.status_code}"
942
+ raise RuntimeError(f"Failed to cancel execution: {error_msg}")
943
+
944
+ try:
945
+ return response.json()
946
+ except:
947
+ return {"status": "cancelled"}
948
+
949
+ def run_function(
950
+ self,
951
+ func: Callable,
952
+ gpu: str = "L40S",
953
+ count: int = 1,
954
+ fraction: int = 100,
955
+ timeout: int = 3600,
956
+ env: str = None,
957
+ pip: str = None,
958
+ **kwargs
959
+ ) -> GPUExecutionResult:
960
+ """
961
+ Execute a function on remote GPU.
962
+
963
+ Args:
964
+ func: Function to execute
965
+ gpu: GPU type
966
+ count: Number of GPUs
967
+ fraction: GPU fraction percentage
968
+ timeout: Max execution time
969
+ env: Environment preset
970
+ pip: Additional pip packages
971
+ **kwargs: Arguments to pass to the function
972
+
973
+ Returns:
974
+ GPUExecutionResult object
975
+ """
976
+ # Get function source code
977
+ try:
978
+ source = inspect.getsource(func)
979
+ except OSError:
980
+ raise ValueError("Cannot get source code for function. Define it in a file, not interactively.")
981
+
982
+ # Dedent if needed
983
+ source = textwrap.dedent(source)
984
+
985
+ # Build the execution code
986
+ func_name = func.__name__
987
+
988
+ # Serialize kwargs
989
+ kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items())
990
+
991
+ code = f"""
992
+ {source}
993
+
994
+ # Execute the function
995
+ __result__ = {func_name}({kwargs_str})
996
+ if __result__ is not None:
997
+ print(__result__)
998
+ """
999
+
1000
+ return self.run(
1001
+ code=code,
1002
+ gpu=gpu,
1003
+ count=count,
1004
+ fraction=fraction,
1005
+ timeout=timeout,
1006
+ env=env,
1007
+ pip=pip
1008
+ )
1009
+
1010
+
1011
+ # Global runner instance
1012
+ _runner: Optional[GPURunner] = None
1013
+
1014
+
1015
+ def init(
1016
+ api_key: str = None,
1017
+ project_id: str = None,
1018
+ api_url: str = None
1019
+ ):
1020
+ """
1021
+ Initialize the global GPU runner.
1022
+
1023
+ Args:
1024
+ api_key: API key for authentication
1025
+ project_id: Project ID for billing
1026
+ api_url: Notebook service URL
1027
+ """
1028
+ global _runner
1029
+ _runner = GPURunner(
1030
+ api_key=api_key,
1031
+ project_id=project_id,
1032
+ api_url=api_url
1033
+ )
1034
+
1035
+
1036
+ def get_runner() -> GPURunner:
1037
+ """Get the global GPU runner instance."""
1038
+ global _runner
1039
+ if _runner is None:
1040
+ _runner = GPURunner()
1041
+ return _runner
1042
+
1043
+
1044
+ def run(
1045
+ code: str,
1046
+ gpu: str = "L40S",
1047
+ count: int = 1,
1048
+ fraction: int = 100,
1049
+ timeout: int = 3600,
1050
+ env: str = None,
1051
+ pip: Union[str, list] = None,
1052
+ uv: Union[str, list] = None,
1053
+ conda: Union[str, list] = None,
1054
+ requirements: str = None,
1055
+ use_uv: bool = False,
1056
+ wait: bool = True,
1057
+ stream: bool = None
1058
+ ) -> GPUExecutionResult:
1059
+ """
1060
+ Execute code on remote GPU.
1061
+
1062
+ Args:
1063
+ code: Python code to execute
1064
+ gpu: GPU type (L40S, A100-40G, A100-80G, H100, A10, T4)
1065
+ count: Number of GPUs (1-8)
1066
+ fraction: GPU fraction percentage (25, 50, 75, 100)
1067
+ timeout: Max execution time in seconds
1068
+ env: Environment preset (ml, nlp, cv, audio, tabular, rl, scientific)
1069
+ pip: Pip packages - string "pkg1,pkg2" or list ["pkg1", "pkg2"]
1070
+ uv: UV packages (faster than pip) - string or list
1071
+ conda: Conda packages - string or list
1072
+ requirements: Path to requirements.txt file
1073
+ use_uv: Use uv instead of pip for all installations (faster)
1074
+ wait: Whether to wait for completion
1075
+ stream: Stream output in real-time (default: True in Jupyter, False otherwise)
1076
+
1077
+ Returns:
1078
+ GPUExecutionResult object
1079
+
1080
+ Examples:
1081
+ # Install single package
1082
+ podstack.run_on_gpu(code, pip="transformers")
1083
+
1084
+ # Install multiple packages
1085
+ podstack.run_on_gpu(code, pip=["torch", "transformers", "datasets"])
1086
+
1087
+ # Install with uv (faster)
1088
+ podstack.run_on_gpu(code, uv=["torch", "transformers"])
1089
+
1090
+ # Use uv for all pip installs (faster)
1091
+ podstack.run_on_gpu(code, pip=["torch"], use_uv=True)
1092
+
1093
+ # Install from requirements.txt
1094
+ podstack.run_on_gpu(code, requirements="requirements.txt")
1095
+
1096
+ # Install from requirements.txt using uv (faster)
1097
+ podstack.run_on_gpu(code, requirements="requirements.txt", use_uv=True)
1098
+
1099
+ # Combine pip and conda
1100
+ podstack.run_on_gpu(code, pip="transformers", conda="cudatoolkit=11.8")
1101
+
1102
+ # Force streaming in non-Jupyter environment
1103
+ podstack.run(code, stream=True)
1104
+ """
1105
+ return get_runner().run(
1106
+ code=code,
1107
+ gpu=gpu,
1108
+ count=count,
1109
+ fraction=fraction,
1110
+ timeout=timeout,
1111
+ env=env,
1112
+ pip=pip,
1113
+ uv=uv,
1114
+ conda=conda,
1115
+ requirements=requirements,
1116
+ use_uv=use_uv,
1117
+ wait=wait,
1118
+ stream=stream
1119
+ )
1120
+
1121
+
1122
+ def stream_output(execution_id: str, show_output: bool = True) -> Iterator[Dict[str, Any]]:
1123
+ """
1124
+ Stream real-time output from a running execution.
1125
+
1126
+ This connects to the Server-Sent Events (SSE) endpoint and yields
1127
+ output events as they arrive. Ideal for Jupyter notebooks.
1128
+
1129
+ Args:
1130
+ execution_id: The execution ID to stream
1131
+ show_output: Whether to print output to stdout/stderr (default: True)
1132
+
1133
+ Yields:
1134
+ Dict with event data
1135
+
1136
+ Example:
1137
+ for event in podstack.stream_output("exec_123"):
1138
+ if event.get("status") == "completed":
1139
+ print(f"Done!")
1140
+ """
1141
+ return get_runner().stream_output(execution_id, show_output=show_output)