podstack 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,957 @@
1
+ """
2
+ Registry Client
3
+
4
+ Client for interacting with the Podstack Registry API.
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import tempfile
10
+ import shutil
11
+ from typing import Optional, Dict, Any, List
12
+ import requests
13
+
14
+ from .experiment import Experiment, Run, Metric, Param
15
+ from .model import RegisteredModel, ModelVersion, ModelAlias, StageTransition
16
+ from .exceptions import (
17
+ RegistryError,
18
+ ExperimentNotFoundError,
19
+ RunNotFoundError,
20
+ ModelNotFoundError,
21
+ NoActiveRunError,
22
+ NoExperimentSetError,
23
+ InvalidStageError,
24
+ ModelVersionNotFoundError,
25
+ ArtifactNotFoundError,
26
+ ModelSerializationError,
27
+ )
28
+ from ..exceptions import AuthenticationError
29
+
30
+
31
+ class RegistryClient:
32
+ """
33
+ Podstack Registry Client for experiment tracking and model management.
34
+
35
+ Usage:
36
+ client = RegistryClient(
37
+ api_url="https://cloud.podstack.ai/registry",
38
+ api_key="your-api-key",
39
+ project_id="your-project-id"
40
+ )
41
+
42
+ # Set experiment and track runs
43
+ client.set_experiment("my-experiment")
44
+ with client.start_run(name="training") as run:
45
+ client.log_params({"lr": 0.001})
46
+ client.log_metrics({"loss": 0.5}, step=1)
47
+
48
+ # Register and manage models
49
+ model = client.register_model("my-model", run_id=run.id)
50
+ client.set_model_stage("my-model", version=1, stage="production")
51
+ """
52
+
53
+ DEFAULT_API_URL = "https://cloud.podstack.ai/registry"
54
+
55
+ def __init__(
56
+ self,
57
+ api_url: str = None,
58
+ api_key: str = None,
59
+ project_id: str = None,
60
+ timeout: int = 30
61
+ ):
62
+ """
63
+ Initialize the registry client.
64
+
65
+ Args:
66
+ api_url: Registry API URL. Defaults to PODSTACK_REGISTRY_URL env var.
67
+ api_key: API key (psk_* token). Defaults to PODSTACK_API_KEY env var.
68
+ project_id: Project ID. Defaults to PODSTACK_PROJECT_ID env var.
69
+ timeout: Request timeout in seconds.
70
+ """
71
+ self.api_url = api_url or os.getenv("PODSTACK_REGISTRY_URL", self.DEFAULT_API_URL)
72
+ self.api_key = api_key or os.getenv("PODSTACK_API_KEY")
73
+ self.project_id = project_id or os.getenv("PODSTACK_PROJECT_ID")
74
+ self.timeout = timeout
75
+
76
+ # State
77
+ self._experiment_id: Optional[str] = None
78
+ self._experiment: Optional[Experiment] = None
79
+ self._active_run: Optional[Run] = None
80
+
81
+ def _get_headers(self) -> Dict[str, str]:
82
+ """Get request headers with auth and project ID."""
83
+ headers = {
84
+ "Content-Type": "application/json",
85
+ }
86
+ if self.api_key:
87
+ headers["Authorization"] = f"Bearer {self.api_key}"
88
+ if self.project_id:
89
+ headers["X-Project-ID"] = self.project_id
90
+ return headers
91
+
92
+ def _request(
93
+ self,
94
+ method: str,
95
+ endpoint: str,
96
+ json: Dict[str, Any] = None,
97
+ params: Dict[str, Any] = None,
98
+ files: Dict = None
99
+ ) -> Dict[str, Any]:
100
+ """
101
+ Make an API request.
102
+
103
+ Args:
104
+ method: HTTP method
105
+ endpoint: API endpoint path
106
+ json: JSON body
107
+ params: Query parameters
108
+ files: Files to upload
109
+
110
+ Returns:
111
+ Response data as dict
112
+ """
113
+ url = f"{self.api_url}/api/v1{endpoint}"
114
+ headers = self._get_headers()
115
+
116
+ # Remove Content-Type for file uploads
117
+ if files:
118
+ del headers["Content-Type"]
119
+
120
+ try:
121
+ response = requests.request(
122
+ method=method,
123
+ url=url,
124
+ headers=headers,
125
+ json=json,
126
+ params=params,
127
+ files=files,
128
+ timeout=self.timeout
129
+ )
130
+
131
+ if response.status_code == 401:
132
+ raise AuthenticationError("Invalid or expired API key")
133
+ elif response.status_code == 404:
134
+ error_msg = response.json().get("error", "Not found")
135
+ raise RegistryError(error_msg, code="not_found")
136
+ elif response.status_code >= 400:
137
+ error_data = response.json() if response.content else {}
138
+ error_msg = error_data.get("error", f"Request failed: {response.status_code}")
139
+ raise RegistryError(error_msg)
140
+
141
+ if response.status_code == 204 or not response.content:
142
+ return {}
143
+ return response.json()
144
+
145
+ except requests.exceptions.Timeout:
146
+ raise RegistryError("Request timed out", code="timeout")
147
+ except requests.exceptions.ConnectionError as e:
148
+ raise RegistryError(f"Connection error: {e}", code="connection_error")
149
+
150
+ # ==================== Experiment Methods ====================
151
+
152
+ def set_experiment(self, name: str, description: str = None) -> Experiment:
153
+ """
154
+ Set the active experiment. Creates if doesn't exist.
155
+
156
+ Args:
157
+ name: Experiment name
158
+ description: Optional description
159
+
160
+ Returns:
161
+ Experiment object
162
+ """
163
+ # Try to get existing experiment
164
+ try:
165
+ data = self._request("GET", f"/experiments/name/{name}")
166
+ self._experiment = Experiment.from_dict(data)
167
+ except RegistryError:
168
+ # Create new experiment
169
+ data = self._request("POST", "/experiments", json={
170
+ "name": name,
171
+ "description": description or ""
172
+ })
173
+ exp_data = data.get("experiment", data)
174
+ self._experiment = Experiment.from_dict(exp_data)
175
+
176
+ self._experiment_id = self._experiment.id
177
+ return self._experiment
178
+
179
+ def get_experiment(self, experiment_id: str) -> Experiment:
180
+ """Get an experiment by ID."""
181
+ try:
182
+ data = self._request("GET", f"/experiments/{experiment_id}")
183
+ return Experiment.from_dict(data)
184
+ except RegistryError as e:
185
+ if e.code == "not_found":
186
+ raise ExperimentNotFoundError(experiment_id)
187
+ raise
188
+
189
+ def list_experiments(self, limit: int = 20, offset: int = 0) -> List[Experiment]:
190
+ """List experiments in the current project."""
191
+ data = self._request("GET", "/experiments", params={
192
+ "limit": limit,
193
+ "offset": offset
194
+ })
195
+ return [Experiment.from_dict(e) for e in data.get("experiments", [])]
196
+
197
+ def archive_experiment(self, experiment_id: str):
198
+ """Archive an experiment."""
199
+ self._request("DELETE", f"/experiments/{experiment_id}")
200
+
201
+ # ==================== Run Methods ====================
202
+
203
+ def start_run(self, name: str = None, tags: dict = None) -> Run:
204
+ """
205
+ Start a new run in the active experiment.
206
+
207
+ Args:
208
+ name: Optional run name
209
+ tags: Optional tags dict
210
+
211
+ Returns:
212
+ Run object (can be used as context manager)
213
+ """
214
+ if not self._experiment_id:
215
+ raise NoExperimentSetError()
216
+
217
+ data = self._request("POST", "/runs", json={
218
+ "experiment_id": self._experiment_id,
219
+ "name": name or f"run-{int(time.time())}",
220
+ "tags": tags or {}
221
+ })
222
+ run_data = data.get("run", data)
223
+ self._active_run = Run.from_dict(run_data, client=self)
224
+ return self._active_run
225
+
226
+ def get_run(self, run_id: str) -> Run:
227
+ """Get a run by ID."""
228
+ try:
229
+ data = self._request("GET", f"/runs/{run_id}")
230
+ return Run.from_dict(data, client=self)
231
+ except RegistryError as e:
232
+ if e.code == "not_found":
233
+ raise RunNotFoundError(run_id)
234
+ raise
235
+
236
+ def list_runs(
237
+ self,
238
+ experiment_id: str = None,
239
+ status: str = None,
240
+ limit: int = 20,
241
+ offset: int = 0
242
+ ) -> List[Run]:
243
+ """
244
+ List runs.
245
+
246
+ Args:
247
+ experiment_id: Filter by experiment ID
248
+ status: Filter by status (running, completed, failed, killed)
249
+ limit: Max results
250
+ offset: Offset for pagination
251
+
252
+ Returns:
253
+ List of Run objects
254
+ """
255
+ params = {"limit": limit, "offset": offset}
256
+ if experiment_id:
257
+ params["experiment_id"] = experiment_id
258
+ if status:
259
+ params["status"] = status
260
+
261
+ data = self._request("GET", "/runs", params=params)
262
+ return [Run.from_dict(r, client=self) for r in data.get("runs", [])]
263
+
264
+ def end_run(self, status: str = "completed"):
265
+ """
266
+ End the active run.
267
+
268
+ Args:
269
+ status: Run status (completed, failed, killed)
270
+ """
271
+ if not self._active_run:
272
+ raise NoActiveRunError()
273
+
274
+ self._request("POST", f"/runs/{self._active_run.id}/end", json={
275
+ "status": status
276
+ })
277
+ self._active_run.status = status
278
+ self._active_run = None
279
+
280
+ def log_params(self, params: Dict[str, Any]):
281
+ """
282
+ Log parameters for the active run.
283
+
284
+ Args:
285
+ params: Dict of parameter names to values
286
+ """
287
+ if not self._active_run:
288
+ raise NoActiveRunError()
289
+
290
+ # Convert all values to strings
291
+ str_params = {k: str(v) for k, v in params.items()}
292
+ self._request("POST", f"/runs/{self._active_run.id}/params", json={
293
+ "params": str_params
294
+ })
295
+
296
+ def log_metrics(self, metrics: Dict[str, float], step: int = None):
297
+ """
298
+ Log metrics for the active run.
299
+
300
+ Args:
301
+ metrics: Dict of metric names to values
302
+ step: Optional step number
303
+ """
304
+ if not self._active_run:
305
+ raise NoActiveRunError()
306
+
307
+ timestamp = int(time.time() * 1000)
308
+ metrics_list = [
309
+ {"key": k, "value": float(v), "step": step or 0, "timestamp": timestamp}
310
+ for k, v in metrics.items()
311
+ ]
312
+ self._request("POST", f"/runs/{self._active_run.id}/metrics", json={
313
+ "metrics": metrics_list
314
+ })
315
+
316
+ @staticmethod
317
+ def _get_artifact_dir(run_id: str) -> str:
318
+ """Return the local artifact directory for a run."""
319
+ base = os.getenv(
320
+ "PODSTACK_ARTIFACT_DIR",
321
+ os.path.join(os.path.expanduser("~"), ".podstack", "artifacts"),
322
+ )
323
+ return os.path.join(base, run_id)
324
+
325
+ def log_artifact(self, local_path: str, artifact_path: str = None):
326
+ """
327
+ Log an artifact file for the active run.
328
+
329
+ Copies the file into a local artifact directory
330
+ (``~/.podstack/artifacts/<run_id>/``) and records the path as a param.
331
+ The backend registry service does not expose an artifact-upload
332
+ endpoint; artifact files live on the local / shared filesystem.
333
+
334
+ Args:
335
+ local_path: Local path to the file
336
+ artifact_path: Optional relative path within the artifact store
337
+ """
338
+ if not self._active_run:
339
+ raise NoActiveRunError()
340
+
341
+ artifact_path = artifact_path or os.path.basename(local_path)
342
+ dest_dir = self._get_artifact_dir(self._active_run.id)
343
+ dest = os.path.join(dest_dir, artifact_path)
344
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
345
+ shutil.copy2(local_path, dest)
346
+
347
+ # Record the artifact reference as a param
348
+ self._request("POST", f"/runs/{self._active_run.id}/params", json={
349
+ "params": {f"_artifact.{artifact_path}": dest}
350
+ })
351
+
352
+ def set_tag(self, key: str, value: str):
353
+ """
354
+ Set a tag on the active run.
355
+
356
+ Tags are persisted as params with a ``_tag.`` prefix via
357
+ ``POST /runs/:id/params`` (the backend has no dedicated tags endpoint).
358
+ The in-memory ``run.tags`` dict is also updated for local access.
359
+
360
+ Args:
361
+ key: Tag key
362
+ value: Tag value
363
+ """
364
+ if not self._active_run:
365
+ raise NoActiveRunError()
366
+
367
+ # Backend has no /runs/:id/tags endpoint; persist via params
368
+ self._request("POST", f"/runs/{self._active_run.id}/params", json={
369
+ "params": {f"_tag.{key}": str(value)}
370
+ })
371
+ self._active_run.tags[key] = value
372
+
373
+ def get_run_metrics(self, run_id: str) -> List[Metric]:
374
+ """Get all metrics for a run."""
375
+ data = self._request("GET", f"/runs/{run_id}/metrics")
376
+ return [Metric.from_dict(m) for m in data.get("metrics", [])]
377
+
378
+ def get_run_params(self, run_id: str) -> List[Param]:
379
+ """Get all parameters for a run."""
380
+ data = self._request("GET", f"/runs/{run_id}/params")
381
+ return [Param.from_dict(p) for p in data.get("params", [])]
382
+
383
+ # ==================== Model Registry Methods ====================
384
+
385
+ def register_model(
386
+ self,
387
+ name: str,
388
+ run_id: str = None,
389
+ description: str = None,
390
+ tags: dict = None
391
+ ) -> RegisteredModel:
392
+ """
393
+ Register a new model.
394
+
395
+ If ``run_id`` is provided a first model version linked to that run is
396
+ created automatically (the backend ``POST /models`` only accepts
397
+ *name*, *description* and *tags* — ``run_id`` belongs on the version).
398
+
399
+ Args:
400
+ name: Model name
401
+ run_id: Optional run ID — creates version 1 linked to this run
402
+ description: Optional description
403
+ tags: Optional tags dict
404
+
405
+ Returns:
406
+ RegisteredModel object
407
+ """
408
+ body = {"name": name}
409
+ if description:
410
+ body["description"] = description
411
+ if tags:
412
+ body["tags"] = tags
413
+
414
+ data = self._request("POST", "/models", json=body)
415
+ model_data = data.get("model", data)
416
+ model = RegisteredModel.from_dict(model_data, client=self)
417
+
418
+ # Auto-create version 1 when run_id is provided
419
+ if run_id:
420
+ artifact_dir = self._get_artifact_dir(run_id)
421
+ source = f"runs/{run_id}/artifacts/model"
422
+ if os.path.isdir(artifact_dir):
423
+ source = artifact_dir
424
+ self.create_model_version(
425
+ model.id, run_id=run_id, source=source
426
+ )
427
+
428
+ return model
429
+
430
+ def get_model(self, name_or_id: str) -> RegisteredModel:
431
+ """
432
+ Get a model by name or ID.
433
+
434
+ Args:
435
+ name_or_id: Model name or ID
436
+
437
+ Returns:
438
+ RegisteredModel object
439
+ """
440
+ # Try by ID first
441
+ try:
442
+ data = self._request("GET", f"/models/{name_or_id}")
443
+ model_data = data.get("model", data)
444
+ return RegisteredModel.from_dict(model_data, client=self)
445
+ except RegistryError:
446
+ pass
447
+
448
+ # Try by name
449
+ try:
450
+ data = self._request("GET", f"/models/name/{name_or_id}")
451
+ model_data = data.get("model", data)
452
+ return RegisteredModel.from_dict(model_data, client=self)
453
+ except RegistryError as e:
454
+ if e.code == "not_found":
455
+ raise ModelNotFoundError(name_or_id)
456
+ raise
457
+
458
+ def list_models(self, limit: int = 20, offset: int = 0) -> List[RegisteredModel]:
459
+ """List registered models."""
460
+ data = self._request("GET", "/models", params={
461
+ "limit": limit,
462
+ "offset": offset
463
+ })
464
+ return [RegisteredModel.from_dict(m, client=self) for m in data.get("models", [])]
465
+
466
+ def delete_model(self, model_id: str):
467
+ """Delete a model."""
468
+ self._request("DELETE", f"/models/{model_id}")
469
+
470
+ def create_model_version(
471
+ self,
472
+ model_id: str,
473
+ run_id: str = None,
474
+ source: str = None,
475
+ description: str = None
476
+ ) -> ModelVersion:
477
+ """
478
+ Create a new version of a model.
479
+
480
+ Args:
481
+ model_id: Model ID
482
+ run_id: Optional run ID to link
483
+ source: Optional source path
484
+ description: Optional description
485
+
486
+ Returns:
487
+ ModelVersion object
488
+ """
489
+ body = {}
490
+ if run_id:
491
+ body["run_id"] = run_id
492
+ if source:
493
+ body["source"] = source
494
+ if description:
495
+ body["description"] = description
496
+
497
+ data = self._request("POST", f"/models/{model_id}/versions", json=body)
498
+ version_data = data.get("version", data)
499
+ return ModelVersion.from_dict(version_data, client=self)
500
+
501
+ def get_model_version(self, model_id: str, version: int) -> ModelVersion:
502
+ """Get a specific version of a model."""
503
+ try:
504
+ data = self._request("GET", f"/models/{model_id}/versions/{version}")
505
+ version_data = data.get("version", data)
506
+ return ModelVersion.from_dict(version_data, client=self)
507
+ except RegistryError as e:
508
+ if e.code == "not_found":
509
+ raise ModelVersionNotFoundError(model_id, version)
510
+ raise
511
+
512
+ def list_model_versions(
513
+ self,
514
+ model_id: str,
515
+ limit: int = 20,
516
+ offset: int = 0
517
+ ) -> List[ModelVersion]:
518
+ """List versions of a model."""
519
+ data = self._request("GET", f"/models/{model_id}/versions", params={
520
+ "limit": limit,
521
+ "offset": offset
522
+ })
523
+ return [ModelVersion.from_dict(v, client=self) for v in data.get("versions", [])]
524
+
525
+ def set_model_stage(
526
+ self,
527
+ model_name: str,
528
+ version: int,
529
+ stage: str,
530
+ comment: str = None
531
+ ) -> ModelVersion:
532
+ """
533
+ Transition a model version to a new stage.
534
+
535
+ Args:
536
+ model_name: Model name
537
+ version: Version number
538
+ stage: Target stage (development, staging, production, archived)
539
+ comment: Optional comment
540
+
541
+ Returns:
542
+ Updated ModelVersion
543
+ """
544
+ valid_stages = ["development", "staging", "production", "archived"]
545
+ if stage not in valid_stages:
546
+ raise InvalidStageError(stage)
547
+
548
+ # Get model by name to get ID
549
+ model = self.get_model(model_name)
550
+
551
+ body = {"stage": stage}
552
+ if comment:
553
+ body["comment"] = comment
554
+
555
+ data = self._request(
556
+ "PUT",
557
+ f"/models/{model.id}/versions/{version}/stage",
558
+ json=body
559
+ )
560
+ version_data = data.get("version", data)
561
+ return ModelVersion.from_dict(version_data, client=self)
562
+
563
+ def transition_model_stage(
564
+ self,
565
+ model_id: str,
566
+ version: int,
567
+ stage: str,
568
+ comment: str = None
569
+ ) -> ModelVersion:
570
+ """Transition a model version stage by model ID."""
571
+ valid_stages = ["development", "staging", "production", "archived"]
572
+ if stage not in valid_stages:
573
+ raise InvalidStageError(stage)
574
+
575
+ body = {"stage": stage}
576
+ if comment:
577
+ body["comment"] = comment
578
+
579
+ data = self._request(
580
+ "PUT",
581
+ f"/models/{model_id}/versions/{version}/stage",
582
+ json=body
583
+ )
584
+ version_data = data.get("version", data)
585
+ return ModelVersion.from_dict(version_data, client=self)
586
+
587
+ def set_model_alias(
588
+ self,
589
+ model_name: str,
590
+ alias: str,
591
+ version: int
592
+ ) -> ModelAlias:
593
+ """
594
+ Set an alias for a model version.
595
+
596
+ Args:
597
+ model_name: Model name
598
+ alias: Alias name (e.g., "champion", "challenger")
599
+ version: Version number
600
+
601
+ Returns:
602
+ ModelAlias object
603
+ """
604
+ model = self.get_model(model_name)
605
+ data = self._request("POST", f"/models/{model.id}/aliases", json={
606
+ "alias": alias,
607
+ "version": version
608
+ })
609
+ alias_data = data.get("alias", data)
610
+ return ModelAlias.from_dict(alias_data)
611
+
612
+ def get_model_aliases(self, model_id: str) -> List[ModelAlias]:
613
+ """Get all aliases for a model."""
614
+ data = self._request("GET", f"/models/{model_id}/aliases")
615
+ return [ModelAlias.from_dict(a) for a in data.get("aliases", [])]
616
+
617
+ def delete_model_alias(self, model_id: str, alias: str):
618
+ """Delete a model alias."""
619
+ self._request("DELETE", f"/models/{model_id}/aliases/{alias}")
620
+
621
+ def get_model_by_alias(self, model_id: str, alias: str) -> ModelVersion:
622
+ """Get a model version by alias."""
623
+ data = self._request("GET", f"/models/{model_id}/alias/{alias}")
624
+ version_data = data.get("version", data)
625
+ return ModelVersion.from_dict(version_data, client=self)
626
+
627
+ def get_stage_transitions(
628
+ self,
629
+ model_id: str,
630
+ version: int = None
631
+ ) -> List[StageTransition]:
632
+ """
633
+ Get stage transition history.
634
+
635
+ Args:
636
+ model_id: Model ID
637
+ version: Optional version number to filter
638
+
639
+ Returns:
640
+ List of StageTransition objects
641
+ """
642
+ if version:
643
+ endpoint = f"/models/{model_id}/versions/{version}/transitions"
644
+ else:
645
+ endpoint = f"/models/{model_id}/transitions"
646
+
647
+ data = self._request("GET", endpoint)
648
+ return [StageTransition.from_dict(t) for t in data.get("transitions", [])]
649
+
650
+ # ==================== Model Artifact Methods ====================
651
+
652
+ def log_model(
653
+ self,
654
+ model: Any,
655
+ artifact_path: str = "model",
656
+ framework: str = None,
657
+ metadata: Dict[str, str] = None
658
+ ):
659
+ """
660
+ Serialize a model to the local artifact directory for the active run.
661
+
662
+ The model is saved under ``~/.podstack/artifacts/<run_id>/<artifact_path>/``
663
+ and its metadata (framework, path) is recorded as run params so it can
664
+ be retrieved later with :meth:`load_model`.
665
+
666
+ Args:
667
+ model: The model object to save (PyTorch, TensorFlow, sklearn, HuggingFace, etc.)
668
+ artifact_path: Sub-path inside the artifact dir (default: "model")
669
+ framework: Framework name. Auto-detected if not provided.
670
+ metadata: Optional metadata dict stored as run params.
671
+
672
+ Raises:
673
+ NoActiveRunError: If no run is active.
674
+ ModelSerializationError: If model serialization fails.
675
+ """
676
+ from .model_utils import save_model, detect_framework
677
+
678
+ if not self._active_run:
679
+ raise NoActiveRunError()
680
+
681
+ if framework is None:
682
+ framework = detect_framework(model)
683
+
684
+ # Serialize model into the persistent artifact directory
685
+ artifact_dir = self._get_artifact_dir(self._active_run.id)
686
+ model_dir = os.path.join(artifact_dir, artifact_path)
687
+ save_model(model, model_dir, framework)
688
+
689
+ # Record model metadata as params (backend supports POST /runs/:id/params)
690
+ model_params = {
691
+ "_model.framework": framework,
692
+ "_model.artifact_path": artifact_path,
693
+ "_model.local_dir": model_dir,
694
+ }
695
+ if metadata:
696
+ for key, value in metadata.items():
697
+ model_params[f"_model.{key}"] = str(value)
698
+
699
+ self.log_params(model_params)
700
+
701
+ def get_model_version_by_stage(
702
+ self,
703
+ model_id: str,
704
+ stage: str
705
+ ) -> ModelVersion:
706
+ """
707
+ Get the model version currently assigned to a stage.
708
+
709
+ Uses the dedicated ``GET /models/:id/stage/:stage`` backend endpoint.
710
+
711
+ Args:
712
+ model_id: Model ID.
713
+ stage: One of development, staging, production, archived.
714
+
715
+ Returns:
716
+ ModelVersion object.
717
+ """
718
+ data = self._request("GET", f"/models/{model_id}/stage/{stage}")
719
+ version_data = data.get("version", data)
720
+ return ModelVersion.from_dict(version_data, client=self)
721
+
722
+ def load_model(
723
+ self,
724
+ model_name: str,
725
+ version: int = None,
726
+ stage: str = None,
727
+ framework: str = None
728
+ ) -> Any:
729
+ """
730
+ Load a previously-saved model from the local artifact directory.
731
+
732
+ Resolves the model version (by explicit version number, stage via
733
+ ``GET /models/:id/stage/:stage``, or latest), reads framework metadata
734
+ from the run's params, then deserializes the model from the local
735
+ artifact directory.
736
+
737
+ Args:
738
+ model_name: Registered model name.
739
+ version: Version number to load. Mutually exclusive with *stage*.
740
+ stage: Stage to load from (e.g. "production"). Mutually exclusive with *version*.
741
+ framework: Framework name for deserialization.
742
+ Read from run params if not provided.
743
+
744
+ Returns:
745
+ The loaded model object.
746
+
747
+ Raises:
748
+ ModelNotFoundError: If the model or version is not found.
749
+ ArtifactNotFoundError: If the local artifact directory is missing.
750
+ """
751
+ from .model_utils import load_model_from_path
752
+
753
+ # Resolve model and version
754
+ registered_model = self.get_model(model_name)
755
+
756
+ if stage:
757
+ # Use the dedicated backend endpoint
758
+ try:
759
+ model_version = self.get_model_version_by_stage(
760
+ registered_model.id, stage
761
+ )
762
+ except RegistryError:
763
+ raise ModelNotFoundError(f"{model_name} (stage={stage})")
764
+ elif version is not None:
765
+ model_version = self.get_model_version(registered_model.id, version)
766
+ else:
767
+ model_version = self.get_model_version(
768
+ registered_model.id, registered_model.latest_version
769
+ )
770
+
771
+ # Read model metadata from run params
772
+ artifact_path = "model"
773
+ if model_version.run_id:
774
+ run_params = self.get_run_params(model_version.run_id)
775
+ params_dict = {p.key: p.value for p in run_params}
776
+ if not framework:
777
+ framework = params_dict.get("_model.framework", "pickle")
778
+ artifact_path = params_dict.get("_model.artifact_path", "model")
779
+
780
+ framework = framework or "pickle"
781
+
782
+ # Locate model on disk
783
+ if model_version.run_id:
784
+ model_dir = os.path.join(
785
+ self._get_artifact_dir(model_version.run_id), artifact_path
786
+ )
787
+ else:
788
+ model_dir = os.path.join(
789
+ self._get_artifact_dir(model_version.id), artifact_path
790
+ )
791
+
792
+ if not os.path.exists(model_dir):
793
+ raise ArtifactNotFoundError(
794
+ model_version.run_id or model_version.id, artifact_path
795
+ )
796
+
797
+ return load_model_from_path(model_dir, framework)
798
+
799
+ def log_dataset(
800
+ self,
801
+ name: str,
802
+ path: str = None,
803
+ version: str = None,
804
+ description: str = None,
805
+ digest: str = None,
806
+ num_rows: int = None,
807
+ num_features: int = None
808
+ ):
809
+ """
810
+ Log dataset metadata for the active run.
811
+
812
+ All metadata is stored as run params via ``POST /runs/:id/params``
813
+ using a ``dataset.`` prefix for easy retrieval.
814
+
815
+ Args:
816
+ name: Dataset name.
817
+ path: Dataset path or URI (e.g., "s3://bucket/data").
818
+ version: Dataset version string.
819
+ description: Dataset description.
820
+ digest: Hash/digest of the dataset for reproducibility.
821
+ num_rows: Number of rows/samples in the dataset.
822
+ num_features: Number of features/columns.
823
+
824
+ Raises:
825
+ NoActiveRunError: If no run is active.
826
+ """
827
+ if not self._active_run:
828
+ raise NoActiveRunError()
829
+
830
+ params = {"dataset.name": name}
831
+ if path:
832
+ params["dataset.path"] = path
833
+ if version:
834
+ params["dataset.version"] = version
835
+ if description:
836
+ params["dataset.description"] = description
837
+ if digest:
838
+ params["dataset.digest"] = digest
839
+ if num_rows is not None:
840
+ params["dataset.num_rows"] = str(num_rows)
841
+ if num_features is not None:
842
+ params["dataset.num_features"] = str(num_features)
843
+
844
+ self.log_params(params)
845
+
846
+ def compare_runs(
847
+ self,
848
+ run_ids: List[str],
849
+ metric_keys: List[str] = None
850
+ ) -> Dict[str, Any]:
851
+ """
852
+ Compare multiple runs side by side.
853
+
854
+ Args:
855
+ run_ids: List of run IDs to compare.
856
+ metric_keys: Optional list of metric keys to include.
857
+
858
+ Returns:
859
+ Dict with structured comparison data including params and metrics
860
+ for each run.
861
+ """
862
+ body = {"run_ids": run_ids}
863
+ if metric_keys:
864
+ body["metric_keys"] = metric_keys
865
+
866
+ return self._request("POST", "/runs/compare", json=body)
867
+
868
+ def get_metric_history(
869
+ self,
870
+ run_id: str,
871
+ metric_key: str
872
+ ) -> List[Metric]:
873
+ """
874
+ Get the full history of a metric across all steps for a run.
875
+
876
+ Args:
877
+ run_id: Run ID.
878
+ metric_key: Metric key to retrieve history for.
879
+
880
+ Returns:
881
+ List of Metric objects ordered by step.
882
+ """
883
+ data = self._request("GET", f"/runs/{run_id}/metrics/{metric_key}/history")
884
+ return [Metric.from_dict(m) for m in data.get("metrics", [])]
885
+
886
+ def download_artifact(
887
+ self,
888
+ run_id: str,
889
+ artifact_path: str,
890
+ local_path: str
891
+ ) -> str:
892
+ """
893
+ Copy an artifact from the local artifact store to *local_path*.
894
+
895
+ Artifacts are stored on the local / shared filesystem under
896
+ ``~/.podstack/artifacts/<run_id>/``. The backend registry service
897
+ does not expose an artifact-download endpoint, so this method reads
898
+ from the same directory that :meth:`log_artifact` / :meth:`log_model`
899
+ wrote to.
900
+
901
+ Args:
902
+ run_id: Run ID.
903
+ artifact_path: Relative path within the run's artifact directory.
904
+ local_path: Destination directory.
905
+
906
+ Returns:
907
+ Absolute path to the copied artifact.
908
+
909
+ Raises:
910
+ ArtifactNotFoundError: If the artifact is not on disk.
911
+ """
912
+ src = os.path.join(self._get_artifact_dir(run_id), artifact_path)
913
+
914
+ if not os.path.exists(src):
915
+ raise ArtifactNotFoundError(run_id, artifact_path)
916
+
917
+ os.makedirs(local_path, exist_ok=True)
918
+ dest = os.path.join(local_path, artifact_path)
919
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
920
+
921
+ if os.path.isdir(src):
922
+ shutil.copytree(src, dest, dirs_exist_ok=True)
923
+ else:
924
+ shutil.copy2(src, dest)
925
+
926
+ return dest
927
+
928
+ def search_runs(
929
+ self,
930
+ experiment_id: str = None,
931
+ status: str = None,
932
+ max_results: int = 100,
933
+ offset: int = 0
934
+ ) -> List[Run]:
935
+ """
936
+ Search / list runs.
937
+
938
+ The backend ``GET /runs`` supports filtering by *experiment_id* and
939
+ *status* with *limit* / *offset* pagination.
940
+
941
+ Args:
942
+ experiment_id: Filter by experiment ID.
943
+ status: Filter by status (running, completed, failed, cancelled).
944
+ max_results: Maximum number of results (default 100).
945
+ offset: Pagination offset.
946
+
947
+ Returns:
948
+ List of matching Run objects.
949
+ """
950
+ params: Dict[str, Any] = {"limit": max_results, "offset": offset}
951
+ if experiment_id:
952
+ params["experiment_id"] = experiment_id
953
+ if status:
954
+ params["status"] = status
955
+
956
+ data = self._request("GET", "/runs", params=params)
957
+ return [Run.from_dict(r, client=self) for r in data.get("runs", [])]