podstack 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- podstack/__init__.py +222 -0
- podstack/annotations.py +725 -0
- podstack/client.py +322 -0
- podstack/exceptions.py +125 -0
- podstack/execution.py +291 -0
- podstack/gpu_runner.py +1141 -0
- podstack/models.py +274 -0
- podstack/notebook.py +410 -0
- podstack/registry/__init__.py +402 -0
- podstack/registry/client.py +957 -0
- podstack/registry/exceptions.py +107 -0
- podstack/registry/experiment.py +227 -0
- podstack/registry/model.py +273 -0
- podstack/registry/model_utils.py +231 -0
- podstack-1.2.0.dist-info/METADATA +299 -0
- podstack-1.2.0.dist-info/RECORD +27 -0
- podstack-1.2.0.dist-info/WHEEL +5 -0
- podstack-1.2.0.dist-info/licenses/LICENSE +21 -0
- podstack-1.2.0.dist-info/top_level.txt +2 -0
- podstack_gpu/__init__.py +126 -0
- podstack_gpu/app.py +675 -0
- podstack_gpu/exceptions.py +35 -0
- podstack_gpu/image.py +325 -0
- podstack_gpu/runner.py +746 -0
- podstack_gpu/secret.py +189 -0
- podstack_gpu/utils.py +203 -0
- podstack_gpu/volume.py +198 -0
|
@@ -0,0 +1,957 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Registry Client
|
|
3
|
+
|
|
4
|
+
Client for interacting with the Podstack Registry API.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
import tempfile
|
|
10
|
+
import shutil
|
|
11
|
+
from typing import Optional, Dict, Any, List
|
|
12
|
+
import requests
|
|
13
|
+
|
|
14
|
+
from .experiment import Experiment, Run, Metric, Param
|
|
15
|
+
from .model import RegisteredModel, ModelVersion, ModelAlias, StageTransition
|
|
16
|
+
from .exceptions import (
|
|
17
|
+
RegistryError,
|
|
18
|
+
ExperimentNotFoundError,
|
|
19
|
+
RunNotFoundError,
|
|
20
|
+
ModelNotFoundError,
|
|
21
|
+
NoActiveRunError,
|
|
22
|
+
NoExperimentSetError,
|
|
23
|
+
InvalidStageError,
|
|
24
|
+
ModelVersionNotFoundError,
|
|
25
|
+
ArtifactNotFoundError,
|
|
26
|
+
ModelSerializationError,
|
|
27
|
+
)
|
|
28
|
+
from ..exceptions import AuthenticationError
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RegistryClient:
|
|
32
|
+
"""
|
|
33
|
+
Podstack Registry Client for experiment tracking and model management.
|
|
34
|
+
|
|
35
|
+
Usage:
|
|
36
|
+
client = RegistryClient(
|
|
37
|
+
api_url="https://cloud.podstack.ai/registry",
|
|
38
|
+
api_key="your-api-key",
|
|
39
|
+
project_id="your-project-id"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Set experiment and track runs
|
|
43
|
+
client.set_experiment("my-experiment")
|
|
44
|
+
with client.start_run(name="training") as run:
|
|
45
|
+
client.log_params({"lr": 0.001})
|
|
46
|
+
client.log_metrics({"loss": 0.5}, step=1)
|
|
47
|
+
|
|
48
|
+
# Register and manage models
|
|
49
|
+
model = client.register_model("my-model", run_id=run.id)
|
|
50
|
+
client.set_model_stage("my-model", version=1, stage="production")
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
DEFAULT_API_URL = "https://cloud.podstack.ai/registry"
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
api_url: str = None,
|
|
58
|
+
api_key: str = None,
|
|
59
|
+
project_id: str = None,
|
|
60
|
+
timeout: int = 30
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Initialize the registry client.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
api_url: Registry API URL. Defaults to PODSTACK_REGISTRY_URL env var.
|
|
67
|
+
api_key: API key (psk_* token). Defaults to PODSTACK_API_KEY env var.
|
|
68
|
+
project_id: Project ID. Defaults to PODSTACK_PROJECT_ID env var.
|
|
69
|
+
timeout: Request timeout in seconds.
|
|
70
|
+
"""
|
|
71
|
+
self.api_url = api_url or os.getenv("PODSTACK_REGISTRY_URL", self.DEFAULT_API_URL)
|
|
72
|
+
self.api_key = api_key or os.getenv("PODSTACK_API_KEY")
|
|
73
|
+
self.project_id = project_id or os.getenv("PODSTACK_PROJECT_ID")
|
|
74
|
+
self.timeout = timeout
|
|
75
|
+
|
|
76
|
+
# State
|
|
77
|
+
self._experiment_id: Optional[str] = None
|
|
78
|
+
self._experiment: Optional[Experiment] = None
|
|
79
|
+
self._active_run: Optional[Run] = None
|
|
80
|
+
|
|
81
|
+
def _get_headers(self) -> Dict[str, str]:
|
|
82
|
+
"""Get request headers with auth and project ID."""
|
|
83
|
+
headers = {
|
|
84
|
+
"Content-Type": "application/json",
|
|
85
|
+
}
|
|
86
|
+
if self.api_key:
|
|
87
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
88
|
+
if self.project_id:
|
|
89
|
+
headers["X-Project-ID"] = self.project_id
|
|
90
|
+
return headers
|
|
91
|
+
|
|
92
|
+
def _request(
|
|
93
|
+
self,
|
|
94
|
+
method: str,
|
|
95
|
+
endpoint: str,
|
|
96
|
+
json: Dict[str, Any] = None,
|
|
97
|
+
params: Dict[str, Any] = None,
|
|
98
|
+
files: Dict = None
|
|
99
|
+
) -> Dict[str, Any]:
|
|
100
|
+
"""
|
|
101
|
+
Make an API request.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
method: HTTP method
|
|
105
|
+
endpoint: API endpoint path
|
|
106
|
+
json: JSON body
|
|
107
|
+
params: Query parameters
|
|
108
|
+
files: Files to upload
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Response data as dict
|
|
112
|
+
"""
|
|
113
|
+
url = f"{self.api_url}/api/v1{endpoint}"
|
|
114
|
+
headers = self._get_headers()
|
|
115
|
+
|
|
116
|
+
# Remove Content-Type for file uploads
|
|
117
|
+
if files:
|
|
118
|
+
del headers["Content-Type"]
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
response = requests.request(
|
|
122
|
+
method=method,
|
|
123
|
+
url=url,
|
|
124
|
+
headers=headers,
|
|
125
|
+
json=json,
|
|
126
|
+
params=params,
|
|
127
|
+
files=files,
|
|
128
|
+
timeout=self.timeout
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if response.status_code == 401:
|
|
132
|
+
raise AuthenticationError("Invalid or expired API key")
|
|
133
|
+
elif response.status_code == 404:
|
|
134
|
+
error_msg = response.json().get("error", "Not found")
|
|
135
|
+
raise RegistryError(error_msg, code="not_found")
|
|
136
|
+
elif response.status_code >= 400:
|
|
137
|
+
error_data = response.json() if response.content else {}
|
|
138
|
+
error_msg = error_data.get("error", f"Request failed: {response.status_code}")
|
|
139
|
+
raise RegistryError(error_msg)
|
|
140
|
+
|
|
141
|
+
if response.status_code == 204 or not response.content:
|
|
142
|
+
return {}
|
|
143
|
+
return response.json()
|
|
144
|
+
|
|
145
|
+
except requests.exceptions.Timeout:
|
|
146
|
+
raise RegistryError("Request timed out", code="timeout")
|
|
147
|
+
except requests.exceptions.ConnectionError as e:
|
|
148
|
+
raise RegistryError(f"Connection error: {e}", code="connection_error")
|
|
149
|
+
|
|
150
|
+
# ==================== Experiment Methods ====================
|
|
151
|
+
|
|
152
|
+
def set_experiment(self, name: str, description: str = None) -> Experiment:
|
|
153
|
+
"""
|
|
154
|
+
Set the active experiment. Creates if doesn't exist.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
name: Experiment name
|
|
158
|
+
description: Optional description
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Experiment object
|
|
162
|
+
"""
|
|
163
|
+
# Try to get existing experiment
|
|
164
|
+
try:
|
|
165
|
+
data = self._request("GET", f"/experiments/name/{name}")
|
|
166
|
+
self._experiment = Experiment.from_dict(data)
|
|
167
|
+
except RegistryError:
|
|
168
|
+
# Create new experiment
|
|
169
|
+
data = self._request("POST", "/experiments", json={
|
|
170
|
+
"name": name,
|
|
171
|
+
"description": description or ""
|
|
172
|
+
})
|
|
173
|
+
exp_data = data.get("experiment", data)
|
|
174
|
+
self._experiment = Experiment.from_dict(exp_data)
|
|
175
|
+
|
|
176
|
+
self._experiment_id = self._experiment.id
|
|
177
|
+
return self._experiment
|
|
178
|
+
|
|
179
|
+
def get_experiment(self, experiment_id: str) -> Experiment:
|
|
180
|
+
"""Get an experiment by ID."""
|
|
181
|
+
try:
|
|
182
|
+
data = self._request("GET", f"/experiments/{experiment_id}")
|
|
183
|
+
return Experiment.from_dict(data)
|
|
184
|
+
except RegistryError as e:
|
|
185
|
+
if e.code == "not_found":
|
|
186
|
+
raise ExperimentNotFoundError(experiment_id)
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
def list_experiments(self, limit: int = 20, offset: int = 0) -> List[Experiment]:
|
|
190
|
+
"""List experiments in the current project."""
|
|
191
|
+
data = self._request("GET", "/experiments", params={
|
|
192
|
+
"limit": limit,
|
|
193
|
+
"offset": offset
|
|
194
|
+
})
|
|
195
|
+
return [Experiment.from_dict(e) for e in data.get("experiments", [])]
|
|
196
|
+
|
|
197
|
+
def archive_experiment(self, experiment_id: str):
|
|
198
|
+
"""Archive an experiment."""
|
|
199
|
+
self._request("DELETE", f"/experiments/{experiment_id}")
|
|
200
|
+
|
|
201
|
+
# ==================== Run Methods ====================
|
|
202
|
+
|
|
203
|
+
def start_run(self, name: str = None, tags: dict = None) -> Run:
|
|
204
|
+
"""
|
|
205
|
+
Start a new run in the active experiment.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
name: Optional run name
|
|
209
|
+
tags: Optional tags dict
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Run object (can be used as context manager)
|
|
213
|
+
"""
|
|
214
|
+
if not self._experiment_id:
|
|
215
|
+
raise NoExperimentSetError()
|
|
216
|
+
|
|
217
|
+
data = self._request("POST", "/runs", json={
|
|
218
|
+
"experiment_id": self._experiment_id,
|
|
219
|
+
"name": name or f"run-{int(time.time())}",
|
|
220
|
+
"tags": tags or {}
|
|
221
|
+
})
|
|
222
|
+
run_data = data.get("run", data)
|
|
223
|
+
self._active_run = Run.from_dict(run_data, client=self)
|
|
224
|
+
return self._active_run
|
|
225
|
+
|
|
226
|
+
def get_run(self, run_id: str) -> Run:
|
|
227
|
+
"""Get a run by ID."""
|
|
228
|
+
try:
|
|
229
|
+
data = self._request("GET", f"/runs/{run_id}")
|
|
230
|
+
return Run.from_dict(data, client=self)
|
|
231
|
+
except RegistryError as e:
|
|
232
|
+
if e.code == "not_found":
|
|
233
|
+
raise RunNotFoundError(run_id)
|
|
234
|
+
raise
|
|
235
|
+
|
|
236
|
+
def list_runs(
|
|
237
|
+
self,
|
|
238
|
+
experiment_id: str = None,
|
|
239
|
+
status: str = None,
|
|
240
|
+
limit: int = 20,
|
|
241
|
+
offset: int = 0
|
|
242
|
+
) -> List[Run]:
|
|
243
|
+
"""
|
|
244
|
+
List runs.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
experiment_id: Filter by experiment ID
|
|
248
|
+
status: Filter by status (running, completed, failed, killed)
|
|
249
|
+
limit: Max results
|
|
250
|
+
offset: Offset for pagination
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
List of Run objects
|
|
254
|
+
"""
|
|
255
|
+
params = {"limit": limit, "offset": offset}
|
|
256
|
+
if experiment_id:
|
|
257
|
+
params["experiment_id"] = experiment_id
|
|
258
|
+
if status:
|
|
259
|
+
params["status"] = status
|
|
260
|
+
|
|
261
|
+
data = self._request("GET", "/runs", params=params)
|
|
262
|
+
return [Run.from_dict(r, client=self) for r in data.get("runs", [])]
|
|
263
|
+
|
|
264
|
+
def end_run(self, status: str = "completed"):
|
|
265
|
+
"""
|
|
266
|
+
End the active run.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
status: Run status (completed, failed, killed)
|
|
270
|
+
"""
|
|
271
|
+
if not self._active_run:
|
|
272
|
+
raise NoActiveRunError()
|
|
273
|
+
|
|
274
|
+
self._request("POST", f"/runs/{self._active_run.id}/end", json={
|
|
275
|
+
"status": status
|
|
276
|
+
})
|
|
277
|
+
self._active_run.status = status
|
|
278
|
+
self._active_run = None
|
|
279
|
+
|
|
280
|
+
def log_params(self, params: Dict[str, Any]):
|
|
281
|
+
"""
|
|
282
|
+
Log parameters for the active run.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
params: Dict of parameter names to values
|
|
286
|
+
"""
|
|
287
|
+
if not self._active_run:
|
|
288
|
+
raise NoActiveRunError()
|
|
289
|
+
|
|
290
|
+
# Convert all values to strings
|
|
291
|
+
str_params = {k: str(v) for k, v in params.items()}
|
|
292
|
+
self._request("POST", f"/runs/{self._active_run.id}/params", json={
|
|
293
|
+
"params": str_params
|
|
294
|
+
})
|
|
295
|
+
|
|
296
|
+
def log_metrics(self, metrics: Dict[str, float], step: int = None):
|
|
297
|
+
"""
|
|
298
|
+
Log metrics for the active run.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
metrics: Dict of metric names to values
|
|
302
|
+
step: Optional step number
|
|
303
|
+
"""
|
|
304
|
+
if not self._active_run:
|
|
305
|
+
raise NoActiveRunError()
|
|
306
|
+
|
|
307
|
+
timestamp = int(time.time() * 1000)
|
|
308
|
+
metrics_list = [
|
|
309
|
+
{"key": k, "value": float(v), "step": step or 0, "timestamp": timestamp}
|
|
310
|
+
for k, v in metrics.items()
|
|
311
|
+
]
|
|
312
|
+
self._request("POST", f"/runs/{self._active_run.id}/metrics", json={
|
|
313
|
+
"metrics": metrics_list
|
|
314
|
+
})
|
|
315
|
+
|
|
316
|
+
@staticmethod
|
|
317
|
+
def _get_artifact_dir(run_id: str) -> str:
|
|
318
|
+
"""Return the local artifact directory for a run."""
|
|
319
|
+
base = os.getenv(
|
|
320
|
+
"PODSTACK_ARTIFACT_DIR",
|
|
321
|
+
os.path.join(os.path.expanduser("~"), ".podstack", "artifacts"),
|
|
322
|
+
)
|
|
323
|
+
return os.path.join(base, run_id)
|
|
324
|
+
|
|
325
|
+
def log_artifact(self, local_path: str, artifact_path: str = None):
|
|
326
|
+
"""
|
|
327
|
+
Log an artifact file for the active run.
|
|
328
|
+
|
|
329
|
+
Copies the file into a local artifact directory
|
|
330
|
+
(``~/.podstack/artifacts/<run_id>/``) and records the path as a param.
|
|
331
|
+
The backend registry service does not expose an artifact-upload
|
|
332
|
+
endpoint; artifact files live on the local / shared filesystem.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
local_path: Local path to the file
|
|
336
|
+
artifact_path: Optional relative path within the artifact store
|
|
337
|
+
"""
|
|
338
|
+
if not self._active_run:
|
|
339
|
+
raise NoActiveRunError()
|
|
340
|
+
|
|
341
|
+
artifact_path = artifact_path or os.path.basename(local_path)
|
|
342
|
+
dest_dir = self._get_artifact_dir(self._active_run.id)
|
|
343
|
+
dest = os.path.join(dest_dir, artifact_path)
|
|
344
|
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
|
345
|
+
shutil.copy2(local_path, dest)
|
|
346
|
+
|
|
347
|
+
# Record the artifact reference as a param
|
|
348
|
+
self._request("POST", f"/runs/{self._active_run.id}/params", json={
|
|
349
|
+
"params": {f"_artifact.{artifact_path}": dest}
|
|
350
|
+
})
|
|
351
|
+
|
|
352
|
+
def set_tag(self, key: str, value: str):
|
|
353
|
+
"""
|
|
354
|
+
Set a tag on the active run.
|
|
355
|
+
|
|
356
|
+
Tags are persisted as params with a ``_tag.`` prefix via
|
|
357
|
+
``POST /runs/:id/params`` (the backend has no dedicated tags endpoint).
|
|
358
|
+
The in-memory ``run.tags`` dict is also updated for local access.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
key: Tag key
|
|
362
|
+
value: Tag value
|
|
363
|
+
"""
|
|
364
|
+
if not self._active_run:
|
|
365
|
+
raise NoActiveRunError()
|
|
366
|
+
|
|
367
|
+
# Backend has no /runs/:id/tags endpoint; persist via params
|
|
368
|
+
self._request("POST", f"/runs/{self._active_run.id}/params", json={
|
|
369
|
+
"params": {f"_tag.{key}": str(value)}
|
|
370
|
+
})
|
|
371
|
+
self._active_run.tags[key] = value
|
|
372
|
+
|
|
373
|
+
def get_run_metrics(self, run_id: str) -> List[Metric]:
|
|
374
|
+
"""Get all metrics for a run."""
|
|
375
|
+
data = self._request("GET", f"/runs/{run_id}/metrics")
|
|
376
|
+
return [Metric.from_dict(m) for m in data.get("metrics", [])]
|
|
377
|
+
|
|
378
|
+
def get_run_params(self, run_id: str) -> List[Param]:
|
|
379
|
+
"""Get all parameters for a run."""
|
|
380
|
+
data = self._request("GET", f"/runs/{run_id}/params")
|
|
381
|
+
return [Param.from_dict(p) for p in data.get("params", [])]
|
|
382
|
+
|
|
383
|
+
# ==================== Model Registry Methods ====================
|
|
384
|
+
|
|
385
|
+
def register_model(
|
|
386
|
+
self,
|
|
387
|
+
name: str,
|
|
388
|
+
run_id: str = None,
|
|
389
|
+
description: str = None,
|
|
390
|
+
tags: dict = None
|
|
391
|
+
) -> RegisteredModel:
|
|
392
|
+
"""
|
|
393
|
+
Register a new model.
|
|
394
|
+
|
|
395
|
+
If ``run_id`` is provided a first model version linked to that run is
|
|
396
|
+
created automatically (the backend ``POST /models`` only accepts
|
|
397
|
+
*name*, *description* and *tags* — ``run_id`` belongs on the version).
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
name: Model name
|
|
401
|
+
run_id: Optional run ID — creates version 1 linked to this run
|
|
402
|
+
description: Optional description
|
|
403
|
+
tags: Optional tags dict
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
RegisteredModel object
|
|
407
|
+
"""
|
|
408
|
+
body = {"name": name}
|
|
409
|
+
if description:
|
|
410
|
+
body["description"] = description
|
|
411
|
+
if tags:
|
|
412
|
+
body["tags"] = tags
|
|
413
|
+
|
|
414
|
+
data = self._request("POST", "/models", json=body)
|
|
415
|
+
model_data = data.get("model", data)
|
|
416
|
+
model = RegisteredModel.from_dict(model_data, client=self)
|
|
417
|
+
|
|
418
|
+
# Auto-create version 1 when run_id is provided
|
|
419
|
+
if run_id:
|
|
420
|
+
artifact_dir = self._get_artifact_dir(run_id)
|
|
421
|
+
source = f"runs/{run_id}/artifacts/model"
|
|
422
|
+
if os.path.isdir(artifact_dir):
|
|
423
|
+
source = artifact_dir
|
|
424
|
+
self.create_model_version(
|
|
425
|
+
model.id, run_id=run_id, source=source
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
return model
|
|
429
|
+
|
|
430
|
+
def get_model(self, name_or_id: str) -> RegisteredModel:
|
|
431
|
+
"""
|
|
432
|
+
Get a model by name or ID.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
name_or_id: Model name or ID
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
RegisteredModel object
|
|
439
|
+
"""
|
|
440
|
+
# Try by ID first
|
|
441
|
+
try:
|
|
442
|
+
data = self._request("GET", f"/models/{name_or_id}")
|
|
443
|
+
model_data = data.get("model", data)
|
|
444
|
+
return RegisteredModel.from_dict(model_data, client=self)
|
|
445
|
+
except RegistryError:
|
|
446
|
+
pass
|
|
447
|
+
|
|
448
|
+
# Try by name
|
|
449
|
+
try:
|
|
450
|
+
data = self._request("GET", f"/models/name/{name_or_id}")
|
|
451
|
+
model_data = data.get("model", data)
|
|
452
|
+
return RegisteredModel.from_dict(model_data, client=self)
|
|
453
|
+
except RegistryError as e:
|
|
454
|
+
if e.code == "not_found":
|
|
455
|
+
raise ModelNotFoundError(name_or_id)
|
|
456
|
+
raise
|
|
457
|
+
|
|
458
|
+
def list_models(self, limit: int = 20, offset: int = 0) -> List[RegisteredModel]:
|
|
459
|
+
"""List registered models."""
|
|
460
|
+
data = self._request("GET", "/models", params={
|
|
461
|
+
"limit": limit,
|
|
462
|
+
"offset": offset
|
|
463
|
+
})
|
|
464
|
+
return [RegisteredModel.from_dict(m, client=self) for m in data.get("models", [])]
|
|
465
|
+
|
|
466
|
+
def delete_model(self, model_id: str):
|
|
467
|
+
"""Delete a model."""
|
|
468
|
+
self._request("DELETE", f"/models/{model_id}")
|
|
469
|
+
|
|
470
|
+
def create_model_version(
|
|
471
|
+
self,
|
|
472
|
+
model_id: str,
|
|
473
|
+
run_id: str = None,
|
|
474
|
+
source: str = None,
|
|
475
|
+
description: str = None
|
|
476
|
+
) -> ModelVersion:
|
|
477
|
+
"""
|
|
478
|
+
Create a new version of a model.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
model_id: Model ID
|
|
482
|
+
run_id: Optional run ID to link
|
|
483
|
+
source: Optional source path
|
|
484
|
+
description: Optional description
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
ModelVersion object
|
|
488
|
+
"""
|
|
489
|
+
body = {}
|
|
490
|
+
if run_id:
|
|
491
|
+
body["run_id"] = run_id
|
|
492
|
+
if source:
|
|
493
|
+
body["source"] = source
|
|
494
|
+
if description:
|
|
495
|
+
body["description"] = description
|
|
496
|
+
|
|
497
|
+
data = self._request("POST", f"/models/{model_id}/versions", json=body)
|
|
498
|
+
version_data = data.get("version", data)
|
|
499
|
+
return ModelVersion.from_dict(version_data, client=self)
|
|
500
|
+
|
|
501
|
+
def get_model_version(self, model_id: str, version: int) -> ModelVersion:
|
|
502
|
+
"""Get a specific version of a model."""
|
|
503
|
+
try:
|
|
504
|
+
data = self._request("GET", f"/models/{model_id}/versions/{version}")
|
|
505
|
+
version_data = data.get("version", data)
|
|
506
|
+
return ModelVersion.from_dict(version_data, client=self)
|
|
507
|
+
except RegistryError as e:
|
|
508
|
+
if e.code == "not_found":
|
|
509
|
+
raise ModelVersionNotFoundError(model_id, version)
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
def list_model_versions(
|
|
513
|
+
self,
|
|
514
|
+
model_id: str,
|
|
515
|
+
limit: int = 20,
|
|
516
|
+
offset: int = 0
|
|
517
|
+
) -> List[ModelVersion]:
|
|
518
|
+
"""List versions of a model."""
|
|
519
|
+
data = self._request("GET", f"/models/{model_id}/versions", params={
|
|
520
|
+
"limit": limit,
|
|
521
|
+
"offset": offset
|
|
522
|
+
})
|
|
523
|
+
return [ModelVersion.from_dict(v, client=self) for v in data.get("versions", [])]
|
|
524
|
+
|
|
525
|
+
def set_model_stage(
|
|
526
|
+
self,
|
|
527
|
+
model_name: str,
|
|
528
|
+
version: int,
|
|
529
|
+
stage: str,
|
|
530
|
+
comment: str = None
|
|
531
|
+
) -> ModelVersion:
|
|
532
|
+
"""
|
|
533
|
+
Transition a model version to a new stage.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
model_name: Model name
|
|
537
|
+
version: Version number
|
|
538
|
+
stage: Target stage (development, staging, production, archived)
|
|
539
|
+
comment: Optional comment
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
Updated ModelVersion
|
|
543
|
+
"""
|
|
544
|
+
valid_stages = ["development", "staging", "production", "archived"]
|
|
545
|
+
if stage not in valid_stages:
|
|
546
|
+
raise InvalidStageError(stage)
|
|
547
|
+
|
|
548
|
+
# Get model by name to get ID
|
|
549
|
+
model = self.get_model(model_name)
|
|
550
|
+
|
|
551
|
+
body = {"stage": stage}
|
|
552
|
+
if comment:
|
|
553
|
+
body["comment"] = comment
|
|
554
|
+
|
|
555
|
+
data = self._request(
|
|
556
|
+
"PUT",
|
|
557
|
+
f"/models/{model.id}/versions/{version}/stage",
|
|
558
|
+
json=body
|
|
559
|
+
)
|
|
560
|
+
version_data = data.get("version", data)
|
|
561
|
+
return ModelVersion.from_dict(version_data, client=self)
|
|
562
|
+
|
|
563
|
+
def transition_model_stage(
|
|
564
|
+
self,
|
|
565
|
+
model_id: str,
|
|
566
|
+
version: int,
|
|
567
|
+
stage: str,
|
|
568
|
+
comment: str = None
|
|
569
|
+
) -> ModelVersion:
|
|
570
|
+
"""Transition a model version stage by model ID."""
|
|
571
|
+
valid_stages = ["development", "staging", "production", "archived"]
|
|
572
|
+
if stage not in valid_stages:
|
|
573
|
+
raise InvalidStageError(stage)
|
|
574
|
+
|
|
575
|
+
body = {"stage": stage}
|
|
576
|
+
if comment:
|
|
577
|
+
body["comment"] = comment
|
|
578
|
+
|
|
579
|
+
data = self._request(
|
|
580
|
+
"PUT",
|
|
581
|
+
f"/models/{model_id}/versions/{version}/stage",
|
|
582
|
+
json=body
|
|
583
|
+
)
|
|
584
|
+
version_data = data.get("version", data)
|
|
585
|
+
return ModelVersion.from_dict(version_data, client=self)
|
|
586
|
+
|
|
587
|
+
def set_model_alias(
|
|
588
|
+
self,
|
|
589
|
+
model_name: str,
|
|
590
|
+
alias: str,
|
|
591
|
+
version: int
|
|
592
|
+
) -> ModelAlias:
|
|
593
|
+
"""
|
|
594
|
+
Set an alias for a model version.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
model_name: Model name
|
|
598
|
+
alias: Alias name (e.g., "champion", "challenger")
|
|
599
|
+
version: Version number
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
ModelAlias object
|
|
603
|
+
"""
|
|
604
|
+
model = self.get_model(model_name)
|
|
605
|
+
data = self._request("POST", f"/models/{model.id}/aliases", json={
|
|
606
|
+
"alias": alias,
|
|
607
|
+
"version": version
|
|
608
|
+
})
|
|
609
|
+
alias_data = data.get("alias", data)
|
|
610
|
+
return ModelAlias.from_dict(alias_data)
|
|
611
|
+
|
|
612
|
+
def get_model_aliases(self, model_id: str) -> List[ModelAlias]:
|
|
613
|
+
"""Get all aliases for a model."""
|
|
614
|
+
data = self._request("GET", f"/models/{model_id}/aliases")
|
|
615
|
+
return [ModelAlias.from_dict(a) for a in data.get("aliases", [])]
|
|
616
|
+
|
|
617
|
+
def delete_model_alias(self, model_id: str, alias: str):
|
|
618
|
+
"""Delete a model alias."""
|
|
619
|
+
self._request("DELETE", f"/models/{model_id}/aliases/{alias}")
|
|
620
|
+
|
|
621
|
+
def get_model_by_alias(self, model_id: str, alias: str) -> ModelVersion:
|
|
622
|
+
"""Get a model version by alias."""
|
|
623
|
+
data = self._request("GET", f"/models/{model_id}/alias/{alias}")
|
|
624
|
+
version_data = data.get("version", data)
|
|
625
|
+
return ModelVersion.from_dict(version_data, client=self)
|
|
626
|
+
|
|
627
|
+
def get_stage_transitions(
|
|
628
|
+
self,
|
|
629
|
+
model_id: str,
|
|
630
|
+
version: int = None
|
|
631
|
+
) -> List[StageTransition]:
|
|
632
|
+
"""
|
|
633
|
+
Get stage transition history.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
model_id: Model ID
|
|
637
|
+
version: Optional version number to filter
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
List of StageTransition objects
|
|
641
|
+
"""
|
|
642
|
+
if version:
|
|
643
|
+
endpoint = f"/models/{model_id}/versions/{version}/transitions"
|
|
644
|
+
else:
|
|
645
|
+
endpoint = f"/models/{model_id}/transitions"
|
|
646
|
+
|
|
647
|
+
data = self._request("GET", endpoint)
|
|
648
|
+
return [StageTransition.from_dict(t) for t in data.get("transitions", [])]
|
|
649
|
+
|
|
650
|
+
# ==================== Model Artifact Methods ====================
|
|
651
|
+
|
|
652
|
+
def log_model(
|
|
653
|
+
self,
|
|
654
|
+
model: Any,
|
|
655
|
+
artifact_path: str = "model",
|
|
656
|
+
framework: str = None,
|
|
657
|
+
metadata: Dict[str, str] = None
|
|
658
|
+
):
|
|
659
|
+
"""
|
|
660
|
+
Serialize a model to the local artifact directory for the active run.
|
|
661
|
+
|
|
662
|
+
The model is saved under ``~/.podstack/artifacts/<run_id>/<artifact_path>/``
|
|
663
|
+
and its metadata (framework, path) is recorded as run params so it can
|
|
664
|
+
be retrieved later with :meth:`load_model`.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
model: The model object to save (PyTorch, TensorFlow, sklearn, HuggingFace, etc.)
|
|
668
|
+
artifact_path: Sub-path inside the artifact dir (default: "model")
|
|
669
|
+
framework: Framework name. Auto-detected if not provided.
|
|
670
|
+
metadata: Optional metadata dict stored as run params.
|
|
671
|
+
|
|
672
|
+
Raises:
|
|
673
|
+
NoActiveRunError: If no run is active.
|
|
674
|
+
ModelSerializationError: If model serialization fails.
|
|
675
|
+
"""
|
|
676
|
+
from .model_utils import save_model, detect_framework
|
|
677
|
+
|
|
678
|
+
if not self._active_run:
|
|
679
|
+
raise NoActiveRunError()
|
|
680
|
+
|
|
681
|
+
if framework is None:
|
|
682
|
+
framework = detect_framework(model)
|
|
683
|
+
|
|
684
|
+
# Serialize model into the persistent artifact directory
|
|
685
|
+
artifact_dir = self._get_artifact_dir(self._active_run.id)
|
|
686
|
+
model_dir = os.path.join(artifact_dir, artifact_path)
|
|
687
|
+
save_model(model, model_dir, framework)
|
|
688
|
+
|
|
689
|
+
# Record model metadata as params (backend supports POST /runs/:id/params)
|
|
690
|
+
model_params = {
|
|
691
|
+
"_model.framework": framework,
|
|
692
|
+
"_model.artifact_path": artifact_path,
|
|
693
|
+
"_model.local_dir": model_dir,
|
|
694
|
+
}
|
|
695
|
+
if metadata:
|
|
696
|
+
for key, value in metadata.items():
|
|
697
|
+
model_params[f"_model.{key}"] = str(value)
|
|
698
|
+
|
|
699
|
+
self.log_params(model_params)
|
|
700
|
+
|
|
701
|
+
def get_model_version_by_stage(
|
|
702
|
+
self,
|
|
703
|
+
model_id: str,
|
|
704
|
+
stage: str
|
|
705
|
+
) -> ModelVersion:
|
|
706
|
+
"""
|
|
707
|
+
Get the model version currently assigned to a stage.
|
|
708
|
+
|
|
709
|
+
Uses the dedicated ``GET /models/:id/stage/:stage`` backend endpoint.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
model_id: Model ID.
|
|
713
|
+
stage: One of development, staging, production, archived.
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
ModelVersion object.
|
|
717
|
+
"""
|
|
718
|
+
data = self._request("GET", f"/models/{model_id}/stage/{stage}")
|
|
719
|
+
version_data = data.get("version", data)
|
|
720
|
+
return ModelVersion.from_dict(version_data, client=self)
|
|
721
|
+
|
|
722
|
+
def load_model(
|
|
723
|
+
self,
|
|
724
|
+
model_name: str,
|
|
725
|
+
version: int = None,
|
|
726
|
+
stage: str = None,
|
|
727
|
+
framework: str = None
|
|
728
|
+
) -> Any:
|
|
729
|
+
"""
|
|
730
|
+
Load a previously-saved model from the local artifact directory.
|
|
731
|
+
|
|
732
|
+
Resolves the model version (by explicit version number, stage via
|
|
733
|
+
``GET /models/:id/stage/:stage``, or latest), reads framework metadata
|
|
734
|
+
from the run's params, then deserializes the model from the local
|
|
735
|
+
artifact directory.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
model_name: Registered model name.
|
|
739
|
+
version: Version number to load. Mutually exclusive with *stage*.
|
|
740
|
+
stage: Stage to load from (e.g. "production"). Mutually exclusive with *version*.
|
|
741
|
+
framework: Framework name for deserialization.
|
|
742
|
+
Read from run params if not provided.
|
|
743
|
+
|
|
744
|
+
Returns:
|
|
745
|
+
The loaded model object.
|
|
746
|
+
|
|
747
|
+
Raises:
|
|
748
|
+
ModelNotFoundError: If the model or version is not found.
|
|
749
|
+
ArtifactNotFoundError: If the local artifact directory is missing.
|
|
750
|
+
"""
|
|
751
|
+
from .model_utils import load_model_from_path
|
|
752
|
+
|
|
753
|
+
# Resolve model and version
|
|
754
|
+
registered_model = self.get_model(model_name)
|
|
755
|
+
|
|
756
|
+
if stage:
|
|
757
|
+
# Use the dedicated backend endpoint
|
|
758
|
+
try:
|
|
759
|
+
model_version = self.get_model_version_by_stage(
|
|
760
|
+
registered_model.id, stage
|
|
761
|
+
)
|
|
762
|
+
except RegistryError:
|
|
763
|
+
raise ModelNotFoundError(f"{model_name} (stage={stage})")
|
|
764
|
+
elif version is not None:
|
|
765
|
+
model_version = self.get_model_version(registered_model.id, version)
|
|
766
|
+
else:
|
|
767
|
+
model_version = self.get_model_version(
|
|
768
|
+
registered_model.id, registered_model.latest_version
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Read model metadata from run params
|
|
772
|
+
artifact_path = "model"
|
|
773
|
+
if model_version.run_id:
|
|
774
|
+
run_params = self.get_run_params(model_version.run_id)
|
|
775
|
+
params_dict = {p.key: p.value for p in run_params}
|
|
776
|
+
if not framework:
|
|
777
|
+
framework = params_dict.get("_model.framework", "pickle")
|
|
778
|
+
artifact_path = params_dict.get("_model.artifact_path", "model")
|
|
779
|
+
|
|
780
|
+
framework = framework or "pickle"
|
|
781
|
+
|
|
782
|
+
# Locate model on disk
|
|
783
|
+
if model_version.run_id:
|
|
784
|
+
model_dir = os.path.join(
|
|
785
|
+
self._get_artifact_dir(model_version.run_id), artifact_path
|
|
786
|
+
)
|
|
787
|
+
else:
|
|
788
|
+
model_dir = os.path.join(
|
|
789
|
+
self._get_artifact_dir(model_version.id), artifact_path
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
if not os.path.exists(model_dir):
|
|
793
|
+
raise ArtifactNotFoundError(
|
|
794
|
+
model_version.run_id or model_version.id, artifact_path
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
return load_model_from_path(model_dir, framework)
|
|
798
|
+
|
|
799
|
+
def log_dataset(
|
|
800
|
+
self,
|
|
801
|
+
name: str,
|
|
802
|
+
path: str = None,
|
|
803
|
+
version: str = None,
|
|
804
|
+
description: str = None,
|
|
805
|
+
digest: str = None,
|
|
806
|
+
num_rows: int = None,
|
|
807
|
+
num_features: int = None
|
|
808
|
+
):
|
|
809
|
+
"""
|
|
810
|
+
Log dataset metadata for the active run.
|
|
811
|
+
|
|
812
|
+
All metadata is stored as run params via ``POST /runs/:id/params``
|
|
813
|
+
using a ``dataset.`` prefix for easy retrieval.
|
|
814
|
+
|
|
815
|
+
Args:
|
|
816
|
+
name: Dataset name.
|
|
817
|
+
path: Dataset path or URI (e.g., "s3://bucket/data").
|
|
818
|
+
version: Dataset version string.
|
|
819
|
+
description: Dataset description.
|
|
820
|
+
digest: Hash/digest of the dataset for reproducibility.
|
|
821
|
+
num_rows: Number of rows/samples in the dataset.
|
|
822
|
+
num_features: Number of features/columns.
|
|
823
|
+
|
|
824
|
+
Raises:
|
|
825
|
+
NoActiveRunError: If no run is active.
|
|
826
|
+
"""
|
|
827
|
+
if not self._active_run:
|
|
828
|
+
raise NoActiveRunError()
|
|
829
|
+
|
|
830
|
+
params = {"dataset.name": name}
|
|
831
|
+
if path:
|
|
832
|
+
params["dataset.path"] = path
|
|
833
|
+
if version:
|
|
834
|
+
params["dataset.version"] = version
|
|
835
|
+
if description:
|
|
836
|
+
params["dataset.description"] = description
|
|
837
|
+
if digest:
|
|
838
|
+
params["dataset.digest"] = digest
|
|
839
|
+
if num_rows is not None:
|
|
840
|
+
params["dataset.num_rows"] = str(num_rows)
|
|
841
|
+
if num_features is not None:
|
|
842
|
+
params["dataset.num_features"] = str(num_features)
|
|
843
|
+
|
|
844
|
+
self.log_params(params)
|
|
845
|
+
|
|
846
|
+
def compare_runs(
|
|
847
|
+
self,
|
|
848
|
+
run_ids: List[str],
|
|
849
|
+
metric_keys: List[str] = None
|
|
850
|
+
) -> Dict[str, Any]:
|
|
851
|
+
"""
|
|
852
|
+
Compare multiple runs side by side.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
run_ids: List of run IDs to compare.
|
|
856
|
+
metric_keys: Optional list of metric keys to include.
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
Dict with structured comparison data including params and metrics
|
|
860
|
+
for each run.
|
|
861
|
+
"""
|
|
862
|
+
body = {"run_ids": run_ids}
|
|
863
|
+
if metric_keys:
|
|
864
|
+
body["metric_keys"] = metric_keys
|
|
865
|
+
|
|
866
|
+
return self._request("POST", "/runs/compare", json=body)
|
|
867
|
+
|
|
868
|
+
def get_metric_history(
|
|
869
|
+
self,
|
|
870
|
+
run_id: str,
|
|
871
|
+
metric_key: str
|
|
872
|
+
) -> List[Metric]:
|
|
873
|
+
"""
|
|
874
|
+
Get the full history of a metric across all steps for a run.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
run_id: Run ID.
|
|
878
|
+
metric_key: Metric key to retrieve history for.
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
List of Metric objects ordered by step.
|
|
882
|
+
"""
|
|
883
|
+
data = self._request("GET", f"/runs/{run_id}/metrics/{metric_key}/history")
|
|
884
|
+
return [Metric.from_dict(m) for m in data.get("metrics", [])]
|
|
885
|
+
|
|
886
|
+
def download_artifact(
|
|
887
|
+
self,
|
|
888
|
+
run_id: str,
|
|
889
|
+
artifact_path: str,
|
|
890
|
+
local_path: str
|
|
891
|
+
) -> str:
|
|
892
|
+
"""
|
|
893
|
+
Copy an artifact from the local artifact store to *local_path*.
|
|
894
|
+
|
|
895
|
+
Artifacts are stored on the local / shared filesystem under
|
|
896
|
+
``~/.podstack/artifacts/<run_id>/``. The backend registry service
|
|
897
|
+
does not expose an artifact-download endpoint, so this method reads
|
|
898
|
+
from the same directory that :meth:`log_artifact` / :meth:`log_model`
|
|
899
|
+
wrote to.
|
|
900
|
+
|
|
901
|
+
Args:
|
|
902
|
+
run_id: Run ID.
|
|
903
|
+
artifact_path: Relative path within the run's artifact directory.
|
|
904
|
+
local_path: Destination directory.
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
Absolute path to the copied artifact.
|
|
908
|
+
|
|
909
|
+
Raises:
|
|
910
|
+
ArtifactNotFoundError: If the artifact is not on disk.
|
|
911
|
+
"""
|
|
912
|
+
src = os.path.join(self._get_artifact_dir(run_id), artifact_path)
|
|
913
|
+
|
|
914
|
+
if not os.path.exists(src):
|
|
915
|
+
raise ArtifactNotFoundError(run_id, artifact_path)
|
|
916
|
+
|
|
917
|
+
os.makedirs(local_path, exist_ok=True)
|
|
918
|
+
dest = os.path.join(local_path, artifact_path)
|
|
919
|
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
|
920
|
+
|
|
921
|
+
if os.path.isdir(src):
|
|
922
|
+
shutil.copytree(src, dest, dirs_exist_ok=True)
|
|
923
|
+
else:
|
|
924
|
+
shutil.copy2(src, dest)
|
|
925
|
+
|
|
926
|
+
return dest
|
|
927
|
+
|
|
928
|
+
def search_runs(
|
|
929
|
+
self,
|
|
930
|
+
experiment_id: str = None,
|
|
931
|
+
status: str = None,
|
|
932
|
+
max_results: int = 100,
|
|
933
|
+
offset: int = 0
|
|
934
|
+
) -> List[Run]:
|
|
935
|
+
"""
|
|
936
|
+
Search / list runs.
|
|
937
|
+
|
|
938
|
+
The backend ``GET /runs`` supports filtering by *experiment_id* and
|
|
939
|
+
*status* with *limit* / *offset* pagination.
|
|
940
|
+
|
|
941
|
+
Args:
|
|
942
|
+
experiment_id: Filter by experiment ID.
|
|
943
|
+
status: Filter by status (running, completed, failed, cancelled).
|
|
944
|
+
max_results: Maximum number of results (default 100).
|
|
945
|
+
offset: Pagination offset.
|
|
946
|
+
|
|
947
|
+
Returns:
|
|
948
|
+
List of matching Run objects.
|
|
949
|
+
"""
|
|
950
|
+
params: Dict[str, Any] = {"limit": max_results, "offset": offset}
|
|
951
|
+
if experiment_id:
|
|
952
|
+
params["experiment_id"] = experiment_id
|
|
953
|
+
if status:
|
|
954
|
+
params["status"] = status
|
|
955
|
+
|
|
956
|
+
data = self._request("GET", "/runs", params=params)
|
|
957
|
+
return [Run.from_dict(r, client=self) for r in data.get("runs", [])]
|