podstack 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- podstack/__init__.py +222 -0
- podstack/annotations.py +725 -0
- podstack/client.py +322 -0
- podstack/exceptions.py +125 -0
- podstack/execution.py +291 -0
- podstack/gpu_runner.py +1141 -0
- podstack/models.py +274 -0
- podstack/notebook.py +410 -0
- podstack/registry/__init__.py +402 -0
- podstack/registry/client.py +957 -0
- podstack/registry/exceptions.py +107 -0
- podstack/registry/experiment.py +227 -0
- podstack/registry/model.py +273 -0
- podstack/registry/model_utils.py +231 -0
- podstack-1.2.0.dist-info/METADATA +299 -0
- podstack-1.2.0.dist-info/RECORD +27 -0
- podstack-1.2.0.dist-info/WHEEL +5 -0
- podstack-1.2.0.dist-info/licenses/LICENSE +21 -0
- podstack-1.2.0.dist-info/top_level.txt +2 -0
- podstack_gpu/__init__.py +126 -0
- podstack_gpu/app.py +675 -0
- podstack_gpu/exceptions.py +35 -0
- podstack_gpu/image.py +325 -0
- podstack_gpu/runner.py +746 -0
- podstack_gpu/secret.py +189 -0
- podstack_gpu/utils.py +203 -0
- podstack_gpu/volume.py +198 -0
podstack_gpu/runner.py
ADDED
|
@@ -0,0 +1,746 @@
|
|
|
1
|
+
"""GPU Runner - Main interface for running code on cloud GPUs."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import json
|
|
6
|
+
import threading
|
|
7
|
+
from typing import Optional, Dict, Any, Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
|
|
12
|
+
from .exceptions import (
|
|
13
|
+
PodstackError,
|
|
14
|
+
AuthenticationError,
|
|
15
|
+
ValidationError,
|
|
16
|
+
ExecutionError,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class StatusUpdate:
|
|
22
|
+
"""Real-time status update from execution."""
|
|
23
|
+
execution_id: str
|
|
24
|
+
status: str
|
|
25
|
+
timestamp: str
|
|
26
|
+
gpu_type: Optional[str] = None
|
|
27
|
+
gpu_count: Optional[int] = None
|
|
28
|
+
gpu_fraction: Optional[int] = None
|
|
29
|
+
queue_position: Optional[int] = None
|
|
30
|
+
estimated_wait_seconds: Optional[int] = None
|
|
31
|
+
started_at: Optional[str] = None
|
|
32
|
+
completed_at: Optional[str] = None
|
|
33
|
+
gpu_seconds: Optional[float] = None
|
|
34
|
+
cost_cents: Optional[int] = None
|
|
35
|
+
error: Optional[str] = None
|
|
36
|
+
progress: Optional[float] = None
|
|
37
|
+
message: Optional[str] = None
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_dict(cls, data: Dict[str, Any]) -> "StatusUpdate":
|
|
41
|
+
"""Create StatusUpdate from dictionary."""
|
|
42
|
+
return cls(
|
|
43
|
+
execution_id=data.get("execution_id", ""),
|
|
44
|
+
status=data.get("status", "unknown"),
|
|
45
|
+
timestamp=data.get("timestamp", ""),
|
|
46
|
+
gpu_type=data.get("gpu_type"),
|
|
47
|
+
gpu_count=data.get("gpu_count"),
|
|
48
|
+
gpu_fraction=data.get("gpu_fraction"),
|
|
49
|
+
queue_position=data.get("queue_position"),
|
|
50
|
+
estimated_wait_seconds=data.get("estimated_wait_seconds"),
|
|
51
|
+
started_at=data.get("started_at"),
|
|
52
|
+
completed_at=data.get("completed_at"),
|
|
53
|
+
gpu_seconds=data.get("gpu_seconds"),
|
|
54
|
+
cost_cents=data.get("cost_cents"),
|
|
55
|
+
error=data.get("error"),
|
|
56
|
+
progress=data.get("progress"),
|
|
57
|
+
message=data.get("message"),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def is_terminal(self) -> bool:
|
|
62
|
+
"""Check if this is a terminal status."""
|
|
63
|
+
return self.status in ("completed", "failed", "timeout", "cancelled")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class ExecutionResult:
|
|
68
|
+
"""Result of a GPU execution."""
|
|
69
|
+
execution_id: str
|
|
70
|
+
status: str
|
|
71
|
+
output: str
|
|
72
|
+
gpu_seconds: float
|
|
73
|
+
cost_cents: int
|
|
74
|
+
error: Optional[str] = None
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def success(self) -> bool:
|
|
78
|
+
return self.status == "completed"
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def cost_inr(self) -> float:
|
|
82
|
+
"""Cost in INR (paise converted to rupees)."""
|
|
83
|
+
return self.cost_cents / 100 if self.cost_cents else 0.0
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def cost_dollars(self) -> float:
|
|
87
|
+
"""Deprecated: Use cost_inr instead. Returns cost in INR for backwards compatibility."""
|
|
88
|
+
return self.cost_inr
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class GPURunner:
|
|
92
|
+
"""
|
|
93
|
+
Run Python code on cloud GPUs.
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
gpu = GPURunner(token, project_id)
|
|
97
|
+
result = gpu.run('''
|
|
98
|
+
import torch
|
|
99
|
+
print(torch.cuda.is_available())
|
|
100
|
+
''', gpu="L40S", fraction=25)
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
DEFAULT_API_URL = "https://cloud.podstack.ai/notebooks"
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
token: str,
|
|
108
|
+
project_id: str,
|
|
109
|
+
api_url: str = None,
|
|
110
|
+
timeout: int = 300,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Initialize GPU Runner.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
token: API token (psk_xxx for platform tokens, pjt_xxx for project tokens)
|
|
117
|
+
project_id: Project UUID
|
|
118
|
+
api_url: API base URL (optional, defaults to PODSTACK_API_URL env var or production URL)
|
|
119
|
+
timeout: Default execution timeout in seconds
|
|
120
|
+
"""
|
|
121
|
+
self.token = token
|
|
122
|
+
self.project_id = project_id
|
|
123
|
+
# Priority: explicit api_url > PODSTACK_API_URL env var > default URL
|
|
124
|
+
self.api_url = (api_url or os.environ.get("PODSTACK_API_URL") or self.DEFAULT_API_URL).rstrip("/")
|
|
125
|
+
self.default_timeout = timeout
|
|
126
|
+
|
|
127
|
+
self._session = requests.Session()
|
|
128
|
+
self._session.headers.update({
|
|
129
|
+
"Authorization": f"Bearer {token}",
|
|
130
|
+
"Content-Type": "application/json",
|
|
131
|
+
})
|
|
132
|
+
|
|
133
|
+
def run(
|
|
134
|
+
self,
|
|
135
|
+
code: str,
|
|
136
|
+
gpu: str = "L40S",
|
|
137
|
+
count: int = 1,
|
|
138
|
+
fraction: int = 100,
|
|
139
|
+
timeout: int = None,
|
|
140
|
+
memory: int = None,
|
|
141
|
+
env: str = None,
|
|
142
|
+
pip: list = None,
|
|
143
|
+
wait: bool = True,
|
|
144
|
+
poll_interval: float = 2.0,
|
|
145
|
+
) -> ExecutionResult:
|
|
146
|
+
"""
|
|
147
|
+
Run code on a cloud GPU.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
code: Python code to execute
|
|
151
|
+
gpu: GPU type (A10, L40, L40S, A100-40G, A100-80G, H100)
|
|
152
|
+
count: Number of GPUs (1-8)
|
|
153
|
+
fraction: GPU fraction percentage (25, 50, 75, 100)
|
|
154
|
+
timeout: Max execution time in seconds
|
|
155
|
+
memory: GPU memory limit in GB (optional)
|
|
156
|
+
env: Environment preset (ml, cv, nlp, audio, tabular, rl, scientific)
|
|
157
|
+
pip: List of pip packages to install
|
|
158
|
+
wait: If True, wait for completion. If False, return immediately.
|
|
159
|
+
poll_interval: Seconds between status checks when waiting
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
ExecutionResult with output and status
|
|
163
|
+
"""
|
|
164
|
+
timeout = timeout or self.default_timeout
|
|
165
|
+
|
|
166
|
+
# Build annotation
|
|
167
|
+
annotation = self._build_annotation(
|
|
168
|
+
gpu=gpu,
|
|
169
|
+
count=count,
|
|
170
|
+
fraction=fraction,
|
|
171
|
+
timeout=timeout,
|
|
172
|
+
memory=memory,
|
|
173
|
+
env=env,
|
|
174
|
+
pip=pip,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Prepend annotation to code
|
|
178
|
+
full_code = f"{annotation}\n{code}"
|
|
179
|
+
|
|
180
|
+
# Submit execution
|
|
181
|
+
execution_id = self._submit(full_code)
|
|
182
|
+
|
|
183
|
+
print(f"Submitted execution: {execution_id}")
|
|
184
|
+
print(f" GPU: {gpu} @ {fraction}%")
|
|
185
|
+
|
|
186
|
+
if not wait:
|
|
187
|
+
return ExecutionResult(
|
|
188
|
+
execution_id=execution_id,
|
|
189
|
+
status="submitted",
|
|
190
|
+
output="",
|
|
191
|
+
gpu_seconds=0,
|
|
192
|
+
cost_cents=0,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Wait for completion
|
|
196
|
+
return self._wait_for_completion(execution_id, poll_interval, timeout)
|
|
197
|
+
|
|
198
|
+
def _build_annotation(
|
|
199
|
+
self,
|
|
200
|
+
gpu: str,
|
|
201
|
+
count: int,
|
|
202
|
+
fraction: int,
|
|
203
|
+
timeout: int,
|
|
204
|
+
memory: int = None,
|
|
205
|
+
env: str = None,
|
|
206
|
+
pip: list = None,
|
|
207
|
+
) -> str:
|
|
208
|
+
"""Build the GPU annotation comment."""
|
|
209
|
+
parts = [f"#@podstack gpu={gpu} count={count} fraction={fraction} timeout={timeout}"]
|
|
210
|
+
|
|
211
|
+
if memory:
|
|
212
|
+
parts.append(f"memory={memory}")
|
|
213
|
+
if env:
|
|
214
|
+
parts.append(f"env={env}")
|
|
215
|
+
if pip:
|
|
216
|
+
parts.append(f"pip={','.join(pip)}")
|
|
217
|
+
|
|
218
|
+
return " ".join(parts)
|
|
219
|
+
|
|
220
|
+
def _submit(self, code: str) -> str:
|
|
221
|
+
"""Submit code for execution."""
|
|
222
|
+
url = f"{self.api_url}/api/v1/executions/submit"
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
response = self._session.post(
|
|
226
|
+
url,
|
|
227
|
+
params={
|
|
228
|
+
"code": code,
|
|
229
|
+
"project_id": self.project_id,
|
|
230
|
+
},
|
|
231
|
+
)
|
|
232
|
+
except requests.RequestException as e:
|
|
233
|
+
raise PodstackError(f"Failed to connect to API: {e}")
|
|
234
|
+
|
|
235
|
+
if response.status_code == 401:
|
|
236
|
+
raise AuthenticationError("Invalid or expired token")
|
|
237
|
+
elif response.status_code == 403:
|
|
238
|
+
raise AuthenticationError("No access to this project")
|
|
239
|
+
elif response.status_code == 400:
|
|
240
|
+
data = response.json()
|
|
241
|
+
raise ValidationError(data.get("detail", "Invalid request"))
|
|
242
|
+
elif response.status_code != 202:
|
|
243
|
+
raise PodstackError(f"API error: {response.status_code} - {response.text}")
|
|
244
|
+
|
|
245
|
+
data = response.json()
|
|
246
|
+
return data["execution_id"]
|
|
247
|
+
|
|
248
|
+
def _wait_for_completion(
|
|
249
|
+
self,
|
|
250
|
+
execution_id: str,
|
|
251
|
+
poll_interval: float,
|
|
252
|
+
timeout: int,
|
|
253
|
+
) -> ExecutionResult:
|
|
254
|
+
"""Wait for execution completion using real-time SSE streaming."""
|
|
255
|
+
# Try SSE streaming first for real-time updates
|
|
256
|
+
try:
|
|
257
|
+
return self._wait_with_sse(execution_id, timeout)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
print(f" SSE streaming unavailable ({e}), falling back to polling...")
|
|
260
|
+
return self._wait_with_polling(execution_id, poll_interval, timeout)
|
|
261
|
+
|
|
262
|
+
def _wait_with_sse(
|
|
263
|
+
self,
|
|
264
|
+
execution_id: str,
|
|
265
|
+
timeout: int,
|
|
266
|
+
) -> ExecutionResult:
|
|
267
|
+
"""Wait for completion using Server-Sent Events for real-time updates."""
|
|
268
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/status/stream"
|
|
269
|
+
result_url = f"{self.api_url}/api/v1/executions/{execution_id}/result"
|
|
270
|
+
|
|
271
|
+
last_status = None
|
|
272
|
+
final_update: Optional[StatusUpdate] = None
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
# Use stream=True for SSE
|
|
276
|
+
response = self._session.get(url, stream=True, timeout=timeout + 60)
|
|
277
|
+
if response.status_code != 200:
|
|
278
|
+
raise PodstackError(f"Failed to connect to status stream: {response.text}")
|
|
279
|
+
|
|
280
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
281
|
+
if line is None:
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
line = line.strip()
|
|
285
|
+
|
|
286
|
+
# Skip empty lines and comments (heartbeat)
|
|
287
|
+
if not line or line.startswith(":"):
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
# Parse SSE data
|
|
291
|
+
if line.startswith("data:"):
|
|
292
|
+
data_str = line[5:].strip()
|
|
293
|
+
try:
|
|
294
|
+
data = json.loads(data_str)
|
|
295
|
+
update = StatusUpdate.from_dict(data)
|
|
296
|
+
|
|
297
|
+
# Print status updates
|
|
298
|
+
if update.status != last_status and update.status != "heartbeat":
|
|
299
|
+
self._print_status_update(update)
|
|
300
|
+
last_status = update.status
|
|
301
|
+
|
|
302
|
+
# Check for terminal state
|
|
303
|
+
if update.is_terminal:
|
|
304
|
+
final_update = update
|
|
305
|
+
break
|
|
306
|
+
|
|
307
|
+
except json.JSONDecodeError:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
except requests.exceptions.Timeout:
|
|
311
|
+
raise ExecutionError(
|
|
312
|
+
f"Timed out waiting for execution after {timeout}s",
|
|
313
|
+
execution_id=execution_id,
|
|
314
|
+
status="timeout",
|
|
315
|
+
)
|
|
316
|
+
except requests.exceptions.RequestException as e:
|
|
317
|
+
raise PodstackError(f"Connection error during status streaming: {e}")
|
|
318
|
+
|
|
319
|
+
if final_update is None:
|
|
320
|
+
raise PodstackError("Status stream ended without terminal state")
|
|
321
|
+
|
|
322
|
+
# Get final result with output
|
|
323
|
+
output = ""
|
|
324
|
+
try:
|
|
325
|
+
response = self._session.get(result_url)
|
|
326
|
+
if response.status_code == 200:
|
|
327
|
+
result_data = response.json()
|
|
328
|
+
output = "\n".join(
|
|
329
|
+
o.get("content", "")
|
|
330
|
+
for o in result_data.get("output", [])
|
|
331
|
+
if o.get("output_type") in ("stdout", "result")
|
|
332
|
+
)
|
|
333
|
+
except Exception:
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
result = ExecutionResult(
|
|
337
|
+
execution_id=execution_id,
|
|
338
|
+
status=final_update.status,
|
|
339
|
+
output=output,
|
|
340
|
+
gpu_seconds=final_update.gpu_seconds or 0,
|
|
341
|
+
cost_cents=final_update.cost_cents or 0,
|
|
342
|
+
error=final_update.error,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Print final status
|
|
346
|
+
self._print_result(result)
|
|
347
|
+
|
|
348
|
+
if final_update.status == "failed":
|
|
349
|
+
raise ExecutionError(
|
|
350
|
+
result.error or "Execution failed",
|
|
351
|
+
execution_id=execution_id,
|
|
352
|
+
status=final_update.status,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
return result
|
|
356
|
+
|
|
357
|
+
def _wait_with_polling(
|
|
358
|
+
self,
|
|
359
|
+
execution_id: str,
|
|
360
|
+
poll_interval: float,
|
|
361
|
+
timeout: int,
|
|
362
|
+
) -> ExecutionResult:
|
|
363
|
+
"""Fallback: Poll for execution completion."""
|
|
364
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/status"
|
|
365
|
+
result_url = f"{self.api_url}/api/v1/executions/{execution_id}/result"
|
|
366
|
+
|
|
367
|
+
start_time = time.time()
|
|
368
|
+
last_status = None
|
|
369
|
+
|
|
370
|
+
while True:
|
|
371
|
+
elapsed = time.time() - start_time
|
|
372
|
+
if elapsed > timeout + 60: # Allow extra time for provisioning
|
|
373
|
+
raise ExecutionError(
|
|
374
|
+
f"Timed out waiting for execution after {elapsed:.0f}s",
|
|
375
|
+
execution_id=execution_id,
|
|
376
|
+
status="timeout",
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
response = self._session.get(url)
|
|
381
|
+
if response.status_code != 200:
|
|
382
|
+
raise PodstackError(f"Failed to get status: {response.text}")
|
|
383
|
+
|
|
384
|
+
data = response.json()
|
|
385
|
+
status = data["status"]
|
|
386
|
+
|
|
387
|
+
# Print status updates
|
|
388
|
+
if status != last_status:
|
|
389
|
+
self._print_status(status, data)
|
|
390
|
+
last_status = status
|
|
391
|
+
|
|
392
|
+
# Check terminal states
|
|
393
|
+
if status in ("completed", "failed", "timeout", "cancelled"):
|
|
394
|
+
break
|
|
395
|
+
|
|
396
|
+
except requests.RequestException as e:
|
|
397
|
+
print(f" Warning: {e}")
|
|
398
|
+
|
|
399
|
+
time.sleep(poll_interval)
|
|
400
|
+
|
|
401
|
+
# Get final result
|
|
402
|
+
try:
|
|
403
|
+
response = self._session.get(result_url)
|
|
404
|
+
if response.status_code == 200:
|
|
405
|
+
result_data = response.json()
|
|
406
|
+
output = "\n".join(
|
|
407
|
+
o.get("content", "")
|
|
408
|
+
for o in result_data.get("output", [])
|
|
409
|
+
if o.get("output_type") in ("stdout", "result")
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
output = ""
|
|
413
|
+
result_data = data
|
|
414
|
+
except Exception:
|
|
415
|
+
output = ""
|
|
416
|
+
result_data = data
|
|
417
|
+
|
|
418
|
+
result = ExecutionResult(
|
|
419
|
+
execution_id=execution_id,
|
|
420
|
+
status=status,
|
|
421
|
+
output=output,
|
|
422
|
+
gpu_seconds=result_data.get("gpu_seconds", 0) or 0,
|
|
423
|
+
cost_cents=result_data.get("cost_cents", 0) or 0,
|
|
424
|
+
error=result_data.get("error"),
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Print final status
|
|
428
|
+
self._print_result(result)
|
|
429
|
+
|
|
430
|
+
if status == "failed":
|
|
431
|
+
raise ExecutionError(
|
|
432
|
+
result.error or "Execution failed",
|
|
433
|
+
execution_id=execution_id,
|
|
434
|
+
status=status,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return result
|
|
438
|
+
|
|
439
|
+
def _print_status_update(self, update: StatusUpdate):
|
|
440
|
+
"""Print status update from SSE stream."""
|
|
441
|
+
if update.message:
|
|
442
|
+
print(f" {update.message}")
|
|
443
|
+
else:
|
|
444
|
+
status_messages = {
|
|
445
|
+
"pending": "Status: pending",
|
|
446
|
+
"queued": f"Status: queued (position {update.queue_position or '?'})",
|
|
447
|
+
"provisioning": "Provisioning GPU...",
|
|
448
|
+
"running": "Running on GPU...",
|
|
449
|
+
"completed": "Completed!",
|
|
450
|
+
"failed": "Failed!",
|
|
451
|
+
"timeout": "Timed out!",
|
|
452
|
+
"cancelled": "Cancelled",
|
|
453
|
+
}
|
|
454
|
+
print(f" {status_messages.get(update.status, f'Status: {update.status}')}")
|
|
455
|
+
|
|
456
|
+
# Show progress if available
|
|
457
|
+
if update.progress is not None and update.status == "running":
|
|
458
|
+
print(f" Progress: {update.progress:.1f}%")
|
|
459
|
+
|
|
460
|
+
def _print_status(self, status: str, data: Dict[str, Any]):
|
|
461
|
+
"""Print status update."""
|
|
462
|
+
status_messages = {
|
|
463
|
+
"pending": " Status: pending",
|
|
464
|
+
"queued": f" Status: queued (position {data.get('queue_position', '?')})",
|
|
465
|
+
"provisioning": " Provisioning GPU...",
|
|
466
|
+
"running": " Running on GPU...",
|
|
467
|
+
"completed": " Completed!",
|
|
468
|
+
"failed": " Failed!",
|
|
469
|
+
"timeout": " Timed out!",
|
|
470
|
+
"cancelled": " Cancelled",
|
|
471
|
+
}
|
|
472
|
+
print(status_messages.get(status, f" Status: {status}"))
|
|
473
|
+
|
|
474
|
+
def _print_result(self, result: ExecutionResult):
|
|
475
|
+
"""Print execution result summary."""
|
|
476
|
+
if result.success:
|
|
477
|
+
print(f"\nExecution completed successfully!")
|
|
478
|
+
print(f" GPU time: {result.gpu_seconds:.1f}s")
|
|
479
|
+
print(f" Cost: ₹{result.cost_inr:.4f}")
|
|
480
|
+
if result.output:
|
|
481
|
+
print(f"\nOutput:\n{result.output}")
|
|
482
|
+
else:
|
|
483
|
+
print(f"\nExecution {result.status}")
|
|
484
|
+
if result.error:
|
|
485
|
+
print(f" Error: {result.error}")
|
|
486
|
+
|
|
487
|
+
def get_status(self, execution_id: str) -> Dict[str, Any]:
|
|
488
|
+
"""Get current status of an execution."""
|
|
489
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/status"
|
|
490
|
+
|
|
491
|
+
response = self._session.get(url)
|
|
492
|
+
if response.status_code != 200:
|
|
493
|
+
raise PodstackError(f"Failed to get status: {response.text}")
|
|
494
|
+
|
|
495
|
+
return response.json()
|
|
496
|
+
|
|
497
|
+
def cancel(self, execution_id: str, reason: str = "Cancelled by user") -> bool:
|
|
498
|
+
"""Cancel a running execution."""
|
|
499
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/cancel"
|
|
500
|
+
|
|
501
|
+
response = self._session.post(url, params={"reason": reason})
|
|
502
|
+
return response.status_code == 200
|
|
503
|
+
|
|
504
|
+
def list_gpus(self) -> Dict[str, Any]:
|
|
505
|
+
"""List available GPU types and pricing."""
|
|
506
|
+
url = f"{self.api_url}/api/v1/gpu-types"
|
|
507
|
+
|
|
508
|
+
response = self._session.get(url)
|
|
509
|
+
if response.status_code != 200:
|
|
510
|
+
raise PodstackError(f"Failed to get GPU types: {response.text}")
|
|
511
|
+
|
|
512
|
+
return response.json()
|
|
513
|
+
|
|
514
|
+
def stream_status(
|
|
515
|
+
self,
|
|
516
|
+
execution_id: str,
|
|
517
|
+
callback: Callable[[StatusUpdate], None],
|
|
518
|
+
timeout: int = 3600,
|
|
519
|
+
) -> None:
|
|
520
|
+
"""
|
|
521
|
+
Stream status updates for an execution with a callback.
|
|
522
|
+
|
|
523
|
+
This method blocks until the execution reaches a terminal state.
|
|
524
|
+
Use this for real-time status tracking in your application.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
execution_id: Execution ID to track
|
|
528
|
+
callback: Function called with each StatusUpdate
|
|
529
|
+
timeout: Maximum time to wait in seconds
|
|
530
|
+
|
|
531
|
+
Example:
|
|
532
|
+
def on_status(update):
|
|
533
|
+
print(f"Status: {update.status}")
|
|
534
|
+
if update.queue_position:
|
|
535
|
+
print(f"Queue position: {update.queue_position}")
|
|
536
|
+
|
|
537
|
+
runner.stream_status(execution_id, on_status)
|
|
538
|
+
"""
|
|
539
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/status/stream"
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
response = self._session.get(url, stream=True, timeout=timeout)
|
|
543
|
+
if response.status_code != 200:
|
|
544
|
+
raise PodstackError(f"Failed to connect to status stream: {response.text}")
|
|
545
|
+
|
|
546
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
547
|
+
if line is None:
|
|
548
|
+
continue
|
|
549
|
+
|
|
550
|
+
line = line.strip()
|
|
551
|
+
|
|
552
|
+
# Skip empty lines and comments (heartbeat)
|
|
553
|
+
if not line or line.startswith(":"):
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
# Parse SSE data
|
|
557
|
+
if line.startswith("data:"):
|
|
558
|
+
data_str = line[5:].strip()
|
|
559
|
+
try:
|
|
560
|
+
data = json.loads(data_str)
|
|
561
|
+
update = StatusUpdate.from_dict(data)
|
|
562
|
+
|
|
563
|
+
# Call the callback
|
|
564
|
+
if update.status != "heartbeat":
|
|
565
|
+
callback(update)
|
|
566
|
+
|
|
567
|
+
# Stop on terminal state
|
|
568
|
+
if update.is_terminal:
|
|
569
|
+
break
|
|
570
|
+
|
|
571
|
+
except json.JSONDecodeError:
|
|
572
|
+
continue
|
|
573
|
+
|
|
574
|
+
except requests.exceptions.Timeout:
|
|
575
|
+
raise ExecutionError(
|
|
576
|
+
f"Timed out waiting for execution after {timeout}s",
|
|
577
|
+
execution_id=execution_id,
|
|
578
|
+
status="timeout",
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
def stream_status_async(
|
|
582
|
+
self,
|
|
583
|
+
execution_id: str,
|
|
584
|
+
callback: Callable[[StatusUpdate], None],
|
|
585
|
+
timeout: int = 3600,
|
|
586
|
+
) -> threading.Thread:
|
|
587
|
+
"""
|
|
588
|
+
Stream status updates in a background thread.
|
|
589
|
+
|
|
590
|
+
Returns immediately with a thread handle. The callback is called
|
|
591
|
+
from the background thread for each status update.
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
execution_id: Execution ID to track
|
|
595
|
+
callback: Function called with each StatusUpdate
|
|
596
|
+
timeout: Maximum time to wait in seconds
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
Thread object (already started)
|
|
600
|
+
|
|
601
|
+
Example:
|
|
602
|
+
updates = []
|
|
603
|
+
def on_status(update):
|
|
604
|
+
updates.append(update)
|
|
605
|
+
print(f"Status: {update.status}")
|
|
606
|
+
|
|
607
|
+
thread = runner.stream_status_async(execution_id, on_status)
|
|
608
|
+
# Do other work...
|
|
609
|
+
thread.join() # Wait for completion
|
|
610
|
+
"""
|
|
611
|
+
thread = threading.Thread(
|
|
612
|
+
target=self.stream_status,
|
|
613
|
+
args=(execution_id, callback, timeout),
|
|
614
|
+
daemon=True,
|
|
615
|
+
)
|
|
616
|
+
thread.start()
|
|
617
|
+
return thread
|
|
618
|
+
|
|
619
|
+
def stream_logs(
|
|
620
|
+
self,
|
|
621
|
+
execution_id: str,
|
|
622
|
+
timeout: int = 3600,
|
|
623
|
+
):
|
|
624
|
+
"""
|
|
625
|
+
Stream real-time logs in Jupyter notebook with live output display.
|
|
626
|
+
|
|
627
|
+
This method displays logs in real-time as they are generated on the GPU.
|
|
628
|
+
Works with all @podstack annotations - logs will stream automatically.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
execution_id: Execution ID to stream logs for
|
|
632
|
+
timeout: Maximum time to wait in seconds
|
|
633
|
+
|
|
634
|
+
Example (in Jupyter notebook):
|
|
635
|
+
# Submit execution
|
|
636
|
+
gpu = GPURunner(token, project_id)
|
|
637
|
+
result = gpu.run(code, wait=False)
|
|
638
|
+
|
|
639
|
+
# Stream logs in real-time
|
|
640
|
+
gpu.stream_logs(result.execution_id)
|
|
641
|
+
|
|
642
|
+
The logs will update live in the notebook output area.
|
|
643
|
+
"""
|
|
644
|
+
try:
|
|
645
|
+
# Try to use IPython display for Jupyter notebooks
|
|
646
|
+
from IPython.display import display, HTML, clear_output
|
|
647
|
+
import sys
|
|
648
|
+
from io import StringIO
|
|
649
|
+
|
|
650
|
+
# Create output buffer
|
|
651
|
+
output_area = StringIO()
|
|
652
|
+
|
|
653
|
+
print("Streaming logs (press Stop button or wait for completion)...")
|
|
654
|
+
print("=" * 60)
|
|
655
|
+
|
|
656
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/status/stream"
|
|
657
|
+
|
|
658
|
+
response = self._session.get(url, stream=True, timeout=timeout)
|
|
659
|
+
if response.status_code != 200:
|
|
660
|
+
raise PodstackError(f"Failed to connect to log stream: {response.text}")
|
|
661
|
+
|
|
662
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
663
|
+
if line is None:
|
|
664
|
+
continue
|
|
665
|
+
|
|
666
|
+
line = line.strip()
|
|
667
|
+
|
|
668
|
+
# Skip empty lines and comments
|
|
669
|
+
if not line or line.startswith(":"):
|
|
670
|
+
continue
|
|
671
|
+
|
|
672
|
+
# Parse SSE data
|
|
673
|
+
if line.startswith("data:"):
|
|
674
|
+
data_str = line[5:].strip()
|
|
675
|
+
try:
|
|
676
|
+
data = json.loads(data_str)
|
|
677
|
+
|
|
678
|
+
# Check if this is a log message
|
|
679
|
+
if data.get("status") == "log":
|
|
680
|
+
output_type = data.get("data", {}).get("output_type", "stdout")
|
|
681
|
+
content = data.get("data", {}).get("content", "")
|
|
682
|
+
|
|
683
|
+
# Display log content
|
|
684
|
+
if output_type == "stderr":
|
|
685
|
+
# Red for errors
|
|
686
|
+
print(f"\033[91m{content}\033[0m", end="")
|
|
687
|
+
else:
|
|
688
|
+
# Normal for stdout
|
|
689
|
+
print(content, end="")
|
|
690
|
+
|
|
691
|
+
sys.stdout.flush()
|
|
692
|
+
|
|
693
|
+
# Check for terminal state
|
|
694
|
+
update = StatusUpdate.from_dict(data)
|
|
695
|
+
if update.is_terminal:
|
|
696
|
+
print("\n" + "=" * 60)
|
|
697
|
+
print(f"Execution {update.status}")
|
|
698
|
+
if update.error:
|
|
699
|
+
print(f"Error: {update.error}")
|
|
700
|
+
if update.gpu_seconds:
|
|
701
|
+
print(f"GPU time: {update.gpu_seconds:.1f}s")
|
|
702
|
+
if update.cost_cents:
|
|
703
|
+
print(f"Cost: ₹{update.cost_cents / 100:.4f}")
|
|
704
|
+
break
|
|
705
|
+
|
|
706
|
+
except json.JSONDecodeError:
|
|
707
|
+
continue
|
|
708
|
+
|
|
709
|
+
except ImportError:
|
|
710
|
+
# Fallback for non-Jupyter environments
|
|
711
|
+
print("Warning: IPython not available. Falling back to simple streaming.")
|
|
712
|
+
self._stream_logs_simple(execution_id, timeout)
|
|
713
|
+
|
|
714
|
+
def _stream_logs_simple(self, execution_id: str, timeout: int):
|
|
715
|
+
"""Simple log streaming without Jupyter widgets."""
|
|
716
|
+
url = f"{self.api_url}/api/v1/executions/{execution_id}/status/stream"
|
|
717
|
+
|
|
718
|
+
response = self._session.get(url, stream=True, timeout=timeout)
|
|
719
|
+
if response.status_code != 200:
|
|
720
|
+
raise PodstackError(f"Failed to connect to log stream: {response.text}")
|
|
721
|
+
|
|
722
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
723
|
+
if line is None:
|
|
724
|
+
continue
|
|
725
|
+
|
|
726
|
+
line = line.strip()
|
|
727
|
+
|
|
728
|
+
if not line or line.startswith(":"):
|
|
729
|
+
continue
|
|
730
|
+
|
|
731
|
+
if line.startswith("data:"):
|
|
732
|
+
data_str = line[5:].strip()
|
|
733
|
+
try:
|
|
734
|
+
data = json.loads(data_str)
|
|
735
|
+
|
|
736
|
+
if data.get("status") == "log":
|
|
737
|
+
content = data.get("data", {}).get("content", "")
|
|
738
|
+
print(content, end="")
|
|
739
|
+
|
|
740
|
+
update = StatusUpdate.from_dict(data)
|
|
741
|
+
if update.is_terminal:
|
|
742
|
+
print(f"\nExecution {update.status}")
|
|
743
|
+
break
|
|
744
|
+
|
|
745
|
+
except json.JSONDecodeError:
|
|
746
|
+
continue
|