podstack 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,402 @@
1
+ """
2
+ Podstack Registry - Experiment Tracking and Model Management
3
+
4
+ Usage:
5
+ from podstack import registry
6
+
7
+ # Initialize
8
+ registry.init(api_key="your-api-key", project_id="your-project-id")
9
+
10
+ # Set experiment
11
+ registry.set_experiment("my-experiment")
12
+
13
+ # Track a run
14
+ with registry.start_run(name="training-v1") as run:
15
+ registry.log_params({"learning_rate": 0.001, "batch_size": 32})
16
+ for epoch in range(10):
17
+ loss = train_epoch()
18
+ registry.log_metrics({"loss": loss}, step=epoch)
19
+
20
+ # Register a model
21
+ registry.register_model(
22
+ name="my-model",
23
+ run_id=run.id,
24
+ description="My trained model"
25
+ )
26
+ """
27
+
28
+ from .client import RegistryClient
29
+ from .experiment import Experiment, Run
30
+ from .model import RegisteredModel, ModelVersion, ModelAlias, StageTransition
31
+ from .exceptions import (
32
+ RegistryError,
33
+ ExperimentNotFoundError,
34
+ RunNotFoundError,
35
+ ModelNotFoundError,
36
+ NoActiveRunError,
37
+ NoExperimentSetError,
38
+ ArtifactNotFoundError,
39
+ ModelSerializationError,
40
+ FrameworkNotInstalledError,
41
+ )
42
+
43
+ __all__ = [
44
+ # Client
45
+ "RegistryClient",
46
+ "init",
47
+ "set_experiment",
48
+ "get_experiment",
49
+ "list_experiments",
50
+ "start_run",
51
+ "end_run",
52
+ "log_params",
53
+ "log_metrics",
54
+ "log_artifact",
55
+ "set_tag",
56
+ "register_model",
57
+ "get_model",
58
+ "list_models",
59
+ "set_model_stage",
60
+ "set_model_alias",
61
+ # New methods
62
+ "log_model",
63
+ "load_model",
64
+ "log_dataset",
65
+ "compare_runs",
66
+ "get_metric_history",
67
+ "download_artifact",
68
+ "search_runs",
69
+ # Classes
70
+ "Experiment",
71
+ "Run",
72
+ "RegisteredModel",
73
+ "ModelVersion",
74
+ "ModelAlias",
75
+ "StageTransition",
76
+ # Exceptions
77
+ "RegistryError",
78
+ "ExperimentNotFoundError",
79
+ "RunNotFoundError",
80
+ "ModelNotFoundError",
81
+ "NoActiveRunError",
82
+ "NoExperimentSetError",
83
+ "ArtifactNotFoundError",
84
+ "ModelSerializationError",
85
+ "FrameworkNotInstalledError",
86
+ ]
87
+
88
+ # Global client instance
89
+ _client: RegistryClient = None
90
+
91
+
92
+ def init(
93
+ api_url: str = None,
94
+ api_key: str = None,
95
+ project_id: str = None
96
+ ) -> RegistryClient:
97
+ """
98
+ Initialize the registry client.
99
+
100
+ Args:
101
+ api_url: Registry service URL. Defaults to PODSTACK_REGISTRY_URL env var.
102
+ api_key: API key for authentication. Defaults to PODSTACK_API_KEY env var.
103
+ project_id: Project ID for multi-tenant isolation. Defaults to PODSTACK_PROJECT_ID env var.
104
+
105
+ Returns:
106
+ RegistryClient instance
107
+ """
108
+ global _client
109
+ _client = RegistryClient(
110
+ api_url=api_url,
111
+ api_key=api_key,
112
+ project_id=project_id
113
+ )
114
+ return _client
115
+
116
+
117
+ def _get_client() -> RegistryClient:
118
+ """Get the global client instance, initializing if needed."""
119
+ global _client
120
+ if _client is None:
121
+ _client = RegistryClient()
122
+ return _client
123
+
124
+
125
+ def set_experiment(name: str, description: str = None) -> Experiment:
126
+ """
127
+ Set the active experiment. Creates if doesn't exist.
128
+
129
+ Args:
130
+ name: Experiment name
131
+ description: Optional description
132
+
133
+ Returns:
134
+ Experiment object
135
+ """
136
+ return _get_client().set_experiment(name, description)
137
+
138
+
139
+ def get_experiment(experiment_id: str) -> Experiment:
140
+ """Get an experiment by ID."""
141
+ return _get_client().get_experiment(experiment_id)
142
+
143
+
144
+ def list_experiments(limit: int = 20, offset: int = 0) -> list:
145
+ """List experiments in the current project."""
146
+ return _get_client().list_experiments(limit, offset)
147
+
148
+
149
+ def start_run(name: str = None, tags: dict = None) -> Run:
150
+ """
151
+ Start a new run in the active experiment.
152
+
153
+ Use as a context manager:
154
+ with registry.start_run(name="training") as run:
155
+ registry.log_params({"lr": 0.001})
156
+ registry.log_metrics({"loss": 0.5})
157
+
158
+ Args:
159
+ name: Optional run name
160
+ tags: Optional tags dict
161
+
162
+ Returns:
163
+ Run object (context manager)
164
+ """
165
+ return _get_client().start_run(name, tags)
166
+
167
+
168
+ def end_run(status: str = "completed"):
169
+ """
170
+ End the active run.
171
+
172
+ Args:
173
+ status: Run status (completed, failed, killed)
174
+ """
175
+ _get_client().end_run(status)
176
+
177
+
178
+ def log_params(params: dict):
179
+ """
180
+ Log parameters for the active run.
181
+
182
+ Args:
183
+ params: Dict of parameter names to values
184
+ """
185
+ _get_client().log_params(params)
186
+
187
+
188
+ def log_metrics(metrics: dict, step: int = None):
189
+ """
190
+ Log metrics for the active run.
191
+
192
+ Args:
193
+ metrics: Dict of metric names to values
194
+ step: Optional step number
195
+ """
196
+ _get_client().log_metrics(metrics, step)
197
+
198
+
199
+ def log_artifact(local_path: str, artifact_path: str = None):
200
+ """
201
+ Log an artifact file for the active run.
202
+
203
+ Args:
204
+ local_path: Local path to the file
205
+ artifact_path: Optional path within the artifact store
206
+ """
207
+ _get_client().log_artifact(local_path, artifact_path)
208
+
209
+
210
+ def set_tag(key: str, value: str):
211
+ """
212
+ Set a tag on the active run.
213
+
214
+ Args:
215
+ key: Tag key
216
+ value: Tag value
217
+ """
218
+ _get_client().set_tag(key, value)
219
+
220
+
221
+ def register_model(
222
+ name: str,
223
+ run_id: str = None,
224
+ description: str = None,
225
+ tags: dict = None
226
+ ) -> RegisteredModel:
227
+ """
228
+ Register a new model.
229
+
230
+ Args:
231
+ name: Model name
232
+ run_id: Optional run ID to link the model to
233
+ description: Optional description
234
+ tags: Optional tags dict
235
+
236
+ Returns:
237
+ RegisteredModel object
238
+ """
239
+ return _get_client().register_model(name, run_id, description, tags)
240
+
241
+
242
+ def get_model(name_or_id: str) -> RegisteredModel:
243
+ """Get a registered model by name or ID."""
244
+ return _get_client().get_model(name_or_id)
245
+
246
+
247
+ def list_models(limit: int = 20, offset: int = 0) -> list:
248
+ """List registered models in the current project."""
249
+ return _get_client().list_models(limit, offset)
250
+
251
+
252
+ def set_model_stage(
253
+ model_name: str,
254
+ version: int,
255
+ stage: str,
256
+ comment: str = None
257
+ ) -> ModelVersion:
258
+ """
259
+ Transition a model version to a new stage.
260
+
261
+ Args:
262
+ model_name: Model name
263
+ version: Version number
264
+ stage: Target stage (development, staging, production, archived)
265
+ comment: Optional comment
266
+
267
+ Returns:
268
+ Updated ModelVersion
269
+ """
270
+ return _get_client().set_model_stage(model_name, version, stage, comment)
271
+
272
+
273
+ def set_model_alias(model_name: str, alias: str, version: int) -> ModelAlias:
274
+ """
275
+ Set an alias for a model version.
276
+
277
+ Args:
278
+ model_name: Model name
279
+ alias: Alias name (e.g., "champion", "challenger")
280
+ version: Version number to alias
281
+
282
+ Returns:
283
+ ModelAlias object
284
+ """
285
+ return _get_client().set_model_alias(model_name, alias, version)
286
+
287
+
288
+ def log_model(model, artifact_path: str = "model", framework: str = None, metadata: dict = None):
289
+ """
290
+ Serialize and upload a model as an artifact for the active run.
291
+
292
+ Args:
293
+ model: The model object to save.
294
+ artifact_path: Path within the artifact store (default: "model").
295
+ framework: Framework name. Auto-detected if not provided.
296
+ metadata: Optional metadata dict.
297
+ """
298
+ _get_client().log_model(model, artifact_path, framework, metadata)
299
+
300
+
301
+ def load_model(model_name: str, version: int = None, stage: str = None, framework: str = None):
302
+ """
303
+ Download and deserialize a model from the registry.
304
+
305
+ Args:
306
+ model_name: Registered model name.
307
+ version: Version number to load.
308
+ stage: Stage to load from (e.g., "production").
309
+ framework: Framework name for deserialization.
310
+
311
+ Returns:
312
+ The loaded model object.
313
+ """
314
+ return _get_client().load_model(model_name, version, stage, framework)
315
+
316
+
317
+ def log_dataset(
318
+ name: str,
319
+ path: str = None,
320
+ version: str = None,
321
+ description: str = None,
322
+ digest: str = None,
323
+ num_rows: int = None,
324
+ num_features: int = None
325
+ ):
326
+ """
327
+ Log dataset metadata for the active run.
328
+
329
+ Args:
330
+ name: Dataset name.
331
+ path: Dataset path or URI.
332
+ version: Dataset version string.
333
+ description: Dataset description.
334
+ digest: Hash/digest of the dataset.
335
+ num_rows: Number of rows/samples.
336
+ num_features: Number of features/columns.
337
+ """
338
+ _get_client().log_dataset(name, path, version, description, digest, num_rows, num_features)
339
+
340
+
341
+ def compare_runs(run_ids: list, metric_keys: list = None) -> dict:
342
+ """
343
+ Compare multiple runs side by side.
344
+
345
+ Args:
346
+ run_ids: List of run IDs to compare.
347
+ metric_keys: Optional list of metric keys to include.
348
+
349
+ Returns:
350
+ Dict with comparison data.
351
+ """
352
+ return _get_client().compare_runs(run_ids, metric_keys)
353
+
354
+
355
+ def get_metric_history(run_id: str, metric_key: str) -> list:
356
+ """
357
+ Get the full history of a metric across all steps for a run.
358
+
359
+ Args:
360
+ run_id: Run ID.
361
+ metric_key: Metric key.
362
+
363
+ Returns:
364
+ List of Metric objects.
365
+ """
366
+ return _get_client().get_metric_history(run_id, metric_key)
367
+
368
+
369
+ def download_artifact(run_id: str, artifact_path: str, local_path: str) -> str:
370
+ """
371
+ Download an artifact file from a run.
372
+
373
+ Args:
374
+ run_id: Run ID.
375
+ artifact_path: Artifact path within the run.
376
+ local_path: Local directory to save to.
377
+
378
+ Returns:
379
+ Local path to the downloaded file.
380
+ """
381
+ return _get_client().download_artifact(run_id, artifact_path, local_path)
382
+
383
+
384
+ def search_runs(
385
+ experiment_id: str = None,
386
+ status: str = None,
387
+ max_results: int = 100,
388
+ offset: int = 0
389
+ ) -> list:
390
+ """
391
+ Search / list runs.
392
+
393
+ Args:
394
+ experiment_id: Filter by experiment ID.
395
+ status: Filter by status (running, completed, failed, cancelled).
396
+ max_results: Maximum number of results.
397
+ offset: Pagination offset.
398
+
399
+ Returns:
400
+ List of matching Run objects.
401
+ """
402
+ return _get_client().search_runs(experiment_id, status, max_results, offset)