flowyml 1.7.2__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowyml/assets/base.py +15 -0
- flowyml/assets/metrics.py +5 -0
- flowyml/cli/main.py +709 -0
- flowyml/cli/stack_cli.py +138 -25
- flowyml/core/__init__.py +17 -0
- flowyml/core/executor.py +161 -26
- flowyml/core/image_builder.py +129 -0
- flowyml/core/log_streamer.py +227 -0
- flowyml/core/orchestrator.py +22 -2
- flowyml/core/pipeline.py +34 -10
- flowyml/core/routing.py +558 -0
- flowyml/core/step.py +9 -1
- flowyml/core/step_grouping.py +49 -35
- flowyml/core/types.py +407 -0
- flowyml/monitoring/alerts.py +10 -0
- flowyml/monitoring/notifications.py +104 -25
- flowyml/monitoring/slack_blocks.py +323 -0
- flowyml/plugins/__init__.py +251 -0
- flowyml/plugins/alerters/__init__.py +1 -0
- flowyml/plugins/alerters/slack.py +168 -0
- flowyml/plugins/base.py +752 -0
- flowyml/plugins/config.py +478 -0
- flowyml/plugins/deployers/__init__.py +22 -0
- flowyml/plugins/deployers/gcp_cloud_run.py +200 -0
- flowyml/plugins/deployers/sagemaker.py +306 -0
- flowyml/plugins/deployers/vertex.py +290 -0
- flowyml/plugins/integration.py +369 -0
- flowyml/plugins/manager.py +510 -0
- flowyml/plugins/model_registries/__init__.py +22 -0
- flowyml/plugins/model_registries/mlflow.py +159 -0
- flowyml/plugins/model_registries/sagemaker.py +489 -0
- flowyml/plugins/model_registries/vertex.py +386 -0
- flowyml/plugins/orchestrators/__init__.py +13 -0
- flowyml/plugins/orchestrators/sagemaker.py +443 -0
- flowyml/plugins/orchestrators/vertex_ai.py +461 -0
- flowyml/plugins/registries/__init__.py +13 -0
- flowyml/plugins/registries/ecr.py +321 -0
- flowyml/plugins/registries/gcr.py +313 -0
- flowyml/plugins/registry.py +454 -0
- flowyml/plugins/stack.py +494 -0
- flowyml/plugins/stack_config.py +537 -0
- flowyml/plugins/stores/__init__.py +13 -0
- flowyml/plugins/stores/gcs.py +460 -0
- flowyml/plugins/stores/s3.py +453 -0
- flowyml/plugins/trackers/__init__.py +11 -0
- flowyml/plugins/trackers/mlflow.py +316 -0
- flowyml/plugins/validators/__init__.py +3 -0
- flowyml/plugins/validators/deepchecks.py +119 -0
- flowyml/registry/__init__.py +2 -1
- flowyml/registry/model_environment.py +109 -0
- flowyml/registry/model_registry.py +241 -96
- flowyml/serving/__init__.py +17 -0
- flowyml/serving/model_server.py +628 -0
- flowyml/stacks/__init__.py +60 -0
- flowyml/stacks/aws.py +93 -0
- flowyml/stacks/base.py +62 -0
- flowyml/stacks/components.py +12 -0
- flowyml/stacks/gcp.py +44 -9
- flowyml/stacks/plugins.py +115 -0
- flowyml/stacks/registry.py +2 -1
- flowyml/storage/sql.py +401 -12
- flowyml/tracking/experiment.py +8 -5
- flowyml/ui/backend/Dockerfile +87 -16
- flowyml/ui/backend/auth.py +12 -2
- flowyml/ui/backend/main.py +149 -5
- flowyml/ui/backend/routers/ai_context.py +226 -0
- flowyml/ui/backend/routers/assets.py +23 -4
- flowyml/ui/backend/routers/auth.py +96 -0
- flowyml/ui/backend/routers/deployments.py +660 -0
- flowyml/ui/backend/routers/model_explorer.py +597 -0
- flowyml/ui/backend/routers/plugins.py +103 -51
- flowyml/ui/backend/routers/projects.py +91 -8
- flowyml/ui/backend/routers/runs.py +20 -1
- flowyml/ui/backend/routers/schedules.py +22 -17
- flowyml/ui/backend/routers/templates.py +319 -0
- flowyml/ui/backend/routers/websocket.py +2 -2
- flowyml/ui/frontend/Dockerfile +55 -6
- flowyml/ui/frontend/dist/assets/index-B5AsPTSz.css +1 -0
- flowyml/ui/frontend/dist/assets/index-dFbZ8wD8.js +753 -0
- flowyml/ui/frontend/dist/index.html +2 -2
- flowyml/ui/frontend/dist/logo.png +0 -0
- flowyml/ui/frontend/nginx.conf +65 -4
- flowyml/ui/frontend/package-lock.json +1404 -74
- flowyml/ui/frontend/package.json +3 -0
- flowyml/ui/frontend/public/logo.png +0 -0
- flowyml/ui/frontend/src/App.jsx +10 -7
- flowyml/ui/frontend/src/app/auth/Login.jsx +90 -0
- flowyml/ui/frontend/src/app/dashboard/page.jsx +8 -8
- flowyml/ui/frontend/src/app/deployments/page.jsx +786 -0
- flowyml/ui/frontend/src/app/model-explorer/page.jsx +1031 -0
- flowyml/ui/frontend/src/app/pipelines/page.jsx +12 -2
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectExperimentsList.jsx +19 -6
- flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +36 -24
- flowyml/ui/frontend/src/app/runs/page.jsx +8 -2
- flowyml/ui/frontend/src/app/settings/page.jsx +267 -253
- flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +29 -7
- flowyml/ui/frontend/src/components/Layout.jsx +6 -0
- flowyml/ui/frontend/src/components/PipelineGraph.jsx +79 -29
- flowyml/ui/frontend/src/components/RunDetailsPanel.jsx +36 -6
- flowyml/ui/frontend/src/components/RunMetaPanel.jsx +113 -0
- flowyml/ui/frontend/src/components/ai/AIAssistantButton.jsx +71 -0
- flowyml/ui/frontend/src/components/ai/AIAssistantPanel.jsx +420 -0
- flowyml/ui/frontend/src/components/header/Header.jsx +22 -0
- flowyml/ui/frontend/src/components/plugins/PluginManager.jsx +4 -4
- flowyml/ui/frontend/src/components/plugins/{ZenMLIntegration.jsx → StackImport.jsx} +38 -12
- flowyml/ui/frontend/src/components/sidebar/Sidebar.jsx +36 -13
- flowyml/ui/frontend/src/contexts/AIAssistantContext.jsx +245 -0
- flowyml/ui/frontend/src/contexts/AuthContext.jsx +108 -0
- flowyml/ui/frontend/src/hooks/useAIContext.js +156 -0
- flowyml/ui/frontend/src/hooks/useWebGPU.js +54 -0
- flowyml/ui/frontend/src/layouts/MainLayout.jsx +6 -0
- flowyml/ui/frontend/src/router/index.jsx +47 -20
- flowyml/ui/frontend/src/services/pluginService.js +3 -1
- flowyml/ui/server_manager.py +5 -5
- flowyml/ui/utils.py +157 -39
- flowyml/utils/config.py +37 -15
- flowyml/utils/model_introspection.py +123 -0
- flowyml/utils/observability.py +30 -0
- flowyml-1.8.0.dist-info/METADATA +174 -0
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/RECORD +123 -65
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/WHEEL +1 -1
- flowyml/ui/frontend/dist/assets/index-B40RsQDq.css +0 -1
- flowyml/ui/frontend/dist/assets/index-CjI0zKCn.js +0 -685
- flowyml-1.7.2.dist-info/METADATA +0 -477
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/entry_points.txt +0 -0
- {flowyml-1.7.2.dist-info → flowyml-1.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
"""SageMaker Model Registry - Native FlowyML Plugin.
|
|
2
|
+
|
|
3
|
+
This plugin provides direct integration with AWS SageMaker Model Registry
|
|
4
|
+
for model versioning, cataloging, and deployment.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from flowyml.plugins import get_plugin
|
|
8
|
+
|
|
9
|
+
registry = get_plugin("sagemaker_model_registry", region="us-east-1")
|
|
10
|
+
registry.register_model(
|
|
11
|
+
name="my-model",
|
|
12
|
+
model_uri="s3://bucket/model/",
|
|
13
|
+
version="1.0.0",
|
|
14
|
+
metadata={"framework": "pytorch", "accuracy": 0.95}
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Deploy to endpoint
|
|
18
|
+
endpoint = registry.deploy_model(
|
|
19
|
+
model_name="my-model",
|
|
20
|
+
endpoint_name="my-endpoint",
|
|
21
|
+
instance_type="ml.m5.large"
|
|
22
|
+
)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import logging
|
|
26
|
+
from typing import Any
|
|
27
|
+
from datetime import datetime
|
|
28
|
+
|
|
29
|
+
from flowyml.plugins.base import ModelRegistryPlugin, PluginMetadata, PluginType
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SageMakerModelRegistry(ModelRegistryPlugin):
|
|
35
|
+
"""Native SageMaker Model Registry plugin for FlowyML.
|
|
36
|
+
|
|
37
|
+
This plugin integrates directly with SageMaker Model Registry
|
|
38
|
+
for registering, versioning, and deploying ML models.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
region: AWS region.
|
|
42
|
+
role_arn: IAM role ARN for SageMaker.
|
|
43
|
+
model_package_group_arn: Optional model package group ARN.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
metadata = PluginMetadata(
|
|
47
|
+
name="sagemaker_model_registry",
|
|
48
|
+
version="1.0.0",
|
|
49
|
+
description="AWS SageMaker Model Registry",
|
|
50
|
+
author="FlowyML Team",
|
|
51
|
+
plugin_type=PluginType.CUSTOM,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
region: str = None,
|
|
57
|
+
role_arn: str = None,
|
|
58
|
+
model_package_group_name: str = "flowyml-models",
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
"""Initialize the SageMaker Model Registry plugin.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
region: AWS region (uses default if not specified).
|
|
65
|
+
role_arn: IAM role ARN for SageMaker operations.
|
|
66
|
+
model_package_group_name: Model package group for organizing models.
|
|
67
|
+
**kwargs: Additional plugin arguments.
|
|
68
|
+
"""
|
|
69
|
+
super().__init__(**kwargs)
|
|
70
|
+
self.region = region
|
|
71
|
+
self.role_arn = role_arn
|
|
72
|
+
self.model_package_group_name = model_package_group_name
|
|
73
|
+
self._sagemaker = None
|
|
74
|
+
self._boto_session = None
|
|
75
|
+
self._initialized = False
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def plugin_type(self) -> PluginType:
|
|
79
|
+
return PluginType.CUSTOM
|
|
80
|
+
|
|
81
|
+
def initialize(self) -> None:
|
|
82
|
+
"""Initialize connection to SageMaker."""
|
|
83
|
+
if self._initialized:
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
import boto3
|
|
88
|
+
import sagemaker
|
|
89
|
+
|
|
90
|
+
self._boto_session = boto3.Session(region_name=self.region)
|
|
91
|
+
self._sm_client = self._boto_session.client("sagemaker")
|
|
92
|
+
self._sagemaker = sagemaker
|
|
93
|
+
|
|
94
|
+
# Ensure model package group exists
|
|
95
|
+
self._ensure_model_package_group()
|
|
96
|
+
|
|
97
|
+
self._initialized = True
|
|
98
|
+
logger.info(f"SageMaker Model Registry initialized in region {self.region}")
|
|
99
|
+
except ImportError:
|
|
100
|
+
raise ImportError(
|
|
101
|
+
"boto3 and sagemaker are required. " "Install with: pip install boto3 sagemaker",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _ensure_initialized(self) -> None:
|
|
105
|
+
"""Ensure SageMaker is initialized."""
|
|
106
|
+
if not self._initialized:
|
|
107
|
+
self.initialize()
|
|
108
|
+
|
|
109
|
+
def _ensure_model_package_group(self) -> None:
|
|
110
|
+
"""Ensure the model package group exists."""
|
|
111
|
+
try:
|
|
112
|
+
self._sm_client.describe_model_package_group(
|
|
113
|
+
ModelPackageGroupName=self.model_package_group_name,
|
|
114
|
+
)
|
|
115
|
+
except self._sm_client.exceptions.ResourceNotFoundException:
|
|
116
|
+
self._sm_client.create_model_package_group(
|
|
117
|
+
ModelPackageGroupName=self.model_package_group_name,
|
|
118
|
+
ModelPackageGroupDescription="FlowyML Model Registry",
|
|
119
|
+
)
|
|
120
|
+
logger.info(f"Created model package group: {self.model_package_group_name}")
|
|
121
|
+
|
|
122
|
+
def register_model(
|
|
123
|
+
self,
|
|
124
|
+
name: str,
|
|
125
|
+
model_uri: str,
|
|
126
|
+
version: str = None,
|
|
127
|
+
metadata: dict = None,
|
|
128
|
+
inference_image_uri: str = None,
|
|
129
|
+
content_types: list[str] = None,
|
|
130
|
+
response_types: list[str] = None,
|
|
131
|
+
description: str = None,
|
|
132
|
+
**kwargs,
|
|
133
|
+
) -> str:
|
|
134
|
+
"""Register a model in SageMaker Model Registry.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
name: Model name.
|
|
138
|
+
model_uri: S3 URI to model artifacts.
|
|
139
|
+
version: Model version string.
|
|
140
|
+
metadata: Model metadata dictionary.
|
|
141
|
+
inference_image_uri: Docker image for inference.
|
|
142
|
+
content_types: Supported input content types.
|
|
143
|
+
response_types: Supported output content types.
|
|
144
|
+
description: Model description.
|
|
145
|
+
**kwargs: Additional registration arguments.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Model package ARN.
|
|
149
|
+
"""
|
|
150
|
+
self._ensure_initialized()
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
metadata = metadata or {}
|
|
154
|
+
content_types = content_types or ["application/json"]
|
|
155
|
+
response_types = response_types or ["application/json"]
|
|
156
|
+
|
|
157
|
+
# Determine inference image if not provided
|
|
158
|
+
if not inference_image_uri:
|
|
159
|
+
framework = metadata.get("framework", "").lower()
|
|
160
|
+
region = self.region or self._boto_session.region_name
|
|
161
|
+
|
|
162
|
+
# Use SageMaker pre-built images
|
|
163
|
+
account_map = {
|
|
164
|
+
"us-east-1": "763104351884",
|
|
165
|
+
"us-west-2": "763104351884",
|
|
166
|
+
"eu-west-1": "763104351884",
|
|
167
|
+
# Add more regions as needed
|
|
168
|
+
}
|
|
169
|
+
account = account_map.get(region, "763104351884")
|
|
170
|
+
|
|
171
|
+
if framework == "pytorch":
|
|
172
|
+
inference_image_uri = (
|
|
173
|
+
f"{account}.dkr.ecr.{region}.amazonaws.com/" "pytorch-inference:2.0.0-cpu-py310"
|
|
174
|
+
)
|
|
175
|
+
elif framework == "tensorflow":
|
|
176
|
+
inference_image_uri = f"{account}.dkr.ecr.{region}.amazonaws.com/" "tensorflow-inference:2.13-cpu"
|
|
177
|
+
elif framework == "sklearn" or framework == "scikit-learn":
|
|
178
|
+
inference_image_uri = (
|
|
179
|
+
f"{account}.dkr.ecr.{region}.amazonaws.com/" "sagemaker-scikit-learn:1.2-1-cpu-py310"
|
|
180
|
+
)
|
|
181
|
+
elif framework == "xgboost":
|
|
182
|
+
inference_image_uri = (
|
|
183
|
+
f"{account}.dkr.ecr.{region}.amazonaws.com/" "sagemaker-xgboost:1.7-1-cpu-py310"
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
# Default to sklearn
|
|
187
|
+
inference_image_uri = (
|
|
188
|
+
f"{account}.dkr.ecr.{region}.amazonaws.com/" "sagemaker-scikit-learn:1.2-1-cpu-py310"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Build model package spec
|
|
192
|
+
model_spec = {
|
|
193
|
+
"ModelPackageGroupName": self.model_package_group_name,
|
|
194
|
+
"ModelPackageDescription": description or f"{name} registered via FlowyML",
|
|
195
|
+
"InferenceSpecification": {
|
|
196
|
+
"Containers": [
|
|
197
|
+
{
|
|
198
|
+
"Image": inference_image_uri,
|
|
199
|
+
"ModelDataUrl": model_uri,
|
|
200
|
+
},
|
|
201
|
+
],
|
|
202
|
+
"SupportedContentTypes": content_types,
|
|
203
|
+
"SupportedResponseMIMETypes": response_types,
|
|
204
|
+
"SupportedTransformInstanceTypes": ["ml.m5.large"],
|
|
205
|
+
"SupportedRealtimeInferenceInstanceTypes": ["ml.m5.large"],
|
|
206
|
+
},
|
|
207
|
+
"ModelApprovalStatus": "PendingManualApproval",
|
|
208
|
+
"CustomerMetadataProperties": {
|
|
209
|
+
"name": name,
|
|
210
|
+
"version": version or "1.0.0",
|
|
211
|
+
"registered_at": datetime.now().isoformat(),
|
|
212
|
+
**{k: str(v) for k, v in metadata.items()},
|
|
213
|
+
},
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
response = self._sm_client.create_model_package(**model_spec)
|
|
217
|
+
model_package_arn = response["ModelPackageArn"]
|
|
218
|
+
|
|
219
|
+
logger.info(f"Registered model '{name}' with ARN: {model_package_arn}")
|
|
220
|
+
return model_package_arn
|
|
221
|
+
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.error(f"Failed to register model '{name}': {e}")
|
|
224
|
+
raise
|
|
225
|
+
|
|
226
|
+
def get_model(
|
|
227
|
+
self,
|
|
228
|
+
name: str,
|
|
229
|
+
version: str = None,
|
|
230
|
+
) -> Any:
|
|
231
|
+
"""Get a model from the registry.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
name: Model name or package ARN.
|
|
235
|
+
version: Specific version to retrieve.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Model package details.
|
|
239
|
+
"""
|
|
240
|
+
self._ensure_initialized()
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
# If it's an ARN, get directly
|
|
244
|
+
if name.startswith("arn:"):
|
|
245
|
+
return self._sm_client.describe_model_package(
|
|
246
|
+
ModelPackageName=name,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Otherwise, list and filter by name
|
|
250
|
+
packages = self._sm_client.list_model_packages(
|
|
251
|
+
ModelPackageGroupName=self.model_package_group_name,
|
|
252
|
+
SortBy="CreationTime",
|
|
253
|
+
SortOrder="Descending",
|
|
254
|
+
)["ModelPackageSummaryList"]
|
|
255
|
+
|
|
256
|
+
for pkg in packages:
|
|
257
|
+
details = self._sm_client.describe_model_package(
|
|
258
|
+
ModelPackageName=pkg["ModelPackageArn"],
|
|
259
|
+
)
|
|
260
|
+
customer_meta = details.get("CustomerMetadataProperties", {})
|
|
261
|
+
|
|
262
|
+
if customer_meta.get("name") == name:
|
|
263
|
+
if version and customer_meta.get("version") != version:
|
|
264
|
+
continue
|
|
265
|
+
return details
|
|
266
|
+
|
|
267
|
+
logger.warning(f"Model '{name}' not found")
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
logger.error(f"Failed to get model '{name}': {e}")
|
|
272
|
+
raise
|
|
273
|
+
|
|
274
|
+
def list_models(
|
|
275
|
+
self,
|
|
276
|
+
limit: int = 100,
|
|
277
|
+
) -> list[dict]:
|
|
278
|
+
"""List models in the registry.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
limit: Maximum number of models to return.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
List of model dictionaries.
|
|
285
|
+
"""
|
|
286
|
+
self._ensure_initialized()
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
packages = self._sm_client.list_model_packages(
|
|
290
|
+
ModelPackageGroupName=self.model_package_group_name,
|
|
291
|
+
MaxResults=min(limit, 100),
|
|
292
|
+
SortBy="CreationTime",
|
|
293
|
+
SortOrder="Descending",
|
|
294
|
+
)["ModelPackageSummaryList"]
|
|
295
|
+
|
|
296
|
+
result = []
|
|
297
|
+
for pkg in packages:
|
|
298
|
+
details = self._sm_client.describe_model_package(
|
|
299
|
+
ModelPackageName=pkg["ModelPackageArn"],
|
|
300
|
+
)
|
|
301
|
+
customer_meta = details.get("CustomerMetadataProperties", {})
|
|
302
|
+
|
|
303
|
+
result.append(
|
|
304
|
+
{
|
|
305
|
+
"name": customer_meta.get("name", "unknown"),
|
|
306
|
+
"version": customer_meta.get("version", "1.0.0"),
|
|
307
|
+
"arn": pkg["ModelPackageArn"],
|
|
308
|
+
"status": details.get("ModelApprovalStatus", "Unknown"),
|
|
309
|
+
"created": str(pkg.get("CreationTime", "")),
|
|
310
|
+
"metadata": customer_meta,
|
|
311
|
+
},
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return result
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
logger.error(f"Failed to list models: {e}")
|
|
318
|
+
raise
|
|
319
|
+
|
|
320
|
+
def transition_model_stage(
|
|
321
|
+
self,
|
|
322
|
+
name: str,
|
|
323
|
+
stage: str,
|
|
324
|
+
version: str = None,
|
|
325
|
+
) -> bool:
|
|
326
|
+
"""Transition model approval status.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
name: Model name.
|
|
330
|
+
stage: Target stage ("Approved", "Rejected", "PendingManualApproval").
|
|
331
|
+
version: Specific version.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
True if successful.
|
|
335
|
+
"""
|
|
336
|
+
self._ensure_initialized()
|
|
337
|
+
|
|
338
|
+
# Map friendly stage names to SageMaker statuses
|
|
339
|
+
stage_map = {
|
|
340
|
+
"production": "Approved",
|
|
341
|
+
"staging": "Approved",
|
|
342
|
+
"approved": "Approved",
|
|
343
|
+
"rejected": "Rejected",
|
|
344
|
+
"pending": "PendingManualApproval",
|
|
345
|
+
}
|
|
346
|
+
approval_status = stage_map.get(stage.lower(), stage)
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
model = self.get_model(name, version)
|
|
350
|
+
if not model:
|
|
351
|
+
return False
|
|
352
|
+
|
|
353
|
+
self._sm_client.update_model_package(
|
|
354
|
+
ModelPackageArn=model["ModelPackageArn"],
|
|
355
|
+
ModelApprovalStatus=approval_status,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
logger.info(f"Transitioned model '{name}' to status '{approval_status}'")
|
|
359
|
+
return True
|
|
360
|
+
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.error(f"Failed to transition model: {e}")
|
|
363
|
+
return False
|
|
364
|
+
|
|
365
|
+
def deploy_model(
|
|
366
|
+
self,
|
|
367
|
+
model_name: str,
|
|
368
|
+
endpoint_name: str,
|
|
369
|
+
instance_type: str = "ml.m5.large",
|
|
370
|
+
instance_count: int = 1,
|
|
371
|
+
**kwargs,
|
|
372
|
+
) -> str:
|
|
373
|
+
"""Deploy a model to a SageMaker endpoint.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
model_name: Model name or package ARN.
|
|
377
|
+
endpoint_name: Name for the endpoint.
|
|
378
|
+
instance_type: Instance type for deployment.
|
|
379
|
+
instance_count: Number of instances.
|
|
380
|
+
**kwargs: Additional deployment arguments.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Endpoint ARN.
|
|
384
|
+
"""
|
|
385
|
+
self._ensure_initialized()
|
|
386
|
+
|
|
387
|
+
try:
|
|
388
|
+
# Get model package
|
|
389
|
+
model = self.get_model(model_name)
|
|
390
|
+
if not model:
|
|
391
|
+
raise ValueError(f"Model '{model_name}' not found")
|
|
392
|
+
|
|
393
|
+
model_package_arn = model["ModelPackageArn"]
|
|
394
|
+
|
|
395
|
+
# Create model from package
|
|
396
|
+
sm_model_name = f"{endpoint_name}-model"
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
self._sm_client.create_model(
|
|
400
|
+
ModelName=sm_model_name,
|
|
401
|
+
PrimaryContainer={
|
|
402
|
+
"ModelPackageName": model_package_arn,
|
|
403
|
+
},
|
|
404
|
+
ExecutionRoleArn=self.role_arn,
|
|
405
|
+
)
|
|
406
|
+
except self._sm_client.exceptions.ResourceInUse:
|
|
407
|
+
logger.info(f"Model {sm_model_name} already exists")
|
|
408
|
+
|
|
409
|
+
# Create endpoint config
|
|
410
|
+
config_name = f"{endpoint_name}-config"
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
self._sm_client.create_endpoint_config(
|
|
414
|
+
EndpointConfigName=config_name,
|
|
415
|
+
ProductionVariants=[
|
|
416
|
+
{
|
|
417
|
+
"VariantName": "primary",
|
|
418
|
+
"ModelName": sm_model_name,
|
|
419
|
+
"InstanceType": instance_type,
|
|
420
|
+
"InitialInstanceCount": instance_count,
|
|
421
|
+
},
|
|
422
|
+
],
|
|
423
|
+
)
|
|
424
|
+
except self._sm_client.exceptions.ResourceInUse:
|
|
425
|
+
logger.info(f"Endpoint config {config_name} already exists")
|
|
426
|
+
|
|
427
|
+
# Create or update endpoint
|
|
428
|
+
try:
|
|
429
|
+
self._sm_client.create_endpoint(
|
|
430
|
+
EndpointName=endpoint_name,
|
|
431
|
+
EndpointConfigName=config_name,
|
|
432
|
+
)
|
|
433
|
+
logger.info(f"Creating endpoint: {endpoint_name}")
|
|
434
|
+
except self._sm_client.exceptions.ResourceInUse:
|
|
435
|
+
self._sm_client.update_endpoint(
|
|
436
|
+
EndpointName=endpoint_name,
|
|
437
|
+
EndpointConfigName=config_name,
|
|
438
|
+
)
|
|
439
|
+
logger.info(f"Updating endpoint: {endpoint_name}")
|
|
440
|
+
|
|
441
|
+
# Get endpoint ARN
|
|
442
|
+
endpoint_desc = self._sm_client.describe_endpoint(
|
|
443
|
+
EndpointName=endpoint_name,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
logger.info(f"Deployed model '{model_name}' to endpoint '{endpoint_name}'")
|
|
447
|
+
return endpoint_desc["EndpointArn"]
|
|
448
|
+
|
|
449
|
+
except Exception as e:
|
|
450
|
+
logger.error(f"Failed to deploy model: {e}")
|
|
451
|
+
raise
|
|
452
|
+
|
|
453
|
+
def predict(
|
|
454
|
+
self,
|
|
455
|
+
endpoint_name: str,
|
|
456
|
+
data: Any,
|
|
457
|
+
content_type: str = "application/json",
|
|
458
|
+
) -> Any:
|
|
459
|
+
"""Make predictions using a deployed model.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
endpoint_name: Endpoint name.
|
|
463
|
+
data: Input data (will be JSON serialized).
|
|
464
|
+
content_type: Content type of the request.
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
Prediction result.
|
|
468
|
+
"""
|
|
469
|
+
self._ensure_initialized()
|
|
470
|
+
|
|
471
|
+
try:
|
|
472
|
+
import json
|
|
473
|
+
|
|
474
|
+
runtime_client = self._boto_session.client("sagemaker-runtime")
|
|
475
|
+
|
|
476
|
+
body = json.dumps(data) if isinstance(data, (dict, list)) else str(data)
|
|
477
|
+
|
|
478
|
+
response = runtime_client.invoke_endpoint(
|
|
479
|
+
EndpointName=endpoint_name,
|
|
480
|
+
ContentType=content_type,
|
|
481
|
+
Body=body.encode("utf-8"),
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
result = json.loads(response["Body"].read().decode("utf-8"))
|
|
485
|
+
return result
|
|
486
|
+
|
|
487
|
+
except Exception as e:
|
|
488
|
+
logger.error(f"Failed to make prediction: {e}")
|
|
489
|
+
raise
|