podstack 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- podstack/__init__.py +222 -0
- podstack/annotations.py +725 -0
- podstack/client.py +322 -0
- podstack/exceptions.py +125 -0
- podstack/execution.py +291 -0
- podstack/gpu_runner.py +1141 -0
- podstack/models.py +274 -0
- podstack/notebook.py +410 -0
- podstack/registry/__init__.py +402 -0
- podstack/registry/client.py +957 -0
- podstack/registry/exceptions.py +107 -0
- podstack/registry/experiment.py +227 -0
- podstack/registry/model.py +273 -0
- podstack/registry/model_utils.py +231 -0
- podstack-1.2.0.dist-info/METADATA +299 -0
- podstack-1.2.0.dist-info/RECORD +27 -0
- podstack-1.2.0.dist-info/WHEEL +5 -0
- podstack-1.2.0.dist-info/licenses/LICENSE +21 -0
- podstack-1.2.0.dist-info/top_level.txt +2 -0
- podstack_gpu/__init__.py +126 -0
- podstack_gpu/app.py +675 -0
- podstack_gpu/exceptions.py +35 -0
- podstack_gpu/image.py +325 -0
- podstack_gpu/runner.py +746 -0
- podstack_gpu/secret.py +189 -0
- podstack_gpu/utils.py +203 -0
- podstack_gpu/volume.py +198 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Registry Exception Classes
|
|
3
|
+
|
|
4
|
+
Custom exceptions for the Podstack Registry SDK.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ..exceptions import PodstackError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RegistryError(PodstackError):
|
|
11
|
+
"""Base exception for registry errors"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, message: str, code: str = "registry_error"):
|
|
14
|
+
super().__init__(message, code=code)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ExperimentNotFoundError(RegistryError):
|
|
18
|
+
"""Raised when an experiment is not found"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, experiment_id: str):
|
|
21
|
+
message = f"Experiment '{experiment_id}' not found"
|
|
22
|
+
super().__init__(message, code="experiment_not_found")
|
|
23
|
+
self.experiment_id = experiment_id
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RunNotFoundError(RegistryError):
|
|
27
|
+
"""Raised when a run is not found"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, run_id: str):
|
|
30
|
+
message = f"Run '{run_id}' not found"
|
|
31
|
+
super().__init__(message, code="run_not_found")
|
|
32
|
+
self.run_id = run_id
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ModelNotFoundError(RegistryError):
|
|
36
|
+
"""Raised when a model is not found"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, model_id: str):
|
|
39
|
+
message = f"Model '{model_id}' not found"
|
|
40
|
+
super().__init__(message, code="model_not_found")
|
|
41
|
+
self.model_id = model_id
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class NoActiveRunError(RegistryError):
|
|
45
|
+
"""Raised when trying to log without an active run"""
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
message = "No active run. Call start_run() first."
|
|
49
|
+
super().__init__(message, code="no_active_run")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class NoExperimentSetError(RegistryError):
|
|
53
|
+
"""Raised when trying to create a run without setting an experiment"""
|
|
54
|
+
|
|
55
|
+
def __init__(self):
|
|
56
|
+
message = "No experiment set. Call set_experiment() first."
|
|
57
|
+
super().__init__(message, code="no_experiment_set")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class InvalidStageError(RegistryError):
|
|
61
|
+
"""Raised when an invalid stage is specified"""
|
|
62
|
+
|
|
63
|
+
VALID_STAGES = ["development", "staging", "production", "archived"]
|
|
64
|
+
|
|
65
|
+
def __init__(self, stage: str):
|
|
66
|
+
message = f"Invalid stage '{stage}'. Must be one of: {', '.join(self.VALID_STAGES)}"
|
|
67
|
+
super().__init__(message, code="invalid_stage")
|
|
68
|
+
self.stage = stage
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ModelVersionNotFoundError(RegistryError):
|
|
72
|
+
"""Raised when a model version is not found"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, model_name: str, version: int):
|
|
75
|
+
message = f"Version {version} of model '{model_name}' not found"
|
|
76
|
+
super().__init__(message, code="model_version_not_found")
|
|
77
|
+
self.model_name = model_name
|
|
78
|
+
self.version = version
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ArtifactNotFoundError(RegistryError):
|
|
82
|
+
"""Raised when a requested artifact cannot be found"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, run_id: str, artifact_path: str):
|
|
85
|
+
message = f"Artifact '{artifact_path}' not found in run '{run_id}'"
|
|
86
|
+
super().__init__(message, code="artifact_not_found")
|
|
87
|
+
self.run_id = run_id
|
|
88
|
+
self.artifact_path = artifact_path
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ModelSerializationError(RegistryError):
|
|
92
|
+
"""Raised when model serialization or deserialization fails"""
|
|
93
|
+
|
|
94
|
+
def __init__(self, message: str):
|
|
95
|
+
super().__init__(message, code="model_serialization_error")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class FrameworkNotInstalledError(RegistryError):
|
|
99
|
+
"""Raised when a required ML framework is not installed"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, framework: str):
|
|
102
|
+
message = (
|
|
103
|
+
f"Framework '{framework}' is not installed. "
|
|
104
|
+
f"Install it with: pip install podstack[{framework}]"
|
|
105
|
+
)
|
|
106
|
+
super().__init__(message, code="framework_not_installed")
|
|
107
|
+
self.framework = framework
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Experiment and Run Classes
|
|
3
|
+
|
|
4
|
+
Data classes for experiment tracking.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Optional, Dict, Any, List
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Experiment:
|
|
14
|
+
"""Represents an experiment in the registry."""
|
|
15
|
+
|
|
16
|
+
id: str
|
|
17
|
+
project_id: str
|
|
18
|
+
name: str
|
|
19
|
+
description: Optional[str] = None
|
|
20
|
+
artifact_location: Optional[str] = None
|
|
21
|
+
status: str = "active"
|
|
22
|
+
tags: Dict[str, str] = field(default_factory=dict)
|
|
23
|
+
run_count: int = 0
|
|
24
|
+
last_run_at: Optional[datetime] = None
|
|
25
|
+
created_by_user_id: Optional[str] = None
|
|
26
|
+
created_at: Optional[datetime] = None
|
|
27
|
+
updated_at: Optional[datetime] = None
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Experiment":
|
|
31
|
+
"""Create an Experiment from a dict."""
|
|
32
|
+
return cls(
|
|
33
|
+
id=data.get("id", ""),
|
|
34
|
+
project_id=data.get("project_id", ""),
|
|
35
|
+
name=data.get("name", ""),
|
|
36
|
+
description=data.get("description"),
|
|
37
|
+
artifact_location=data.get("artifact_location"),
|
|
38
|
+
status=data.get("status", "active"),
|
|
39
|
+
tags=data.get("tags") or {},
|
|
40
|
+
run_count=data.get("run_count", 0),
|
|
41
|
+
last_run_at=_parse_datetime(data.get("last_run_at")),
|
|
42
|
+
created_by_user_id=data.get("created_by_user_id"),
|
|
43
|
+
created_at=_parse_datetime(data.get("created_at")),
|
|
44
|
+
updated_at=_parse_datetime(data.get("updated_at")),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
48
|
+
"""Convert to dict."""
|
|
49
|
+
return {
|
|
50
|
+
"id": self.id,
|
|
51
|
+
"project_id": self.project_id,
|
|
52
|
+
"name": self.name,
|
|
53
|
+
"description": self.description,
|
|
54
|
+
"artifact_location": self.artifact_location,
|
|
55
|
+
"status": self.status,
|
|
56
|
+
"tags": self.tags,
|
|
57
|
+
"run_count": self.run_count,
|
|
58
|
+
"last_run_at": self.last_run_at.isoformat() if self.last_run_at else None,
|
|
59
|
+
"created_by_user_id": self.created_by_user_id,
|
|
60
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
61
|
+
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class Metric:
|
|
67
|
+
"""Represents a logged metric."""
|
|
68
|
+
|
|
69
|
+
key: str
|
|
70
|
+
value: float
|
|
71
|
+
timestamp: Optional[int] = None
|
|
72
|
+
step: Optional[int] = None
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Metric":
|
|
76
|
+
return cls(
|
|
77
|
+
key=data.get("key", ""),
|
|
78
|
+
value=data.get("value", 0.0),
|
|
79
|
+
timestamp=data.get("timestamp"),
|
|
80
|
+
step=data.get("step"),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class Param:
|
|
86
|
+
"""Represents a logged parameter."""
|
|
87
|
+
|
|
88
|
+
key: str
|
|
89
|
+
value: str
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_dict(cls, data: Dict[str, Any]) -> "Param":
|
|
93
|
+
return cls(
|
|
94
|
+
key=data.get("key", ""),
|
|
95
|
+
value=data.get("value", ""),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class Run:
|
|
101
|
+
"""
|
|
102
|
+
Represents a run in an experiment.
|
|
103
|
+
|
|
104
|
+
Can be used as a context manager:
|
|
105
|
+
with registry.start_run(name="training") as run:
|
|
106
|
+
registry.log_params({"lr": 0.001})
|
|
107
|
+
registry.log_metrics({"loss": 0.5})
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
id: str
|
|
111
|
+
project_id: str
|
|
112
|
+
experiment_id: str
|
|
113
|
+
user_id: Optional[str] = None
|
|
114
|
+
name: Optional[str] = None
|
|
115
|
+
status: str = "running"
|
|
116
|
+
source_type: Optional[str] = None
|
|
117
|
+
source_name: Optional[str] = None
|
|
118
|
+
artifact_uri: Optional[str] = None
|
|
119
|
+
start_time: Optional[int] = None
|
|
120
|
+
end_time: Optional[int] = None
|
|
121
|
+
duration_ms: int = 0
|
|
122
|
+
tags: Dict[str, str] = field(default_factory=dict)
|
|
123
|
+
created_at: Optional[datetime] = None
|
|
124
|
+
updated_at: Optional[datetime] = None
|
|
125
|
+
|
|
126
|
+
# Reference to client for context manager
|
|
127
|
+
_client: Any = field(default=None, repr=False)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_dict(cls, data: Dict[str, Any], client: Any = None) -> "Run":
|
|
131
|
+
"""Create a Run from a dict."""
|
|
132
|
+
return cls(
|
|
133
|
+
id=data.get("id", ""),
|
|
134
|
+
project_id=data.get("project_id", ""),
|
|
135
|
+
experiment_id=data.get("experiment_id", ""),
|
|
136
|
+
user_id=data.get("user_id"),
|
|
137
|
+
name=data.get("name"),
|
|
138
|
+
status=data.get("status", "running"),
|
|
139
|
+
source_type=data.get("source_type"),
|
|
140
|
+
source_name=data.get("source_name"),
|
|
141
|
+
artifact_uri=data.get("artifact_uri"),
|
|
142
|
+
start_time=data.get("start_time"),
|
|
143
|
+
end_time=data.get("end_time"),
|
|
144
|
+
duration_ms=data.get("duration_ms", 0),
|
|
145
|
+
tags=data.get("tags") or {},
|
|
146
|
+
created_at=_parse_datetime(data.get("created_at")),
|
|
147
|
+
updated_at=_parse_datetime(data.get("updated_at")),
|
|
148
|
+
_client=client,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
152
|
+
"""Convert to dict."""
|
|
153
|
+
return {
|
|
154
|
+
"id": self.id,
|
|
155
|
+
"project_id": self.project_id,
|
|
156
|
+
"experiment_id": self.experiment_id,
|
|
157
|
+
"user_id": self.user_id,
|
|
158
|
+
"name": self.name,
|
|
159
|
+
"status": self.status,
|
|
160
|
+
"source_type": self.source_type,
|
|
161
|
+
"source_name": self.source_name,
|
|
162
|
+
"artifact_uri": self.artifact_uri,
|
|
163
|
+
"start_time": self.start_time,
|
|
164
|
+
"end_time": self.end_time,
|
|
165
|
+
"duration_ms": self.duration_ms,
|
|
166
|
+
"tags": self.tags,
|
|
167
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
168
|
+
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
def __enter__(self) -> "Run":
|
|
172
|
+
"""Enter context manager."""
|
|
173
|
+
return self
|
|
174
|
+
|
|
175
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
176
|
+
"""Exit context manager, ending the run."""
|
|
177
|
+
if self._client:
|
|
178
|
+
if exc_type is not None:
|
|
179
|
+
# Exception occurred, mark as failed
|
|
180
|
+
self._client.end_run(status="failed")
|
|
181
|
+
else:
|
|
182
|
+
self._client.end_run(status="completed")
|
|
183
|
+
return False # Don't suppress exceptions
|
|
184
|
+
|
|
185
|
+
def log_params(self, params: Dict[str, Any]):
|
|
186
|
+
"""Log parameters for this run."""
|
|
187
|
+
if self._client:
|
|
188
|
+
self._client.log_params(params)
|
|
189
|
+
|
|
190
|
+
def log_metrics(self, metrics: Dict[str, float], step: int = None):
|
|
191
|
+
"""Log metrics for this run."""
|
|
192
|
+
if self._client:
|
|
193
|
+
self._client.log_metrics(metrics, step)
|
|
194
|
+
|
|
195
|
+
def log_artifact(self, local_path: str, artifact_path: str = None):
|
|
196
|
+
"""Log an artifact for this run."""
|
|
197
|
+
if self._client:
|
|
198
|
+
self._client.log_artifact(local_path, artifact_path)
|
|
199
|
+
|
|
200
|
+
def set_tag(self, key: str, value: str):
|
|
201
|
+
"""Set a tag on this run."""
|
|
202
|
+
if self._client:
|
|
203
|
+
self._client.set_tag(key, value)
|
|
204
|
+
self.tags[key] = value
|
|
205
|
+
|
|
206
|
+
def end(self, status: str = "completed"):
|
|
207
|
+
"""End this run."""
|
|
208
|
+
if self._client:
|
|
209
|
+
self._client.end_run(status)
|
|
210
|
+
self.status = status
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _parse_datetime(value: Any) -> Optional[datetime]:
|
|
214
|
+
"""Parse a datetime from string or return None."""
|
|
215
|
+
if value is None:
|
|
216
|
+
return None
|
|
217
|
+
if isinstance(value, datetime):
|
|
218
|
+
return value
|
|
219
|
+
if isinstance(value, str):
|
|
220
|
+
try:
|
|
221
|
+
# Handle ISO format with Z suffix
|
|
222
|
+
if value.endswith("Z"):
|
|
223
|
+
value = value[:-1] + "+00:00"
|
|
224
|
+
return datetime.fromisoformat(value)
|
|
225
|
+
except ValueError:
|
|
226
|
+
return None
|
|
227
|
+
return None
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Registry Classes
|
|
3
|
+
|
|
4
|
+
Data classes for model registry.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Optional, Dict, Any, List
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class RegisteredModel:
|
|
14
|
+
"""Represents a registered model in the registry."""
|
|
15
|
+
|
|
16
|
+
id: str
|
|
17
|
+
project_id: str
|
|
18
|
+
name: str
|
|
19
|
+
description: Optional[str] = None
|
|
20
|
+
tags: Dict[str, str] = field(default_factory=dict)
|
|
21
|
+
version_count: int = 0
|
|
22
|
+
latest_version: int = 0
|
|
23
|
+
created_by_user_id: Optional[str] = None
|
|
24
|
+
created_at: Optional[datetime] = None
|
|
25
|
+
updated_at: Optional[datetime] = None
|
|
26
|
+
|
|
27
|
+
# Loaded lazily
|
|
28
|
+
_versions: Optional[List["ModelVersion"]] = field(default=None, repr=False)
|
|
29
|
+
_aliases: Optional[List["ModelAlias"]] = field(default=None, repr=False)
|
|
30
|
+
_client: Any = field(default=None, repr=False)
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_dict(cls, data: Dict[str, Any], client: Any = None) -> "RegisteredModel":
|
|
34
|
+
"""Create a RegisteredModel from a dict."""
|
|
35
|
+
return cls(
|
|
36
|
+
id=data.get("id", ""),
|
|
37
|
+
project_id=data.get("project_id", ""),
|
|
38
|
+
name=data.get("name", ""),
|
|
39
|
+
description=data.get("description"),
|
|
40
|
+
tags=data.get("tags") or {},
|
|
41
|
+
version_count=data.get("version_count", 0),
|
|
42
|
+
latest_version=data.get("latest_version", 0),
|
|
43
|
+
created_by_user_id=data.get("created_by_user_id"),
|
|
44
|
+
created_at=_parse_datetime(data.get("created_at")),
|
|
45
|
+
updated_at=_parse_datetime(data.get("updated_at")),
|
|
46
|
+
_client=client,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
50
|
+
"""Convert to dict."""
|
|
51
|
+
return {
|
|
52
|
+
"id": self.id,
|
|
53
|
+
"project_id": self.project_id,
|
|
54
|
+
"name": self.name,
|
|
55
|
+
"description": self.description,
|
|
56
|
+
"tags": self.tags,
|
|
57
|
+
"version_count": self.version_count,
|
|
58
|
+
"latest_version": self.latest_version,
|
|
59
|
+
"created_by_user_id": self.created_by_user_id,
|
|
60
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
61
|
+
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def get_version(self, version: int) -> "ModelVersion":
|
|
65
|
+
"""Get a specific version of this model."""
|
|
66
|
+
if self._client:
|
|
67
|
+
return self._client.get_model_version(self.id, version)
|
|
68
|
+
raise RuntimeError("Model not connected to client")
|
|
69
|
+
|
|
70
|
+
def list_versions(self, limit: int = 20, offset: int = 0) -> List["ModelVersion"]:
|
|
71
|
+
"""List versions of this model."""
|
|
72
|
+
if self._client:
|
|
73
|
+
return self._client.list_model_versions(self.id, limit, offset)
|
|
74
|
+
raise RuntimeError("Model not connected to client")
|
|
75
|
+
|
|
76
|
+
def create_version(
|
|
77
|
+
self,
|
|
78
|
+
run_id: str = None,
|
|
79
|
+
source: str = None,
|
|
80
|
+
description: str = None
|
|
81
|
+
) -> "ModelVersion":
|
|
82
|
+
"""Create a new version of this model."""
|
|
83
|
+
if self._client:
|
|
84
|
+
return self._client.create_model_version(
|
|
85
|
+
self.id, run_id=run_id, source=source, description=description
|
|
86
|
+
)
|
|
87
|
+
raise RuntimeError("Model not connected to client")
|
|
88
|
+
|
|
89
|
+
def set_alias(self, alias: str, version: int) -> "ModelAlias":
|
|
90
|
+
"""Set an alias for a version of this model."""
|
|
91
|
+
if self._client:
|
|
92
|
+
return self._client.set_model_alias(self.name, alias, version)
|
|
93
|
+
raise RuntimeError("Model not connected to client")
|
|
94
|
+
|
|
95
|
+
def get_aliases(self) -> List["ModelAlias"]:
|
|
96
|
+
"""Get all aliases for this model."""
|
|
97
|
+
if self._client:
|
|
98
|
+
return self._client.get_model_aliases(self.id)
|
|
99
|
+
raise RuntimeError("Model not connected to client")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass
|
|
103
|
+
class ModelVersion:
|
|
104
|
+
"""Represents a version of a registered model."""
|
|
105
|
+
|
|
106
|
+
id: str
|
|
107
|
+
model_id: str
|
|
108
|
+
project_id: str
|
|
109
|
+
version: int
|
|
110
|
+
run_id: Optional[str] = None
|
|
111
|
+
source: Optional[str] = None
|
|
112
|
+
status: str = "ready"
|
|
113
|
+
stage: str = "development"
|
|
114
|
+
description: Optional[str] = None
|
|
115
|
+
tags: Dict[str, str] = field(default_factory=dict)
|
|
116
|
+
created_by_user_id: Optional[str] = None
|
|
117
|
+
created_at: Optional[datetime] = None
|
|
118
|
+
updated_at: Optional[datetime] = None
|
|
119
|
+
|
|
120
|
+
_client: Any = field(default=None, repr=False)
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_dict(cls, data: Dict[str, Any], client: Any = None) -> "ModelVersion":
|
|
124
|
+
"""Create a ModelVersion from a dict."""
|
|
125
|
+
return cls(
|
|
126
|
+
id=data.get("id", ""),
|
|
127
|
+
model_id=data.get("model_id", ""),
|
|
128
|
+
project_id=data.get("project_id", ""),
|
|
129
|
+
version=data.get("version", 0),
|
|
130
|
+
run_id=data.get("run_id"),
|
|
131
|
+
source=data.get("source"),
|
|
132
|
+
status=data.get("status", "ready"),
|
|
133
|
+
stage=data.get("stage", "development"),
|
|
134
|
+
description=data.get("description"),
|
|
135
|
+
tags=data.get("tags") or {},
|
|
136
|
+
created_by_user_id=data.get("created_by_user_id"),
|
|
137
|
+
created_at=_parse_datetime(data.get("created_at")),
|
|
138
|
+
updated_at=_parse_datetime(data.get("updated_at")),
|
|
139
|
+
_client=client,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
143
|
+
"""Convert to dict."""
|
|
144
|
+
return {
|
|
145
|
+
"id": self.id,
|
|
146
|
+
"model_id": self.model_id,
|
|
147
|
+
"project_id": self.project_id,
|
|
148
|
+
"version": self.version,
|
|
149
|
+
"run_id": self.run_id,
|
|
150
|
+
"source": self.source,
|
|
151
|
+
"status": self.status,
|
|
152
|
+
"stage": self.stage,
|
|
153
|
+
"description": self.description,
|
|
154
|
+
"tags": self.tags,
|
|
155
|
+
"created_by_user_id": self.created_by_user_id,
|
|
156
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
157
|
+
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
def transition_stage(self, stage: str, comment: str = None) -> "ModelVersion":
|
|
161
|
+
"""Transition this version to a new stage."""
|
|
162
|
+
if self._client:
|
|
163
|
+
return self._client.transition_model_stage(
|
|
164
|
+
self.model_id, self.version, stage, comment
|
|
165
|
+
)
|
|
166
|
+
raise RuntimeError("ModelVersion not connected to client")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclass
|
|
170
|
+
class ModelAlias:
|
|
171
|
+
"""Represents an alias for a model version."""
|
|
172
|
+
|
|
173
|
+
id: str
|
|
174
|
+
model_id: str
|
|
175
|
+
project_id: str
|
|
176
|
+
alias: str
|
|
177
|
+
model_version_id: str
|
|
178
|
+
version: int
|
|
179
|
+
created_by_user_id: Optional[str] = None
|
|
180
|
+
created_at: Optional[datetime] = None
|
|
181
|
+
updated_at: Optional[datetime] = None
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ModelAlias":
|
|
185
|
+
"""Create a ModelAlias from a dict."""
|
|
186
|
+
return cls(
|
|
187
|
+
id=data.get("id", ""),
|
|
188
|
+
model_id=data.get("model_id", ""),
|
|
189
|
+
project_id=data.get("project_id", ""),
|
|
190
|
+
alias=data.get("alias", ""),
|
|
191
|
+
model_version_id=data.get("model_version_id", ""),
|
|
192
|
+
version=data.get("version", 0),
|
|
193
|
+
created_by_user_id=data.get("created_by_user_id"),
|
|
194
|
+
created_at=_parse_datetime(data.get("created_at")),
|
|
195
|
+
updated_at=_parse_datetime(data.get("updated_at")),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
199
|
+
"""Convert to dict."""
|
|
200
|
+
return {
|
|
201
|
+
"id": self.id,
|
|
202
|
+
"model_id": self.model_id,
|
|
203
|
+
"project_id": self.project_id,
|
|
204
|
+
"alias": self.alias,
|
|
205
|
+
"model_version_id": self.model_version_id,
|
|
206
|
+
"version": self.version,
|
|
207
|
+
"created_by_user_id": self.created_by_user_id,
|
|
208
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
209
|
+
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@dataclass
|
|
214
|
+
class StageTransition:
|
|
215
|
+
"""Represents a stage transition history record."""
|
|
216
|
+
|
|
217
|
+
id: str
|
|
218
|
+
model_id: str
|
|
219
|
+
model_version_id: str
|
|
220
|
+
project_id: str
|
|
221
|
+
version: int
|
|
222
|
+
from_stage: Optional[str] = None
|
|
223
|
+
to_stage: str = ""
|
|
224
|
+
transitioned_by: Optional[str] = None
|
|
225
|
+
comment: Optional[str] = None
|
|
226
|
+
created_at: Optional[datetime] = None
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def from_dict(cls, data: Dict[str, Any]) -> "StageTransition":
|
|
230
|
+
"""Create a StageTransition from a dict."""
|
|
231
|
+
return cls(
|
|
232
|
+
id=data.get("id", ""),
|
|
233
|
+
model_id=data.get("model_id", ""),
|
|
234
|
+
model_version_id=data.get("model_version_id", ""),
|
|
235
|
+
project_id=data.get("project_id", ""),
|
|
236
|
+
version=data.get("version", 0),
|
|
237
|
+
from_stage=data.get("from_stage"),
|
|
238
|
+
to_stage=data.get("to_stage", ""),
|
|
239
|
+
transitioned_by=data.get("transitioned_by"),
|
|
240
|
+
comment=data.get("comment"),
|
|
241
|
+
created_at=_parse_datetime(data.get("created_at")),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
245
|
+
"""Convert to dict."""
|
|
246
|
+
return {
|
|
247
|
+
"id": self.id,
|
|
248
|
+
"model_id": self.model_id,
|
|
249
|
+
"model_version_id": self.model_version_id,
|
|
250
|
+
"project_id": self.project_id,
|
|
251
|
+
"version": self.version,
|
|
252
|
+
"from_stage": self.from_stage,
|
|
253
|
+
"to_stage": self.to_stage,
|
|
254
|
+
"transitioned_by": self.transitioned_by,
|
|
255
|
+
"comment": self.comment,
|
|
256
|
+
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _parse_datetime(value: Any) -> Optional[datetime]:
|
|
261
|
+
"""Parse a datetime from string or return None."""
|
|
262
|
+
if value is None:
|
|
263
|
+
return None
|
|
264
|
+
if isinstance(value, datetime):
|
|
265
|
+
return value
|
|
266
|
+
if isinstance(value, str):
|
|
267
|
+
try:
|
|
268
|
+
if value.endswith("Z"):
|
|
269
|
+
value = value[:-1] + "+00:00"
|
|
270
|
+
return datetime.fromisoformat(value)
|
|
271
|
+
except ValueError:
|
|
272
|
+
return None
|
|
273
|
+
return None
|