podstack 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,107 @@
1
+ """
2
+ Registry Exception Classes
3
+
4
+ Custom exceptions for the Podstack Registry SDK.
5
+ """
6
+
7
+ from ..exceptions import PodstackError
8
+
9
+
10
+ class RegistryError(PodstackError):
11
+ """Base exception for registry errors"""
12
+
13
+ def __init__(self, message: str, code: str = "registry_error"):
14
+ super().__init__(message, code=code)
15
+
16
+
17
+ class ExperimentNotFoundError(RegistryError):
18
+ """Raised when an experiment is not found"""
19
+
20
+ def __init__(self, experiment_id: str):
21
+ message = f"Experiment '{experiment_id}' not found"
22
+ super().__init__(message, code="experiment_not_found")
23
+ self.experiment_id = experiment_id
24
+
25
+
26
+ class RunNotFoundError(RegistryError):
27
+ """Raised when a run is not found"""
28
+
29
+ def __init__(self, run_id: str):
30
+ message = f"Run '{run_id}' not found"
31
+ super().__init__(message, code="run_not_found")
32
+ self.run_id = run_id
33
+
34
+
35
+ class ModelNotFoundError(RegistryError):
36
+ """Raised when a model is not found"""
37
+
38
+ def __init__(self, model_id: str):
39
+ message = f"Model '{model_id}' not found"
40
+ super().__init__(message, code="model_not_found")
41
+ self.model_id = model_id
42
+
43
+
44
+ class NoActiveRunError(RegistryError):
45
+ """Raised when trying to log without an active run"""
46
+
47
+ def __init__(self):
48
+ message = "No active run. Call start_run() first."
49
+ super().__init__(message, code="no_active_run")
50
+
51
+
52
+ class NoExperimentSetError(RegistryError):
53
+ """Raised when trying to create a run without setting an experiment"""
54
+
55
+ def __init__(self):
56
+ message = "No experiment set. Call set_experiment() first."
57
+ super().__init__(message, code="no_experiment_set")
58
+
59
+
60
+ class InvalidStageError(RegistryError):
61
+ """Raised when an invalid stage is specified"""
62
+
63
+ VALID_STAGES = ["development", "staging", "production", "archived"]
64
+
65
+ def __init__(self, stage: str):
66
+ message = f"Invalid stage '{stage}'. Must be one of: {', '.join(self.VALID_STAGES)}"
67
+ super().__init__(message, code="invalid_stage")
68
+ self.stage = stage
69
+
70
+
71
+ class ModelVersionNotFoundError(RegistryError):
72
+ """Raised when a model version is not found"""
73
+
74
+ def __init__(self, model_name: str, version: int):
75
+ message = f"Version {version} of model '{model_name}' not found"
76
+ super().__init__(message, code="model_version_not_found")
77
+ self.model_name = model_name
78
+ self.version = version
79
+
80
+
81
+ class ArtifactNotFoundError(RegistryError):
82
+ """Raised when a requested artifact cannot be found"""
83
+
84
+ def __init__(self, run_id: str, artifact_path: str):
85
+ message = f"Artifact '{artifact_path}' not found in run '{run_id}'"
86
+ super().__init__(message, code="artifact_not_found")
87
+ self.run_id = run_id
88
+ self.artifact_path = artifact_path
89
+
90
+
91
+ class ModelSerializationError(RegistryError):
92
+ """Raised when model serialization or deserialization fails"""
93
+
94
+ def __init__(self, message: str):
95
+ super().__init__(message, code="model_serialization_error")
96
+
97
+
98
+ class FrameworkNotInstalledError(RegistryError):
99
+ """Raised when a required ML framework is not installed"""
100
+
101
+ def __init__(self, framework: str):
102
+ message = (
103
+ f"Framework '{framework}' is not installed. "
104
+ f"Install it with: pip install podstack[{framework}]"
105
+ )
106
+ super().__init__(message, code="framework_not_installed")
107
+ self.framework = framework
@@ -0,0 +1,227 @@
1
+ """
2
+ Experiment and Run Classes
3
+
4
+ Data classes for experiment tracking.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional, Dict, Any, List
9
+ from datetime import datetime
10
+
11
+
12
+ @dataclass
13
+ class Experiment:
14
+ """Represents an experiment in the registry."""
15
+
16
+ id: str
17
+ project_id: str
18
+ name: str
19
+ description: Optional[str] = None
20
+ artifact_location: Optional[str] = None
21
+ status: str = "active"
22
+ tags: Dict[str, str] = field(default_factory=dict)
23
+ run_count: int = 0
24
+ last_run_at: Optional[datetime] = None
25
+ created_by_user_id: Optional[str] = None
26
+ created_at: Optional[datetime] = None
27
+ updated_at: Optional[datetime] = None
28
+
29
+ @classmethod
30
+ def from_dict(cls, data: Dict[str, Any]) -> "Experiment":
31
+ """Create an Experiment from a dict."""
32
+ return cls(
33
+ id=data.get("id", ""),
34
+ project_id=data.get("project_id", ""),
35
+ name=data.get("name", ""),
36
+ description=data.get("description"),
37
+ artifact_location=data.get("artifact_location"),
38
+ status=data.get("status", "active"),
39
+ tags=data.get("tags") or {},
40
+ run_count=data.get("run_count", 0),
41
+ last_run_at=_parse_datetime(data.get("last_run_at")),
42
+ created_by_user_id=data.get("created_by_user_id"),
43
+ created_at=_parse_datetime(data.get("created_at")),
44
+ updated_at=_parse_datetime(data.get("updated_at")),
45
+ )
46
+
47
+ def to_dict(self) -> Dict[str, Any]:
48
+ """Convert to dict."""
49
+ return {
50
+ "id": self.id,
51
+ "project_id": self.project_id,
52
+ "name": self.name,
53
+ "description": self.description,
54
+ "artifact_location": self.artifact_location,
55
+ "status": self.status,
56
+ "tags": self.tags,
57
+ "run_count": self.run_count,
58
+ "last_run_at": self.last_run_at.isoformat() if self.last_run_at else None,
59
+ "created_by_user_id": self.created_by_user_id,
60
+ "created_at": self.created_at.isoformat() if self.created_at else None,
61
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
62
+ }
63
+
64
+
65
+ @dataclass
66
+ class Metric:
67
+ """Represents a logged metric."""
68
+
69
+ key: str
70
+ value: float
71
+ timestamp: Optional[int] = None
72
+ step: Optional[int] = None
73
+
74
+ @classmethod
75
+ def from_dict(cls, data: Dict[str, Any]) -> "Metric":
76
+ return cls(
77
+ key=data.get("key", ""),
78
+ value=data.get("value", 0.0),
79
+ timestamp=data.get("timestamp"),
80
+ step=data.get("step"),
81
+ )
82
+
83
+
84
+ @dataclass
85
+ class Param:
86
+ """Represents a logged parameter."""
87
+
88
+ key: str
89
+ value: str
90
+
91
+ @classmethod
92
+ def from_dict(cls, data: Dict[str, Any]) -> "Param":
93
+ return cls(
94
+ key=data.get("key", ""),
95
+ value=data.get("value", ""),
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class Run:
101
+ """
102
+ Represents a run in an experiment.
103
+
104
+ Can be used as a context manager:
105
+ with registry.start_run(name="training") as run:
106
+ registry.log_params({"lr": 0.001})
107
+ registry.log_metrics({"loss": 0.5})
108
+ """
109
+
110
+ id: str
111
+ project_id: str
112
+ experiment_id: str
113
+ user_id: Optional[str] = None
114
+ name: Optional[str] = None
115
+ status: str = "running"
116
+ source_type: Optional[str] = None
117
+ source_name: Optional[str] = None
118
+ artifact_uri: Optional[str] = None
119
+ start_time: Optional[int] = None
120
+ end_time: Optional[int] = None
121
+ duration_ms: int = 0
122
+ tags: Dict[str, str] = field(default_factory=dict)
123
+ created_at: Optional[datetime] = None
124
+ updated_at: Optional[datetime] = None
125
+
126
+ # Reference to client for context manager
127
+ _client: Any = field(default=None, repr=False)
128
+
129
+ @classmethod
130
+ def from_dict(cls, data: Dict[str, Any], client: Any = None) -> "Run":
131
+ """Create a Run from a dict."""
132
+ return cls(
133
+ id=data.get("id", ""),
134
+ project_id=data.get("project_id", ""),
135
+ experiment_id=data.get("experiment_id", ""),
136
+ user_id=data.get("user_id"),
137
+ name=data.get("name"),
138
+ status=data.get("status", "running"),
139
+ source_type=data.get("source_type"),
140
+ source_name=data.get("source_name"),
141
+ artifact_uri=data.get("artifact_uri"),
142
+ start_time=data.get("start_time"),
143
+ end_time=data.get("end_time"),
144
+ duration_ms=data.get("duration_ms", 0),
145
+ tags=data.get("tags") or {},
146
+ created_at=_parse_datetime(data.get("created_at")),
147
+ updated_at=_parse_datetime(data.get("updated_at")),
148
+ _client=client,
149
+ )
150
+
151
+ def to_dict(self) -> Dict[str, Any]:
152
+ """Convert to dict."""
153
+ return {
154
+ "id": self.id,
155
+ "project_id": self.project_id,
156
+ "experiment_id": self.experiment_id,
157
+ "user_id": self.user_id,
158
+ "name": self.name,
159
+ "status": self.status,
160
+ "source_type": self.source_type,
161
+ "source_name": self.source_name,
162
+ "artifact_uri": self.artifact_uri,
163
+ "start_time": self.start_time,
164
+ "end_time": self.end_time,
165
+ "duration_ms": self.duration_ms,
166
+ "tags": self.tags,
167
+ "created_at": self.created_at.isoformat() if self.created_at else None,
168
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
169
+ }
170
+
171
+ def __enter__(self) -> "Run":
172
+ """Enter context manager."""
173
+ return self
174
+
175
+ def __exit__(self, exc_type, exc_val, exc_tb):
176
+ """Exit context manager, ending the run."""
177
+ if self._client:
178
+ if exc_type is not None:
179
+ # Exception occurred, mark as failed
180
+ self._client.end_run(status="failed")
181
+ else:
182
+ self._client.end_run(status="completed")
183
+ return False # Don't suppress exceptions
184
+
185
+ def log_params(self, params: Dict[str, Any]):
186
+ """Log parameters for this run."""
187
+ if self._client:
188
+ self._client.log_params(params)
189
+
190
+ def log_metrics(self, metrics: Dict[str, float], step: int = None):
191
+ """Log metrics for this run."""
192
+ if self._client:
193
+ self._client.log_metrics(metrics, step)
194
+
195
+ def log_artifact(self, local_path: str, artifact_path: str = None):
196
+ """Log an artifact for this run."""
197
+ if self._client:
198
+ self._client.log_artifact(local_path, artifact_path)
199
+
200
+ def set_tag(self, key: str, value: str):
201
+ """Set a tag on this run."""
202
+ if self._client:
203
+ self._client.set_tag(key, value)
204
+ self.tags[key] = value
205
+
206
+ def end(self, status: str = "completed"):
207
+ """End this run."""
208
+ if self._client:
209
+ self._client.end_run(status)
210
+ self.status = status
211
+
212
+
213
+ def _parse_datetime(value: Any) -> Optional[datetime]:
214
+ """Parse a datetime from string or return None."""
215
+ if value is None:
216
+ return None
217
+ if isinstance(value, datetime):
218
+ return value
219
+ if isinstance(value, str):
220
+ try:
221
+ # Handle ISO format with Z suffix
222
+ if value.endswith("Z"):
223
+ value = value[:-1] + "+00:00"
224
+ return datetime.fromisoformat(value)
225
+ except ValueError:
226
+ return None
227
+ return None
@@ -0,0 +1,273 @@
1
+ """
2
+ Model Registry Classes
3
+
4
+ Data classes for model registry.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional, Dict, Any, List
9
+ from datetime import datetime
10
+
11
+
12
+ @dataclass
13
+ class RegisteredModel:
14
+ """Represents a registered model in the registry."""
15
+
16
+ id: str
17
+ project_id: str
18
+ name: str
19
+ description: Optional[str] = None
20
+ tags: Dict[str, str] = field(default_factory=dict)
21
+ version_count: int = 0
22
+ latest_version: int = 0
23
+ created_by_user_id: Optional[str] = None
24
+ created_at: Optional[datetime] = None
25
+ updated_at: Optional[datetime] = None
26
+
27
+ # Loaded lazily
28
+ _versions: Optional[List["ModelVersion"]] = field(default=None, repr=False)
29
+ _aliases: Optional[List["ModelAlias"]] = field(default=None, repr=False)
30
+ _client: Any = field(default=None, repr=False)
31
+
32
+ @classmethod
33
+ def from_dict(cls, data: Dict[str, Any], client: Any = None) -> "RegisteredModel":
34
+ """Create a RegisteredModel from a dict."""
35
+ return cls(
36
+ id=data.get("id", ""),
37
+ project_id=data.get("project_id", ""),
38
+ name=data.get("name", ""),
39
+ description=data.get("description"),
40
+ tags=data.get("tags") or {},
41
+ version_count=data.get("version_count", 0),
42
+ latest_version=data.get("latest_version", 0),
43
+ created_by_user_id=data.get("created_by_user_id"),
44
+ created_at=_parse_datetime(data.get("created_at")),
45
+ updated_at=_parse_datetime(data.get("updated_at")),
46
+ _client=client,
47
+ )
48
+
49
+ def to_dict(self) -> Dict[str, Any]:
50
+ """Convert to dict."""
51
+ return {
52
+ "id": self.id,
53
+ "project_id": self.project_id,
54
+ "name": self.name,
55
+ "description": self.description,
56
+ "tags": self.tags,
57
+ "version_count": self.version_count,
58
+ "latest_version": self.latest_version,
59
+ "created_by_user_id": self.created_by_user_id,
60
+ "created_at": self.created_at.isoformat() if self.created_at else None,
61
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
62
+ }
63
+
64
+ def get_version(self, version: int) -> "ModelVersion":
65
+ """Get a specific version of this model."""
66
+ if self._client:
67
+ return self._client.get_model_version(self.id, version)
68
+ raise RuntimeError("Model not connected to client")
69
+
70
+ def list_versions(self, limit: int = 20, offset: int = 0) -> List["ModelVersion"]:
71
+ """List versions of this model."""
72
+ if self._client:
73
+ return self._client.list_model_versions(self.id, limit, offset)
74
+ raise RuntimeError("Model not connected to client")
75
+
76
+ def create_version(
77
+ self,
78
+ run_id: str = None,
79
+ source: str = None,
80
+ description: str = None
81
+ ) -> "ModelVersion":
82
+ """Create a new version of this model."""
83
+ if self._client:
84
+ return self._client.create_model_version(
85
+ self.id, run_id=run_id, source=source, description=description
86
+ )
87
+ raise RuntimeError("Model not connected to client")
88
+
89
+ def set_alias(self, alias: str, version: int) -> "ModelAlias":
90
+ """Set an alias for a version of this model."""
91
+ if self._client:
92
+ return self._client.set_model_alias(self.name, alias, version)
93
+ raise RuntimeError("Model not connected to client")
94
+
95
+ def get_aliases(self) -> List["ModelAlias"]:
96
+ """Get all aliases for this model."""
97
+ if self._client:
98
+ return self._client.get_model_aliases(self.id)
99
+ raise RuntimeError("Model not connected to client")
100
+
101
+
102
+ @dataclass
103
+ class ModelVersion:
104
+ """Represents a version of a registered model."""
105
+
106
+ id: str
107
+ model_id: str
108
+ project_id: str
109
+ version: int
110
+ run_id: Optional[str] = None
111
+ source: Optional[str] = None
112
+ status: str = "ready"
113
+ stage: str = "development"
114
+ description: Optional[str] = None
115
+ tags: Dict[str, str] = field(default_factory=dict)
116
+ created_by_user_id: Optional[str] = None
117
+ created_at: Optional[datetime] = None
118
+ updated_at: Optional[datetime] = None
119
+
120
+ _client: Any = field(default=None, repr=False)
121
+
122
+ @classmethod
123
+ def from_dict(cls, data: Dict[str, Any], client: Any = None) -> "ModelVersion":
124
+ """Create a ModelVersion from a dict."""
125
+ return cls(
126
+ id=data.get("id", ""),
127
+ model_id=data.get("model_id", ""),
128
+ project_id=data.get("project_id", ""),
129
+ version=data.get("version", 0),
130
+ run_id=data.get("run_id"),
131
+ source=data.get("source"),
132
+ status=data.get("status", "ready"),
133
+ stage=data.get("stage", "development"),
134
+ description=data.get("description"),
135
+ tags=data.get("tags") or {},
136
+ created_by_user_id=data.get("created_by_user_id"),
137
+ created_at=_parse_datetime(data.get("created_at")),
138
+ updated_at=_parse_datetime(data.get("updated_at")),
139
+ _client=client,
140
+ )
141
+
142
+ def to_dict(self) -> Dict[str, Any]:
143
+ """Convert to dict."""
144
+ return {
145
+ "id": self.id,
146
+ "model_id": self.model_id,
147
+ "project_id": self.project_id,
148
+ "version": self.version,
149
+ "run_id": self.run_id,
150
+ "source": self.source,
151
+ "status": self.status,
152
+ "stage": self.stage,
153
+ "description": self.description,
154
+ "tags": self.tags,
155
+ "created_by_user_id": self.created_by_user_id,
156
+ "created_at": self.created_at.isoformat() if self.created_at else None,
157
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
158
+ }
159
+
160
+ def transition_stage(self, stage: str, comment: str = None) -> "ModelVersion":
161
+ """Transition this version to a new stage."""
162
+ if self._client:
163
+ return self._client.transition_model_stage(
164
+ self.model_id, self.version, stage, comment
165
+ )
166
+ raise RuntimeError("ModelVersion not connected to client")
167
+
168
+
169
+ @dataclass
170
+ class ModelAlias:
171
+ """Represents an alias for a model version."""
172
+
173
+ id: str
174
+ model_id: str
175
+ project_id: str
176
+ alias: str
177
+ model_version_id: str
178
+ version: int
179
+ created_by_user_id: Optional[str] = None
180
+ created_at: Optional[datetime] = None
181
+ updated_at: Optional[datetime] = None
182
+
183
+ @classmethod
184
+ def from_dict(cls, data: Dict[str, Any]) -> "ModelAlias":
185
+ """Create a ModelAlias from a dict."""
186
+ return cls(
187
+ id=data.get("id", ""),
188
+ model_id=data.get("model_id", ""),
189
+ project_id=data.get("project_id", ""),
190
+ alias=data.get("alias", ""),
191
+ model_version_id=data.get("model_version_id", ""),
192
+ version=data.get("version", 0),
193
+ created_by_user_id=data.get("created_by_user_id"),
194
+ created_at=_parse_datetime(data.get("created_at")),
195
+ updated_at=_parse_datetime(data.get("updated_at")),
196
+ )
197
+
198
+ def to_dict(self) -> Dict[str, Any]:
199
+ """Convert to dict."""
200
+ return {
201
+ "id": self.id,
202
+ "model_id": self.model_id,
203
+ "project_id": self.project_id,
204
+ "alias": self.alias,
205
+ "model_version_id": self.model_version_id,
206
+ "version": self.version,
207
+ "created_by_user_id": self.created_by_user_id,
208
+ "created_at": self.created_at.isoformat() if self.created_at else None,
209
+ "updated_at": self.updated_at.isoformat() if self.updated_at else None,
210
+ }
211
+
212
+
213
+ @dataclass
214
+ class StageTransition:
215
+ """Represents a stage transition history record."""
216
+
217
+ id: str
218
+ model_id: str
219
+ model_version_id: str
220
+ project_id: str
221
+ version: int
222
+ from_stage: Optional[str] = None
223
+ to_stage: str = ""
224
+ transitioned_by: Optional[str] = None
225
+ comment: Optional[str] = None
226
+ created_at: Optional[datetime] = None
227
+
228
+ @classmethod
229
+ def from_dict(cls, data: Dict[str, Any]) -> "StageTransition":
230
+ """Create a StageTransition from a dict."""
231
+ return cls(
232
+ id=data.get("id", ""),
233
+ model_id=data.get("model_id", ""),
234
+ model_version_id=data.get("model_version_id", ""),
235
+ project_id=data.get("project_id", ""),
236
+ version=data.get("version", 0),
237
+ from_stage=data.get("from_stage"),
238
+ to_stage=data.get("to_stage", ""),
239
+ transitioned_by=data.get("transitioned_by"),
240
+ comment=data.get("comment"),
241
+ created_at=_parse_datetime(data.get("created_at")),
242
+ )
243
+
244
+ def to_dict(self) -> Dict[str, Any]:
245
+ """Convert to dict."""
246
+ return {
247
+ "id": self.id,
248
+ "model_id": self.model_id,
249
+ "model_version_id": self.model_version_id,
250
+ "project_id": self.project_id,
251
+ "version": self.version,
252
+ "from_stage": self.from_stage,
253
+ "to_stage": self.to_stage,
254
+ "transitioned_by": self.transitioned_by,
255
+ "comment": self.comment,
256
+ "created_at": self.created_at.isoformat() if self.created_at else None,
257
+ }
258
+
259
+
260
+ def _parse_datetime(value: Any) -> Optional[datetime]:
261
+ """Parse a datetime from string or return None."""
262
+ if value is None:
263
+ return None
264
+ if isinstance(value, datetime):
265
+ return value
266
+ if isinstance(value, str):
267
+ try:
268
+ if value.endswith("Z"):
269
+ value = value[:-1] + "+00:00"
270
+ return datetime.fromisoformat(value)
271
+ except ValueError:
272
+ return None
273
+ return None