podstack 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- podstack/__init__.py +222 -0
- podstack/annotations.py +725 -0
- podstack/client.py +322 -0
- podstack/exceptions.py +125 -0
- podstack/execution.py +291 -0
- podstack/gpu_runner.py +1141 -0
- podstack/models.py +274 -0
- podstack/notebook.py +410 -0
- podstack/registry/__init__.py +402 -0
- podstack/registry/client.py +957 -0
- podstack/registry/exceptions.py +107 -0
- podstack/registry/experiment.py +227 -0
- podstack/registry/model.py +273 -0
- podstack/registry/model_utils.py +231 -0
- podstack-1.2.0.dist-info/METADATA +299 -0
- podstack-1.2.0.dist-info/RECORD +27 -0
- podstack-1.2.0.dist-info/WHEEL +5 -0
- podstack-1.2.0.dist-info/licenses/LICENSE +21 -0
- podstack-1.2.0.dist-info/top_level.txt +2 -0
- podstack_gpu/__init__.py +126 -0
- podstack_gpu/app.py +675 -0
- podstack_gpu/exceptions.py +35 -0
- podstack_gpu/image.py +325 -0
- podstack_gpu/runner.py +746 -0
- podstack_gpu/secret.py +189 -0
- podstack_gpu/utils.py +203 -0
- podstack_gpu/volume.py +198 -0
podstack/models.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Podstack Data Models
|
|
3
|
+
|
|
4
|
+
Data classes representing various resources in the Podstack API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Optional, List, Dict, Any
|
|
10
|
+
from enum import Enum
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GPUType(str, Enum):
|
|
14
|
+
"""Available GPU types"""
|
|
15
|
+
A10 = "A10"
|
|
16
|
+
A100_40GB = "A100_40GB"
|
|
17
|
+
A100_80GB = "A100_80GB"
|
|
18
|
+
H100 = "H100"
|
|
19
|
+
H100_80GB = "H100_80GB"
|
|
20
|
+
L4 = "L4"
|
|
21
|
+
L40S = "L40S"
|
|
22
|
+
T4 = "T4"
|
|
23
|
+
RTX_4090 = "RTX_4090"
|
|
24
|
+
RTX_6000_ADA = "RTX_6000_ADA"
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def memory_gb(self) -> int:
|
|
28
|
+
"""Get GPU memory in GB"""
|
|
29
|
+
memory_map = {
|
|
30
|
+
"A10": 24,
|
|
31
|
+
"A100_40GB": 40,
|
|
32
|
+
"A100_80GB": 80,
|
|
33
|
+
"H100": 80,
|
|
34
|
+
"H100_80GB": 80,
|
|
35
|
+
"L4": 24,
|
|
36
|
+
"L40S": 48,
|
|
37
|
+
"T4": 16,
|
|
38
|
+
"RTX_4090": 24,
|
|
39
|
+
"RTX_6000_ADA": 48,
|
|
40
|
+
}
|
|
41
|
+
return memory_map.get(self.value, 0)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def price_per_hour_cents(self) -> int:
|
|
45
|
+
"""Get price per hour in cents"""
|
|
46
|
+
price_map = {
|
|
47
|
+
"A10": 90,
|
|
48
|
+
"A100_40GB": 250,
|
|
49
|
+
"A100_80GB": 350,
|
|
50
|
+
"H100": 500,
|
|
51
|
+
"H100_80GB": 550,
|
|
52
|
+
"L4": 70,
|
|
53
|
+
"L40S": 120,
|
|
54
|
+
"T4": 50,
|
|
55
|
+
"RTX_4090": 80,
|
|
56
|
+
"RTX_6000_ADA": 130,
|
|
57
|
+
}
|
|
58
|
+
return price_map.get(self.value, 0)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Environment(str, Enum):
|
|
62
|
+
"""Pre-configured environments"""
|
|
63
|
+
PYTORCH = "pytorch"
|
|
64
|
+
TENSORFLOW = "tensorflow"
|
|
65
|
+
JAX = "jax"
|
|
66
|
+
RAPIDS = "rapids"
|
|
67
|
+
HUGGINGFACE = "huggingface"
|
|
68
|
+
CUSTOM = "custom"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class GPUInfo:
|
|
73
|
+
"""Information about an available GPU type"""
|
|
74
|
+
type: str
|
|
75
|
+
name: str
|
|
76
|
+
memory_gb: int
|
|
77
|
+
price_per_hour_cents: int
|
|
78
|
+
available: int
|
|
79
|
+
total: int
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_dict(cls, data: Dict[str, Any]) -> "GPUInfo":
|
|
83
|
+
return cls(
|
|
84
|
+
type=data["type"],
|
|
85
|
+
name=data["name"],
|
|
86
|
+
memory_gb=data["memory_gb"],
|
|
87
|
+
price_per_hour_cents=data["price_per_hour_cents"],
|
|
88
|
+
available=data["available"],
|
|
89
|
+
total=data["total"]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class Project:
|
|
95
|
+
"""A project for organizing notebooks"""
|
|
96
|
+
id: str
|
|
97
|
+
name: str
|
|
98
|
+
description: Optional[str]
|
|
99
|
+
created_at: datetime
|
|
100
|
+
updated_at: datetime
|
|
101
|
+
notebook_count: int = 0
|
|
102
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Project":
|
|
106
|
+
return cls(
|
|
107
|
+
id=data["id"],
|
|
108
|
+
name=data["name"],
|
|
109
|
+
description=data.get("description"),
|
|
110
|
+
created_at=datetime.fromisoformat(data["created_at"].replace("Z", "+00:00")),
|
|
111
|
+
updated_at=datetime.fromisoformat(data["updated_at"].replace("Z", "+00:00")),
|
|
112
|
+
notebook_count=data.get("notebook_count", 0),
|
|
113
|
+
metadata=data.get("metadata", {})
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass
|
|
118
|
+
class Version:
|
|
119
|
+
"""A notebook version (checkpoint)"""
|
|
120
|
+
id: str
|
|
121
|
+
notebook_id: str
|
|
122
|
+
version_number: int
|
|
123
|
+
message: Optional[str]
|
|
124
|
+
created_at: datetime
|
|
125
|
+
cells_hash: str
|
|
126
|
+
parent_version_id: Optional[str] = None
|
|
127
|
+
branch_name: str = "main"
|
|
128
|
+
is_auto_checkpoint: bool = False
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Version":
|
|
132
|
+
return cls(
|
|
133
|
+
id=data["id"],
|
|
134
|
+
notebook_id=data["notebook_id"],
|
|
135
|
+
version_number=data["version_number"],
|
|
136
|
+
message=data.get("message"),
|
|
137
|
+
created_at=datetime.fromisoformat(data["created_at"].replace("Z", "+00:00")),
|
|
138
|
+
cells_hash=data["cells_hash"],
|
|
139
|
+
parent_version_id=data.get("parent_version_id"),
|
|
140
|
+
branch_name=data.get("branch_name", "main"),
|
|
141
|
+
is_auto_checkpoint=data.get("is_auto_checkpoint", False)
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class WalletBalance:
|
|
147
|
+
"""User's wallet balance information"""
|
|
148
|
+
balance_cents: int
|
|
149
|
+
currency: str
|
|
150
|
+
last_updated: datetime
|
|
151
|
+
pending_charges_cents: int = 0
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def balance(self) -> float:
|
|
155
|
+
"""Get balance in dollars"""
|
|
156
|
+
return self.balance_cents / 100
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def available_balance(self) -> float:
|
|
160
|
+
"""Get available balance (minus pending charges)"""
|
|
161
|
+
return (self.balance_cents - self.pending_charges_cents) / 100
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def from_dict(cls, data: Dict[str, Any]) -> "WalletBalance":
|
|
165
|
+
return cls(
|
|
166
|
+
balance_cents=data["balance_cents"],
|
|
167
|
+
currency=data.get("currency", "USD"),
|
|
168
|
+
last_updated=datetime.fromisoformat(data["last_updated"].replace("Z", "+00:00")),
|
|
169
|
+
pending_charges_cents=data.get("pending_charges_cents", 0)
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass
|
|
174
|
+
class UsageRecord:
|
|
175
|
+
"""Usage record for billing"""
|
|
176
|
+
date: str
|
|
177
|
+
gpu_type: str
|
|
178
|
+
gpu_seconds: int
|
|
179
|
+
cost_cents: int
|
|
180
|
+
notebook_id: Optional[str] = None
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def cost(self) -> float:
|
|
184
|
+
"""Get cost in dollars"""
|
|
185
|
+
return self.cost_cents / 100
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def from_dict(cls, data: Dict[str, Any]) -> "UsageRecord":
|
|
189
|
+
return cls(
|
|
190
|
+
date=data["date"],
|
|
191
|
+
gpu_type=data["gpu_type"],
|
|
192
|
+
gpu_seconds=data["gpu_seconds"],
|
|
193
|
+
cost_cents=data["cost_cents"],
|
|
194
|
+
notebook_id=data.get("notebook_id")
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@dataclass
|
|
199
|
+
class UsageSummary:
|
|
200
|
+
"""Summary of usage for a period"""
|
|
201
|
+
total_cost_cents: int
|
|
202
|
+
total_gpu_seconds: int
|
|
203
|
+
breakdown: List[UsageRecord]
|
|
204
|
+
start_date: Optional[str] = None
|
|
205
|
+
end_date: Optional[str] = None
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def total_cost(self) -> float:
|
|
209
|
+
"""Get total cost in dollars"""
|
|
210
|
+
return self.total_cost_cents / 100
|
|
211
|
+
|
|
212
|
+
@classmethod
|
|
213
|
+
def from_dict(cls, data: Dict[str, Any]) -> "UsageSummary":
|
|
214
|
+
return cls(
|
|
215
|
+
total_cost_cents=data["total_cost_cents"],
|
|
216
|
+
total_gpu_seconds=data["total_gpu_seconds"],
|
|
217
|
+
breakdown=[UsageRecord.from_dict(r) for r in data.get("breakdown", [])],
|
|
218
|
+
start_date=data.get("start_date"),
|
|
219
|
+
end_date=data.get("end_date")
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@dataclass
|
|
224
|
+
class Webhook:
|
|
225
|
+
"""Webhook configuration"""
|
|
226
|
+
id: str
|
|
227
|
+
url: str
|
|
228
|
+
events: List[str]
|
|
229
|
+
created_at: datetime
|
|
230
|
+
is_active: bool = True
|
|
231
|
+
secret: Optional[str] = None
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Webhook":
|
|
235
|
+
return cls(
|
|
236
|
+
id=data["id"],
|
|
237
|
+
url=data["url"],
|
|
238
|
+
events=data["events"],
|
|
239
|
+
created_at=datetime.fromisoformat(data["created_at"].replace("Z", "+00:00")),
|
|
240
|
+
is_active=data.get("is_active", True),
|
|
241
|
+
secret=data.get("secret")
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@dataclass
|
|
246
|
+
class Cell:
|
|
247
|
+
"""A notebook cell"""
|
|
248
|
+
id: str
|
|
249
|
+
cell_type: str # "code" or "markdown"
|
|
250
|
+
source: str
|
|
251
|
+
outputs: List[Dict[str, Any]] = field(default_factory=list)
|
|
252
|
+
execution_count: Optional[int] = None
|
|
253
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Cell":
|
|
257
|
+
return cls(
|
|
258
|
+
id=data["id"],
|
|
259
|
+
cell_type=data["cell_type"],
|
|
260
|
+
source=data["source"],
|
|
261
|
+
outputs=data.get("outputs", []),
|
|
262
|
+
execution_count=data.get("execution_count"),
|
|
263
|
+
metadata=data.get("metadata", {})
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
267
|
+
return {
|
|
268
|
+
"id": self.id,
|
|
269
|
+
"cell_type": self.cell_type,
|
|
270
|
+
"source": self.source,
|
|
271
|
+
"outputs": self.outputs,
|
|
272
|
+
"execution_count": self.execution_count,
|
|
273
|
+
"metadata": self.metadata
|
|
274
|
+
}
|
podstack/notebook.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Podstack Notebook Module
|
|
3
|
+
|
|
4
|
+
Handles notebook operations including creation, execution, and versioning.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
|
10
|
+
from enum import Enum
|
|
11
|
+
import asyncio
|
|
12
|
+
|
|
13
|
+
from .models import Cell, Version
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .client import Client
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class NotebookStatus(str, Enum):
|
|
20
|
+
"""Status of a notebook"""
|
|
21
|
+
CREATING = "creating"
|
|
22
|
+
RUNNING = "running"
|
|
23
|
+
IDLE = "idle"
|
|
24
|
+
STOPPING = "stopping"
|
|
25
|
+
STOPPED = "stopped"
|
|
26
|
+
FAILED = "failed"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Notebook:
|
|
31
|
+
"""
|
|
32
|
+
Represents a GPU notebook.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
id: Unique notebook ID
|
|
36
|
+
name: Notebook name
|
|
37
|
+
status: Current status
|
|
38
|
+
gpu_type: Type of GPU allocated
|
|
39
|
+
environment: Environment preset
|
|
40
|
+
created_at: Creation timestamp
|
|
41
|
+
startup_time_ms: Time to start in milliseconds
|
|
42
|
+
endpoint: WebSocket endpoint for real-time communication
|
|
43
|
+
jupyter_url: URL to access JupyterLab
|
|
44
|
+
project_id: Associated project ID
|
|
45
|
+
idle_timeout_minutes: Auto-shutdown after idle minutes
|
|
46
|
+
auto_shutdown_enabled: Whether auto-shutdown is enabled
|
|
47
|
+
metadata: Custom metadata
|
|
48
|
+
"""
|
|
49
|
+
id: str
|
|
50
|
+
name: str
|
|
51
|
+
status: NotebookStatus
|
|
52
|
+
gpu_type: str
|
|
53
|
+
environment: str
|
|
54
|
+
created_at: Optional[datetime] = None
|
|
55
|
+
startup_time_ms: Optional[float] = None
|
|
56
|
+
endpoint: Optional[str] = None
|
|
57
|
+
jupyter_url: Optional[str] = None
|
|
58
|
+
project_id: Optional[str] = None
|
|
59
|
+
idle_timeout_minutes: int = 30
|
|
60
|
+
auto_shutdown_enabled: bool = True
|
|
61
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
62
|
+
cells: List[Cell] = field(default_factory=list)
|
|
63
|
+
|
|
64
|
+
_client: Optional["Client"] = field(default=None, repr=False)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_dict(cls, data: Dict[str, Any], client: "Client" = None) -> "Notebook":
|
|
68
|
+
created_at = None
|
|
69
|
+
if data.get("created_at"):
|
|
70
|
+
created_at = datetime.fromisoformat(data["created_at"].replace("Z", "+00:00"))
|
|
71
|
+
|
|
72
|
+
cells = []
|
|
73
|
+
if data.get("cells"):
|
|
74
|
+
cells = [Cell.from_dict(c) for c in data["cells"]]
|
|
75
|
+
|
|
76
|
+
return cls(
|
|
77
|
+
id=data["id"],
|
|
78
|
+
name=data["name"],
|
|
79
|
+
status=NotebookStatus(data["status"]),
|
|
80
|
+
gpu_type=data["gpu_type"],
|
|
81
|
+
environment=data["environment"],
|
|
82
|
+
created_at=created_at,
|
|
83
|
+
startup_time_ms=data.get("startup_time_ms"),
|
|
84
|
+
endpoint=data.get("endpoint"),
|
|
85
|
+
jupyter_url=data.get("jupyter_url"),
|
|
86
|
+
project_id=data.get("project_id"),
|
|
87
|
+
idle_timeout_minutes=data.get("idle_timeout_minutes", 30),
|
|
88
|
+
auto_shutdown_enabled=data.get("auto_shutdown_enabled", True),
|
|
89
|
+
metadata=data.get("metadata", {}),
|
|
90
|
+
cells=cells,
|
|
91
|
+
_client=client
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def is_running(self) -> bool:
|
|
96
|
+
"""Check if notebook is running"""
|
|
97
|
+
return self.status in (NotebookStatus.RUNNING, NotebookStatus.IDLE)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def is_ready(self) -> bool:
|
|
101
|
+
"""Check if notebook is ready for execution"""
|
|
102
|
+
return self.status == NotebookStatus.RUNNING
|
|
103
|
+
|
|
104
|
+
async def refresh(self) -> "Notebook":
|
|
105
|
+
"""Refresh notebook status from API"""
|
|
106
|
+
if not self._client:
|
|
107
|
+
raise RuntimeError("Notebook not bound to client")
|
|
108
|
+
|
|
109
|
+
data = await self._client._request("GET", f"/notebooks/{self.id}")
|
|
110
|
+
updated = Notebook.from_dict(data, self._client)
|
|
111
|
+
|
|
112
|
+
# Update self with new data
|
|
113
|
+
self.status = updated.status
|
|
114
|
+
self.endpoint = updated.endpoint
|
|
115
|
+
self.jupyter_url = updated.jupyter_url
|
|
116
|
+
self.cells = updated.cells
|
|
117
|
+
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
async def wait_ready(self, timeout: float = 60, poll_interval: float = 0.5) -> "Notebook":
|
|
121
|
+
"""
|
|
122
|
+
Wait for notebook to be ready.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
timeout: Maximum seconds to wait
|
|
126
|
+
poll_interval: Seconds between status checks
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Updated Notebook object
|
|
130
|
+
"""
|
|
131
|
+
from .exceptions import NotebookStateError
|
|
132
|
+
|
|
133
|
+
start_time = asyncio.get_event_loop().time()
|
|
134
|
+
|
|
135
|
+
while self.status == NotebookStatus.CREATING:
|
|
136
|
+
if (asyncio.get_event_loop().time() - start_time) > timeout:
|
|
137
|
+
raise NotebookStateError(self.id, self.status.value, "running")
|
|
138
|
+
|
|
139
|
+
await asyncio.sleep(poll_interval)
|
|
140
|
+
await self.refresh()
|
|
141
|
+
|
|
142
|
+
if self.status == NotebookStatus.FAILED:
|
|
143
|
+
raise NotebookStateError(self.id, "failed", "running")
|
|
144
|
+
|
|
145
|
+
return self
|
|
146
|
+
|
|
147
|
+
async def execute(
|
|
148
|
+
self,
|
|
149
|
+
code: str,
|
|
150
|
+
timeout_seconds: int = 300,
|
|
151
|
+
wait: bool = True
|
|
152
|
+
) -> "Execution":
|
|
153
|
+
"""
|
|
154
|
+
Execute code in this notebook.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
code: Python code to execute
|
|
158
|
+
timeout_seconds: Maximum execution time
|
|
159
|
+
wait: Whether to wait for completion
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Execution object
|
|
163
|
+
"""
|
|
164
|
+
from .execution import Execution
|
|
165
|
+
|
|
166
|
+
if not self._client:
|
|
167
|
+
raise RuntimeError("Notebook not bound to client")
|
|
168
|
+
|
|
169
|
+
data = await self._client._request(
|
|
170
|
+
"POST",
|
|
171
|
+
f"/notebooks/{self.id}/execute",
|
|
172
|
+
json={"code": code, "timeout_seconds": timeout_seconds}
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
execution = Execution.from_dict(data, self._client)
|
|
176
|
+
execution.notebook_id = self.id
|
|
177
|
+
|
|
178
|
+
if wait:
|
|
179
|
+
await execution.wait(timeout=timeout_seconds)
|
|
180
|
+
|
|
181
|
+
return execution
|
|
182
|
+
|
|
183
|
+
async def stop(self) -> "Notebook":
|
|
184
|
+
"""Stop the notebook"""
|
|
185
|
+
if not self._client:
|
|
186
|
+
raise RuntimeError("Notebook not bound to client")
|
|
187
|
+
|
|
188
|
+
await self._client._request("POST", f"/notebooks/{self.id}/stop")
|
|
189
|
+
await self.refresh()
|
|
190
|
+
return self
|
|
191
|
+
|
|
192
|
+
async def start(self) -> "Notebook":
|
|
193
|
+
"""Start a stopped notebook"""
|
|
194
|
+
if not self._client:
|
|
195
|
+
raise RuntimeError("Notebook not bound to client")
|
|
196
|
+
|
|
197
|
+
await self._client._request("POST", f"/notebooks/{self.id}/start")
|
|
198
|
+
await self.refresh()
|
|
199
|
+
return self
|
|
200
|
+
|
|
201
|
+
async def delete(self):
|
|
202
|
+
"""Delete the notebook"""
|
|
203
|
+
if not self._client:
|
|
204
|
+
raise RuntimeError("Notebook not bound to client")
|
|
205
|
+
|
|
206
|
+
await self._client._request("DELETE", f"/notebooks/{self.id}")
|
|
207
|
+
|
|
208
|
+
# Versioning methods
|
|
209
|
+
async def save(self, message: str = None) -> Version:
|
|
210
|
+
"""
|
|
211
|
+
Save notebook state (create a version).
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
message: Optional commit message
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Version object
|
|
218
|
+
"""
|
|
219
|
+
if not self._client:
|
|
220
|
+
raise RuntimeError("Notebook not bound to client")
|
|
221
|
+
|
|
222
|
+
data = await self._client._request(
|
|
223
|
+
"POST",
|
|
224
|
+
f"/notebooks/{self.id}/versions",
|
|
225
|
+
json={
|
|
226
|
+
"message": message,
|
|
227
|
+
"cells": [c.to_dict() for c in self.cells] if self.cells else None
|
|
228
|
+
}
|
|
229
|
+
)
|
|
230
|
+
return Version.from_dict(data)
|
|
231
|
+
|
|
232
|
+
async def list_versions(self, branch: str = None, limit: int = 20) -> List[Version]:
|
|
233
|
+
"""
|
|
234
|
+
List notebook versions.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
branch: Filter by branch name
|
|
238
|
+
limit: Maximum results
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
List of Version objects
|
|
242
|
+
"""
|
|
243
|
+
if not self._client:
|
|
244
|
+
raise RuntimeError("Notebook not bound to client")
|
|
245
|
+
|
|
246
|
+
params = {"limit": limit}
|
|
247
|
+
if branch:
|
|
248
|
+
params["branch"] = branch
|
|
249
|
+
|
|
250
|
+
data = await self._client._request(
|
|
251
|
+
"GET",
|
|
252
|
+
f"/notebooks/{self.id}/versions",
|
|
253
|
+
params=params
|
|
254
|
+
)
|
|
255
|
+
return [Version.from_dict(v) for v in data.get("versions", [])]
|
|
256
|
+
|
|
257
|
+
async def get_version(self, version_id: str) -> Version:
|
|
258
|
+
"""Get a specific version"""
|
|
259
|
+
if not self._client:
|
|
260
|
+
raise RuntimeError("Notebook not bound to client")
|
|
261
|
+
|
|
262
|
+
data = await self._client._request(
|
|
263
|
+
"GET",
|
|
264
|
+
f"/notebooks/{self.id}/versions/{version_id}"
|
|
265
|
+
)
|
|
266
|
+
return Version.from_dict(data)
|
|
267
|
+
|
|
268
|
+
async def restore_version(self, version_id: str) -> "Notebook":
|
|
269
|
+
"""
|
|
270
|
+
Restore notebook to a previous version.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
version_id: Version to restore
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
Updated Notebook object
|
|
277
|
+
"""
|
|
278
|
+
if not self._client:
|
|
279
|
+
raise RuntimeError("Notebook not bound to client")
|
|
280
|
+
|
|
281
|
+
await self._client._request(
|
|
282
|
+
"POST",
|
|
283
|
+
f"/notebooks/{self.id}/versions/{version_id}/restore"
|
|
284
|
+
)
|
|
285
|
+
await self.refresh()
|
|
286
|
+
return self
|
|
287
|
+
|
|
288
|
+
async def create_branch(self, name: str, from_version_id: str = None) -> Dict[str, Any]:
|
|
289
|
+
"""
|
|
290
|
+
Create a branch from a version.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
name: Branch name
|
|
294
|
+
from_version_id: Version to branch from (default: latest)
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Branch info
|
|
298
|
+
"""
|
|
299
|
+
if not self._client:
|
|
300
|
+
raise RuntimeError("Notebook not bound to client")
|
|
301
|
+
|
|
302
|
+
data = await self._client._request(
|
|
303
|
+
"POST",
|
|
304
|
+
f"/notebooks/{self.id}/branches",
|
|
305
|
+
json={"name": name, "from_version_id": from_version_id}
|
|
306
|
+
)
|
|
307
|
+
return data
|
|
308
|
+
|
|
309
|
+
async def list_branches(self) -> List[Dict[str, Any]]:
|
|
310
|
+
"""List notebook branches"""
|
|
311
|
+
if not self._client:
|
|
312
|
+
raise RuntimeError("Notebook not bound to client")
|
|
313
|
+
|
|
314
|
+
data = await self._client._request("GET", f"/notebooks/{self.id}/branches")
|
|
315
|
+
return data.get("branches", [])
|
|
316
|
+
|
|
317
|
+
def __str__(self) -> str:
|
|
318
|
+
return f"Notebook({self.id}, name={self.name!r}, status={self.status.value})"
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class NotebooksAPI:
|
|
322
|
+
"""API for managing notebooks"""
|
|
323
|
+
|
|
324
|
+
def __init__(self, client: "Client"):
|
|
325
|
+
self._client = client
|
|
326
|
+
|
|
327
|
+
async def create(
|
|
328
|
+
self,
|
|
329
|
+
name: str,
|
|
330
|
+
gpu_type: str = "A10",
|
|
331
|
+
environment: str = "pytorch",
|
|
332
|
+
project_id: str = None,
|
|
333
|
+
idle_timeout_minutes: int = 30,
|
|
334
|
+
auto_shutdown_enabled: bool = True,
|
|
335
|
+
metadata: Dict[str, Any] = None,
|
|
336
|
+
wait_ready: bool = True
|
|
337
|
+
) -> Notebook:
|
|
338
|
+
"""
|
|
339
|
+
Create a new notebook.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
name: Notebook name
|
|
343
|
+
gpu_type: GPU type (A10, A100, H100, etc.)
|
|
344
|
+
environment: Environment preset (pytorch, tensorflow, jax, etc.)
|
|
345
|
+
project_id: Optional project ID
|
|
346
|
+
idle_timeout_minutes: Auto-shutdown after idle minutes
|
|
347
|
+
auto_shutdown_enabled: Whether to enable auto-shutdown
|
|
348
|
+
metadata: Custom metadata
|
|
349
|
+
wait_ready: Whether to wait for notebook to be ready
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Notebook object
|
|
353
|
+
"""
|
|
354
|
+
payload = {
|
|
355
|
+
"name": name,
|
|
356
|
+
"gpu_type": gpu_type,
|
|
357
|
+
"environment": environment,
|
|
358
|
+
"idle_timeout_minutes": idle_timeout_minutes,
|
|
359
|
+
"auto_shutdown_enabled": auto_shutdown_enabled
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
if project_id:
|
|
363
|
+
payload["project_id"] = project_id
|
|
364
|
+
if metadata:
|
|
365
|
+
payload["metadata"] = metadata
|
|
366
|
+
|
|
367
|
+
data = await self._client._request("POST", "/notebooks", json=payload)
|
|
368
|
+
notebook = Notebook.from_dict(data, self._client)
|
|
369
|
+
|
|
370
|
+
if wait_ready:
|
|
371
|
+
await notebook.wait_ready()
|
|
372
|
+
|
|
373
|
+
return notebook
|
|
374
|
+
|
|
375
|
+
async def get(self, notebook_id: str) -> Notebook:
|
|
376
|
+
"""Get a notebook by ID"""
|
|
377
|
+
data = await self._client._request("GET", f"/notebooks/{notebook_id}")
|
|
378
|
+
return Notebook.from_dict(data, self._client)
|
|
379
|
+
|
|
380
|
+
async def list(
|
|
381
|
+
self,
|
|
382
|
+
status: NotebookStatus = None,
|
|
383
|
+
project_id: str = None,
|
|
384
|
+
limit: int = 20,
|
|
385
|
+
offset: int = 0
|
|
386
|
+
) -> List[Notebook]:
|
|
387
|
+
"""
|
|
388
|
+
List notebooks.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
status: Filter by status
|
|
392
|
+
project_id: Filter by project
|
|
393
|
+
limit: Maximum results
|
|
394
|
+
offset: Pagination offset
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
List of Notebook objects
|
|
398
|
+
"""
|
|
399
|
+
params = {"limit": limit, "offset": offset}
|
|
400
|
+
if status:
|
|
401
|
+
params["status"] = status.value
|
|
402
|
+
if project_id:
|
|
403
|
+
params["project_id"] = project_id
|
|
404
|
+
|
|
405
|
+
data = await self._client._request("GET", "/notebooks", params=params)
|
|
406
|
+
return [Notebook.from_dict(n, self._client) for n in data.get("notebooks", [])]
|
|
407
|
+
|
|
408
|
+
async def delete(self, notebook_id: str):
|
|
409
|
+
"""Delete a notebook"""
|
|
410
|
+
await self._client._request("DELETE", f"/notebooks/{notebook_id}")
|