flowyml 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowyml/__init__.py +3 -0
- flowyml/assets/base.py +10 -0
- flowyml/assets/metrics.py +6 -0
- flowyml/cli/main.py +108 -2
- flowyml/cli/run.py +9 -2
- flowyml/core/execution_status.py +52 -0
- flowyml/core/hooks.py +106 -0
- flowyml/core/observability.py +210 -0
- flowyml/core/orchestrator.py +274 -0
- flowyml/core/pipeline.py +193 -231
- flowyml/core/project.py +34 -2
- flowyml/core/remote_orchestrator.py +109 -0
- flowyml/core/resources.py +22 -5
- flowyml/core/retry_policy.py +80 -0
- flowyml/core/step.py +18 -1
- flowyml/core/submission_result.py +53 -0
- flowyml/integrations/keras.py +95 -22
- flowyml/monitoring/alerts.py +2 -2
- flowyml/stacks/__init__.py +15 -0
- flowyml/stacks/aws.py +599 -0
- flowyml/stacks/azure.py +295 -0
- flowyml/stacks/components.py +24 -2
- flowyml/stacks/gcp.py +158 -11
- flowyml/stacks/local.py +5 -0
- flowyml/storage/artifacts.py +15 -5
- flowyml/storage/materializers/__init__.py +2 -0
- flowyml/storage/materializers/cloudpickle.py +74 -0
- flowyml/storage/metadata.py +166 -5
- flowyml/ui/backend/main.py +41 -1
- flowyml/ui/backend/routers/assets.py +356 -15
- flowyml/ui/backend/routers/client.py +46 -0
- flowyml/ui/backend/routers/execution.py +13 -2
- flowyml/ui/backend/routers/experiments.py +48 -12
- flowyml/ui/backend/routers/metrics.py +213 -0
- flowyml/ui/backend/routers/pipelines.py +63 -7
- flowyml/ui/backend/routers/projects.py +33 -7
- flowyml/ui/backend/routers/runs.py +150 -8
- flowyml/ui/frontend/dist/assets/index-DcYwrn2j.css +1 -0
- flowyml/ui/frontend/dist/assets/index-Dlz_ygOL.js +592 -0
- flowyml/ui/frontend/dist/index.html +2 -2
- flowyml/ui/frontend/src/App.jsx +4 -1
- flowyml/ui/frontend/src/app/assets/page.jsx +260 -230
- flowyml/ui/frontend/src/app/dashboard/page.jsx +38 -7
- flowyml/ui/frontend/src/app/experiments/page.jsx +61 -314
- flowyml/ui/frontend/src/app/observability/page.jsx +277 -0
- flowyml/ui/frontend/src/app/pipelines/page.jsx +79 -402
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectArtifactsList.jsx +151 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectExperimentsList.jsx +145 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectHeader.jsx +45 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectHierarchy.jsx +467 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectMetricsPanel.jsx +253 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectPipelinesList.jsx +105 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectRelations.jsx +189 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectRunsList.jsx +136 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectTabs.jsx +95 -0
- flowyml/ui/frontend/src/app/projects/[projectId]/page.jsx +326 -0
- flowyml/ui/frontend/src/app/projects/page.jsx +13 -3
- flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +79 -10
- flowyml/ui/frontend/src/app/runs/page.jsx +82 -424
- flowyml/ui/frontend/src/app/settings/page.jsx +1 -0
- flowyml/ui/frontend/src/app/tokens/page.jsx +62 -16
- flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +373 -0
- flowyml/ui/frontend/src/components/AssetLineageGraph.jsx +291 -0
- flowyml/ui/frontend/src/components/AssetStatsDashboard.jsx +302 -0
- flowyml/ui/frontend/src/components/AssetTreeHierarchy.jsx +477 -0
- flowyml/ui/frontend/src/components/ExperimentDetailsPanel.jsx +227 -0
- flowyml/ui/frontend/src/components/NavigationTree.jsx +401 -0
- flowyml/ui/frontend/src/components/PipelineDetailsPanel.jsx +239 -0
- flowyml/ui/frontend/src/components/PipelineGraph.jsx +67 -3
- flowyml/ui/frontend/src/components/ProjectSelector.jsx +115 -0
- flowyml/ui/frontend/src/components/RunDetailsPanel.jsx +298 -0
- flowyml/ui/frontend/src/components/header/Header.jsx +48 -1
- flowyml/ui/frontend/src/components/plugins/ZenMLIntegration.jsx +106 -0
- flowyml/ui/frontend/src/components/sidebar/Sidebar.jsx +52 -26
- flowyml/ui/frontend/src/components/ui/DataView.jsx +35 -17
- flowyml/ui/frontend/src/components/ui/ErrorBoundary.jsx +118 -0
- flowyml/ui/frontend/src/contexts/ProjectContext.jsx +2 -2
- flowyml/ui/frontend/src/contexts/ToastContext.jsx +116 -0
- flowyml/ui/frontend/src/layouts/MainLayout.jsx +5 -1
- flowyml/ui/frontend/src/router/index.jsx +4 -0
- flowyml/ui/frontend/src/utils/date.js +10 -0
- flowyml/ui/frontend/src/utils/downloads.js +11 -0
- flowyml/utils/config.py +6 -0
- flowyml/utils/stack_config.py +45 -3
- {flowyml-1.2.0.dist-info → flowyml-1.3.0.dist-info}/METADATA +42 -4
- {flowyml-1.2.0.dist-info → flowyml-1.3.0.dist-info}/RECORD +89 -52
- {flowyml-1.2.0.dist-info → flowyml-1.3.0.dist-info}/licenses/LICENSE +1 -1
- flowyml/ui/frontend/dist/assets/index-DFNQnrUj.js +0 -448
- flowyml/ui/frontend/dist/assets/index-pWI271rZ.css +0 -1
- {flowyml-1.2.0.dist-info → flowyml-1.3.0.dist-info}/WHEEL +0 -0
- {flowyml-1.2.0.dist-info → flowyml-1.3.0.dist-info}/entry_points.txt +0 -0
flowyml/stacks/aws.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
1
|
+
"""AWS Stack Components and Preset Stack."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import subprocess
|
|
7
|
+
import uuid
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from flowyml.stacks.base import Stack
|
|
12
|
+
from flowyml.stacks.components import ArtifactStore, ContainerRegistry, ResourceConfig, DockerConfig
|
|
13
|
+
from flowyml.core.remote_orchestrator import RemoteOrchestrator
|
|
14
|
+
from flowyml.stacks.plugins import register_component
|
|
15
|
+
from flowyml.storage.metadata import SQLiteMetadataStore
|
|
16
|
+
from flowyml.core.submission_result import SubmissionResult
|
|
17
|
+
from flowyml.core.execution_status import ExecutionStatus
|
|
18
|
+
from flowyml.stacks.components import Orchestrator
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_component(name="s3")
|
|
22
|
+
class S3ArtifactStore(ArtifactStore):
|
|
23
|
+
"""Artifact store backed by Amazon S3."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
name: str = "s3",
|
|
28
|
+
bucket_name: str | None = None,
|
|
29
|
+
prefix: str = "flowyml",
|
|
30
|
+
region: str | None = None,
|
|
31
|
+
session_kwargs: dict[str, Any] | None = None,
|
|
32
|
+
):
|
|
33
|
+
super().__init__(name)
|
|
34
|
+
self.bucket_name = bucket_name
|
|
35
|
+
self.prefix = prefix.strip("/")
|
|
36
|
+
self.region = region
|
|
37
|
+
self.session_kwargs = session_kwargs or {}
|
|
38
|
+
|
|
39
|
+
def _client(self):
|
|
40
|
+
import boto3
|
|
41
|
+
|
|
42
|
+
return boto3.client("s3", region_name=self.region, **self.session_kwargs)
|
|
43
|
+
|
|
44
|
+
def _object_key(self, path: str) -> str:
|
|
45
|
+
normalized = path.lstrip("/")
|
|
46
|
+
return f"{self.prefix}/{normalized}" if self.prefix else normalized
|
|
47
|
+
|
|
48
|
+
def validate(self) -> bool:
|
|
49
|
+
if not self.bucket_name:
|
|
50
|
+
raise ValueError("bucket_name is required for S3ArtifactStore")
|
|
51
|
+
try:
|
|
52
|
+
self._client().head_bucket(Bucket=self.bucket_name)
|
|
53
|
+
except Exception as exc:
|
|
54
|
+
raise ValueError(f"Unable to access bucket '{self.bucket_name}': {exc}") from exc
|
|
55
|
+
return True
|
|
56
|
+
|
|
57
|
+
def save(self, artifact: Any, path: str) -> str:
|
|
58
|
+
"""Save artifact to S3. Accepts file paths, bytes, or strings."""
|
|
59
|
+
key = self._object_key(path)
|
|
60
|
+
client = self._client()
|
|
61
|
+
|
|
62
|
+
if isinstance(artifact, (str, Path)) and Path(artifact).exists():
|
|
63
|
+
client.upload_file(str(Path(artifact)), self.bucket_name, key)
|
|
64
|
+
else:
|
|
65
|
+
body = artifact if isinstance(artifact, bytes) else str(artifact).encode()
|
|
66
|
+
client.put_object(Bucket=self.bucket_name, Key=key, Body=body)
|
|
67
|
+
|
|
68
|
+
return f"s3://{self.bucket_name}/{key}"
|
|
69
|
+
|
|
70
|
+
def load(self, path: str) -> bytes:
|
|
71
|
+
key = self._object_key(path)
|
|
72
|
+
client = self._client()
|
|
73
|
+
obj = client.get_object(Bucket=self.bucket_name, Key=key)
|
|
74
|
+
return obj["Body"].read()
|
|
75
|
+
|
|
76
|
+
def exists(self, path: str) -> bool:
|
|
77
|
+
key = self._object_key(path)
|
|
78
|
+
client = self._client()
|
|
79
|
+
try:
|
|
80
|
+
client.head_object(Bucket=self.bucket_name, Key=key)
|
|
81
|
+
return True
|
|
82
|
+
except Exception:
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
def to_dict(self) -> dict[str, Any]:
|
|
86
|
+
return {
|
|
87
|
+
"name": self.name,
|
|
88
|
+
"type": "s3",
|
|
89
|
+
"bucket_name": self.bucket_name,
|
|
90
|
+
"prefix": self.prefix,
|
|
91
|
+
"region": self.region,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@register_component(name="ecr")
|
|
96
|
+
class ECRContainerRegistry(ContainerRegistry):
|
|
97
|
+
"""Amazon Elastic Container Registry integration."""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
name: str = "ecr",
|
|
102
|
+
account_id: str | None = None,
|
|
103
|
+
region: str = "us-east-1",
|
|
104
|
+
registry_alias: str | None = None,
|
|
105
|
+
):
|
|
106
|
+
super().__init__(name)
|
|
107
|
+
self.account_id = account_id
|
|
108
|
+
self.region = region
|
|
109
|
+
self.registry_alias = registry_alias
|
|
110
|
+
|
|
111
|
+
def _client(self):
|
|
112
|
+
import boto3
|
|
113
|
+
|
|
114
|
+
return boto3.client("ecr", region_name=self.region)
|
|
115
|
+
|
|
116
|
+
def validate(self) -> bool:
|
|
117
|
+
if not self.account_id:
|
|
118
|
+
raise ValueError("account_id is required for ECRContainerRegistry")
|
|
119
|
+
return True
|
|
120
|
+
|
|
121
|
+
def _login(self) -> None:
|
|
122
|
+
client = self._client()
|
|
123
|
+
auth = client.get_authorization_token()
|
|
124
|
+
data = auth["authorizationData"][0]
|
|
125
|
+
token = base64.b64decode(data["authorizationToken"]).decode()
|
|
126
|
+
username, password = token.split(":")
|
|
127
|
+
endpoint = data["proxyEndpoint"]
|
|
128
|
+
subprocess.run(["docker", "login", "--username", username, "--password", password, endpoint], check=True)
|
|
129
|
+
|
|
130
|
+
def push_image(self, image_name: str, tag: str = "latest") -> str:
|
|
131
|
+
full_uri = self.get_image_uri(image_name, tag)
|
|
132
|
+
self._login()
|
|
133
|
+
subprocess.run(["docker", "tag", f"{image_name}:{tag}", full_uri], check=True)
|
|
134
|
+
subprocess.run(["docker", "push", full_uri], check=True)
|
|
135
|
+
return full_uri
|
|
136
|
+
|
|
137
|
+
def pull_image(self, image_name: str, tag: str = "latest") -> None:
|
|
138
|
+
full_uri = self.get_image_uri(image_name, tag)
|
|
139
|
+
self._login()
|
|
140
|
+
subprocess.run(["docker", "pull", full_uri], check=True)
|
|
141
|
+
|
|
142
|
+
def get_image_uri(self, image_name: str, tag: str = "latest") -> str:
|
|
143
|
+
registry = self.registry_alias or f"{self.account_id}.dkr.ecr.{self.region}.amazonaws.com"
|
|
144
|
+
return f"{registry}/{image_name}:{tag}"
|
|
145
|
+
|
|
146
|
+
def to_dict(self) -> dict[str, Any]:
|
|
147
|
+
return {
|
|
148
|
+
"name": self.name,
|
|
149
|
+
"type": "ecr",
|
|
150
|
+
"account_id": self.account_id,
|
|
151
|
+
"region": self.region,
|
|
152
|
+
"registry_alias": self.registry_alias,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@register_component(name="aws_batch")
|
|
157
|
+
class AWSBatchOrchestrator(RemoteOrchestrator):
|
|
158
|
+
"""Submit Flow yML jobs to AWS Batch."""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
name: str = "aws_batch",
|
|
163
|
+
region: str = "us-east-1",
|
|
164
|
+
job_queue: str | None = None,
|
|
165
|
+
job_definition: str | None = None,
|
|
166
|
+
parameters: dict[str, str] | None = None,
|
|
167
|
+
):
|
|
168
|
+
super().__init__(name)
|
|
169
|
+
self.region = region
|
|
170
|
+
self.job_queue = job_queue
|
|
171
|
+
self.job_definition = job_definition
|
|
172
|
+
self.parameters = parameters or {}
|
|
173
|
+
|
|
174
|
+
def _client(self):
|
|
175
|
+
import boto3
|
|
176
|
+
|
|
177
|
+
return boto3.client("batch", region_name=self.region)
|
|
178
|
+
|
|
179
|
+
def validate(self) -> bool:
|
|
180
|
+
if not self.job_queue or not self.job_definition:
|
|
181
|
+
raise ValueError("job_queue and job_definition are required for AWSBatchOrchestrator")
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
def run_pipeline(
|
|
185
|
+
self,
|
|
186
|
+
pipeline: Any,
|
|
187
|
+
run_id: str,
|
|
188
|
+
resources: ResourceConfig | None = None,
|
|
189
|
+
docker_config: DockerConfig | None = None,
|
|
190
|
+
inputs: dict[str, Any] | None = None,
|
|
191
|
+
context: dict[str, Any] | None = None,
|
|
192
|
+
**kwargs,
|
|
193
|
+
) -> SubmissionResult:
|
|
194
|
+
"""Submit pipeline to AWS Batch.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
SubmissionResult with job ID and optional wait function.
|
|
198
|
+
"""
|
|
199
|
+
from flowyml.core.submission_result import SubmissionResult
|
|
200
|
+
import time
|
|
201
|
+
|
|
202
|
+
client = self._client()
|
|
203
|
+
job_name = kwargs.get("job_name") or f"{pipeline.name}-{getattr(pipeline, 'run_id', uuid.uuid4().hex)[:8]}"
|
|
204
|
+
|
|
205
|
+
env = [
|
|
206
|
+
{"name": "FLOWYML_PIPELINE_NAME", "value": pipeline.name},
|
|
207
|
+
{"name": "FLOWYML_RUN_ID", "value": getattr(pipeline, "run_id", uuid.uuid4().hex)},
|
|
208
|
+
]
|
|
209
|
+
if docker_config and docker_config.env_vars:
|
|
210
|
+
for key, value in docker_config.env_vars.items():
|
|
211
|
+
env.append({"name": key, "value": value})
|
|
212
|
+
|
|
213
|
+
container_overrides: dict[str, Any] = {"environment": env}
|
|
214
|
+
if docker_config and docker_config.image:
|
|
215
|
+
container_overrides["command"] = ["python", "-m", "flowyml.cli.run"]
|
|
216
|
+
|
|
217
|
+
if resources:
|
|
218
|
+
container_overrides["resourceRequirements"] = [
|
|
219
|
+
{"type": "VCPU", "value": resources.cpu},
|
|
220
|
+
{"type": "MEMORY", "value": resources.memory.replace("Gi", "")},
|
|
221
|
+
]
|
|
222
|
+
if resources.gpu:
|
|
223
|
+
container_overrides["resourceRequirements"].append(
|
|
224
|
+
{"type": "GPU", "value": str(resources.gpu_count or 1)},
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
response = client.submit_job(
|
|
228
|
+
jobName=job_name,
|
|
229
|
+
jobQueue=self.job_queue,
|
|
230
|
+
jobDefinition=self.job_definition,
|
|
231
|
+
containerOverrides=container_overrides,
|
|
232
|
+
parameters=self.parameters,
|
|
233
|
+
)
|
|
234
|
+
job_id = response["jobId"]
|
|
235
|
+
|
|
236
|
+
# Create wait function
|
|
237
|
+
def wait_for_completion():
|
|
238
|
+
"""Poll job status until completion."""
|
|
239
|
+
while True:
|
|
240
|
+
status = self.get_run_status(job_id)
|
|
241
|
+
if status.is_finished:
|
|
242
|
+
if not status.is_successful:
|
|
243
|
+
raise RuntimeError(f"AWS Batch job {job_id} failed with status: {status}")
|
|
244
|
+
break
|
|
245
|
+
time.sleep(10) # Poll every 10 seconds
|
|
246
|
+
|
|
247
|
+
return SubmissionResult(
|
|
248
|
+
job_id=job_id,
|
|
249
|
+
wait_for_completion=wait_for_completion,
|
|
250
|
+
metadata={
|
|
251
|
+
"platform": "aws_batch",
|
|
252
|
+
"region": self.region,
|
|
253
|
+
"job_queue": self.job_queue,
|
|
254
|
+
"job_name": job_name,
|
|
255
|
+
},
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def get_run_status(self, job_id: str) -> ExecutionStatus:
|
|
259
|
+
"""Get status of AWS Batch job.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
job_id: The AWS Batch job ID.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Current execution status.
|
|
266
|
+
"""
|
|
267
|
+
from flowyml.core.execution_status import ExecutionStatus
|
|
268
|
+
|
|
269
|
+
client = self._client()
|
|
270
|
+
try:
|
|
271
|
+
response = client.describe_jobs(jobs=[job_id])
|
|
272
|
+
if not response.get("jobs"):
|
|
273
|
+
return ExecutionStatus.FAILED
|
|
274
|
+
|
|
275
|
+
job = response["jobs"][0]
|
|
276
|
+
status = job.get("status", "UNKNOWN")
|
|
277
|
+
|
|
278
|
+
# Map AWS Batch status to ExecutionStatus
|
|
279
|
+
status_map = {
|
|
280
|
+
"SUBMITTED": ExecutionStatus.PROVISIONING,
|
|
281
|
+
"PENDING": ExecutionStatus.PROVISIONING,
|
|
282
|
+
"RUNNABLE": ExecutionStatus.PROVISIONING,
|
|
283
|
+
"STARTING": ExecutionStatus.INITIALIZING,
|
|
284
|
+
"RUNNING": ExecutionStatus.RUNNING,
|
|
285
|
+
"SUCCEEDED": ExecutionStatus.COMPLETED,
|
|
286
|
+
"FAILED": ExecutionStatus.FAILED,
|
|
287
|
+
}
|
|
288
|
+
return status_map.get(status, ExecutionStatus.RUNNING)
|
|
289
|
+
except Exception as e:
|
|
290
|
+
print(f"Error fetching job status: {e}")
|
|
291
|
+
return ExecutionStatus.FAILED
|
|
292
|
+
|
|
293
|
+
def stop_run(self, job_id: str, graceful: bool = True) -> None:
|
|
294
|
+
"""Stop an AWS Batch job.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
job_id: The AWS Batch job ID.
|
|
298
|
+
graceful: If True, allow job to finish current work. If False, terminate immediately.
|
|
299
|
+
"""
|
|
300
|
+
client = self._client()
|
|
301
|
+
reason = "Stopped by user"
|
|
302
|
+
|
|
303
|
+
try:
|
|
304
|
+
if graceful:
|
|
305
|
+
# Cancel the job (allows cleanup)
|
|
306
|
+
client.cancel_job(jobId=job_id, reason=reason)
|
|
307
|
+
else:
|
|
308
|
+
# Terminate immediately
|
|
309
|
+
client.terminate_job(jobId=job_id, reason=reason)
|
|
310
|
+
except Exception as e:
|
|
311
|
+
print(f"Error stopping job {job_id}: {e}")
|
|
312
|
+
raise
|
|
313
|
+
|
|
314
|
+
def to_dict(self) -> dict[str, Any]:
|
|
315
|
+
return {
|
|
316
|
+
"name": self.name,
|
|
317
|
+
"type": "aws_batch",
|
|
318
|
+
"region": self.region,
|
|
319
|
+
"job_queue": self.job_queue,
|
|
320
|
+
"job_definition": self.job_definition,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@register_component(name="sagemaker")
|
|
325
|
+
class SageMakerOrchestrator(RemoteOrchestrator):
|
|
326
|
+
"""Amazon SageMaker Training/Inference Orchestrator."""
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
name: str = "sagemaker",
|
|
331
|
+
region: str = "us-east-1",
|
|
332
|
+
role_arn: str | None = None,
|
|
333
|
+
default_instance_type: str = "ml.m5.xlarge",
|
|
334
|
+
default_instance_count: int = 1,
|
|
335
|
+
volume_size_gb: int = 50,
|
|
336
|
+
output_path: str | None = None,
|
|
337
|
+
):
|
|
338
|
+
super().__init__(name)
|
|
339
|
+
self.region = region
|
|
340
|
+
self.role_arn = role_arn
|
|
341
|
+
self.default_instance_type = default_instance_type
|
|
342
|
+
self.default_instance_count = default_instance_count
|
|
343
|
+
self.volume_size_gb = volume_size_gb
|
|
344
|
+
self.output_path = output_path
|
|
345
|
+
|
|
346
|
+
def _client(self):
|
|
347
|
+
import boto3
|
|
348
|
+
|
|
349
|
+
return boto3.client("sagemaker", region_name=self.region)
|
|
350
|
+
|
|
351
|
+
def validate(self) -> bool:
|
|
352
|
+
if not self.role_arn:
|
|
353
|
+
raise ValueError("role_arn is required for SageMakerOrchestrator")
|
|
354
|
+
return True
|
|
355
|
+
|
|
356
|
+
def run_pipeline(
|
|
357
|
+
self,
|
|
358
|
+
pipeline: Any,
|
|
359
|
+
run_id: str,
|
|
360
|
+
resources: ResourceConfig | None = None,
|
|
361
|
+
docker_config: DockerConfig | None = None,
|
|
362
|
+
inputs: dict[str, Any] | None = None,
|
|
363
|
+
context: dict[str, Any] | None = None,
|
|
364
|
+
hyperparameters: dict[str, str] | None = None,
|
|
365
|
+
**kwargs,
|
|
366
|
+
) -> SubmissionResult:
|
|
367
|
+
client = self._client()
|
|
368
|
+
training_image = docker_config.image if docker_config and docker_config.image else kwargs.get("training_image")
|
|
369
|
+
if not training_image:
|
|
370
|
+
raise ValueError("A Docker image must be provided via DockerConfig for SageMaker training.")
|
|
371
|
+
|
|
372
|
+
instance_type = (
|
|
373
|
+
kwargs.get("instance_type") or (resources.machine_type if resources else None) or self.default_instance_type
|
|
374
|
+
)
|
|
375
|
+
instance_count = (
|
|
376
|
+
kwargs.get("instance_count")
|
|
377
|
+
or (resources.gpu_count if resources and resources.gpu_count else None)
|
|
378
|
+
or self.default_instance_count
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
job_name = kwargs.get("job_name") or f"{pipeline.name}-{getattr(pipeline, 'run_id', uuid.uuid4().hex)[:8]}"
|
|
382
|
+
training_input_s3 = kwargs.get("input_s3_uri")
|
|
383
|
+
output_path = kwargs.get("output_s3_uri") or self.output_path
|
|
384
|
+
if not output_path:
|
|
385
|
+
raise ValueError("output_path must be provided for SageMaker training outputs.")
|
|
386
|
+
|
|
387
|
+
create_kwargs = {
|
|
388
|
+
"TrainingJobName": job_name,
|
|
389
|
+
"AlgorithmSpecification": {
|
|
390
|
+
"TrainingImage": training_image,
|
|
391
|
+
"TrainingInputMode": "File",
|
|
392
|
+
},
|
|
393
|
+
"RoleArn": self.role_arn,
|
|
394
|
+
"OutputDataConfig": {"S3OutputPath": output_path},
|
|
395
|
+
"ResourceConfig": {
|
|
396
|
+
"InstanceType": instance_type,
|
|
397
|
+
"InstanceCount": instance_count,
|
|
398
|
+
"VolumeSizeInGB": kwargs.get("volume_size_gb", self.volume_size_gb),
|
|
399
|
+
},
|
|
400
|
+
"StoppingCondition": {"MaxRuntimeInSeconds": kwargs.get("max_runtime", 3600 * 24)},
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
if training_input_s3:
|
|
404
|
+
create_kwargs["InputDataConfig"] = [
|
|
405
|
+
{
|
|
406
|
+
"ChannelName": "training",
|
|
407
|
+
"DataSource": {
|
|
408
|
+
"S3DataSource": {
|
|
409
|
+
"S3DataType": "S3Prefix",
|
|
410
|
+
"S3Uri": training_input_s3,
|
|
411
|
+
"S3DataDistributionType": "FullyReplicated",
|
|
412
|
+
},
|
|
413
|
+
},
|
|
414
|
+
},
|
|
415
|
+
]
|
|
416
|
+
|
|
417
|
+
if hyperparameters:
|
|
418
|
+
create_kwargs["HyperParameters"] = hyperparameters
|
|
419
|
+
|
|
420
|
+
client.create_training_job(**create_kwargs)
|
|
421
|
+
|
|
422
|
+
# Return Submission Result
|
|
423
|
+
from flowyml.core.submission_result import SubmissionResult
|
|
424
|
+
import time
|
|
425
|
+
|
|
426
|
+
job_id = job_name
|
|
427
|
+
|
|
428
|
+
# Create wait function
|
|
429
|
+
def wait_for_completion():
|
|
430
|
+
"""Poll training job status until completion."""
|
|
431
|
+
while True:
|
|
432
|
+
status = self.get_run_status(job_id)
|
|
433
|
+
if status.is_finished:
|
|
434
|
+
if not status.is_successful:
|
|
435
|
+
raise RuntimeError(f"SageMaker job {job_id} failed with status: {status}")
|
|
436
|
+
break
|
|
437
|
+
time.sleep(15) # Poll every 15 seconds
|
|
438
|
+
|
|
439
|
+
return SubmissionResult(
|
|
440
|
+
job_id=job_id,
|
|
441
|
+
wait_for_completion=wait_for_completion,
|
|
442
|
+
metadata={
|
|
443
|
+
"platform": "sagemaker",
|
|
444
|
+
"region": self.region,
|
|
445
|
+
"instance_type": instance_type,
|
|
446
|
+
"job_name": job_name,
|
|
447
|
+
},
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
def deploy_model(
|
|
451
|
+
self,
|
|
452
|
+
model_artifact_s3_uri: str,
|
|
453
|
+
inference_image: str,
|
|
454
|
+
endpoint_name: str,
|
|
455
|
+
instance_type: str = "ml.m5.large",
|
|
456
|
+
instance_count: int = 1,
|
|
457
|
+
wait: bool = True,
|
|
458
|
+
) -> str:
|
|
459
|
+
client = self._client()
|
|
460
|
+
model_name = f"{endpoint_name}-model"
|
|
461
|
+
client.create_model(
|
|
462
|
+
ModelName=model_name,
|
|
463
|
+
PrimaryContainer={
|
|
464
|
+
"Image": inference_image,
|
|
465
|
+
"ModelDataUrl": model_artifact_s3_uri,
|
|
466
|
+
},
|
|
467
|
+
ExecutionRoleArn=self.role_arn,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
endpoint_config_name = f"{endpoint_name}-config"
|
|
471
|
+
client.create_endpoint_config(
|
|
472
|
+
EndpointConfigName=endpoint_config_name,
|
|
473
|
+
ProductionVariants=[
|
|
474
|
+
{
|
|
475
|
+
"VariantName": "AllTraffic",
|
|
476
|
+
"ModelName": model_name,
|
|
477
|
+
"InstanceType": instance_type,
|
|
478
|
+
"InitialInstanceCount": instance_count,
|
|
479
|
+
},
|
|
480
|
+
],
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
|
|
484
|
+
|
|
485
|
+
if wait:
|
|
486
|
+
waiter = client.get_waiter("endpoint_in_service")
|
|
487
|
+
waiter.wait(EndpointName=endpoint_name)
|
|
488
|
+
return endpoint_name
|
|
489
|
+
|
|
490
|
+
def get_run_status(self, job_id: str) -> ExecutionStatus:
|
|
491
|
+
"""Get status of SageMaker training job.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
job_id: The SageMaker training job name.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Current execution status.
|
|
498
|
+
"""
|
|
499
|
+
from flowyml.core.execution_status import ExecutionStatus
|
|
500
|
+
|
|
501
|
+
client = self._client()
|
|
502
|
+
try:
|
|
503
|
+
response = client.describe_training_job(TrainingJobName=job_id)
|
|
504
|
+
status = response.get("TrainingJobStatus", "Unknown")
|
|
505
|
+
|
|
506
|
+
# Map SageMaker status to ExecutionStatus
|
|
507
|
+
status_map = {
|
|
508
|
+
"InProgress": ExecutionStatus.RUNNING,
|
|
509
|
+
"Completed": ExecutionStatus.COMPLETED,
|
|
510
|
+
"Failed": ExecutionStatus.FAILED,
|
|
511
|
+
"Stopping": ExecutionStatus.STOPPING,
|
|
512
|
+
"Stopped": ExecutionStatus.STOPPED,
|
|
513
|
+
}
|
|
514
|
+
return status_map.get(status, ExecutionStatus.RUNNING)
|
|
515
|
+
except Exception as e:
|
|
516
|
+
print(f"Error fetching training job status: {e}")
|
|
517
|
+
return ExecutionStatus.FAILED
|
|
518
|
+
|
|
519
|
+
def stop_run(self, job_id: str, graceful: bool = True) -> None:
|
|
520
|
+
"""Stop a SageMaker training job.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
job_id: The SageMaker training job name.
|
|
524
|
+
graceful: Graceful shutdown (SageMaker always stops gracefully).
|
|
525
|
+
"""
|
|
526
|
+
client = self._client()
|
|
527
|
+
|
|
528
|
+
try:
|
|
529
|
+
client.stop_training_job(TrainingJobName=job_id)
|
|
530
|
+
except Exception as e:
|
|
531
|
+
print(f"Error stopping training job {job_id}: {e}")
|
|
532
|
+
raise
|
|
533
|
+
|
|
534
|
+
def to_dict(self) -> dict[str, Any]:
|
|
535
|
+
return {
|
|
536
|
+
"name": self.name,
|
|
537
|
+
"type": "sagemaker",
|
|
538
|
+
"region": self.region,
|
|
539
|
+
"role_arn": self.role_arn,
|
|
540
|
+
"default_instance_type": self.default_instance_type,
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class AWSStack(Stack):
|
|
545
|
+
"""Pre-built stack for AWS Batch, S3, and ECR."""
|
|
546
|
+
|
|
547
|
+
def __init__(
|
|
548
|
+
self,
|
|
549
|
+
name: str = "aws",
|
|
550
|
+
region: str = "us-east-1",
|
|
551
|
+
bucket_name: str | None = None,
|
|
552
|
+
account_id: str | None = None,
|
|
553
|
+
job_queue: str | None = None,
|
|
554
|
+
job_definition: str | None = None,
|
|
555
|
+
registry_alias: str | None = None,
|
|
556
|
+
orchestrator_type: str = "batch",
|
|
557
|
+
role_arn: str | None = None,
|
|
558
|
+
metadata_store: Any | None = None,
|
|
559
|
+
):
|
|
560
|
+
orchestrator: Orchestrator
|
|
561
|
+
if orchestrator_type == "sagemaker":
|
|
562
|
+
orchestrator = SageMakerOrchestrator(region=region, role_arn=role_arn)
|
|
563
|
+
else:
|
|
564
|
+
orchestrator = AWSBatchOrchestrator(region=region, job_queue=job_queue, job_definition=job_definition)
|
|
565
|
+
|
|
566
|
+
artifact_store = S3ArtifactStore(bucket_name=bucket_name, region=region)
|
|
567
|
+
container_registry = ECRContainerRegistry(account_id=account_id, region=region, registry_alias=registry_alias)
|
|
568
|
+
|
|
569
|
+
if metadata_store is None:
|
|
570
|
+
metadata_store = SQLiteMetadataStore()
|
|
571
|
+
|
|
572
|
+
super().__init__(
|
|
573
|
+
name=name,
|
|
574
|
+
executor=None,
|
|
575
|
+
artifact_store=artifact_store,
|
|
576
|
+
metadata_store=metadata_store,
|
|
577
|
+
container_registry=container_registry,
|
|
578
|
+
orchestrator=orchestrator,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
self.region = region
|
|
582
|
+
self.bucket_name = bucket_name
|
|
583
|
+
|
|
584
|
+
def validate(self) -> bool:
|
|
585
|
+
self.orchestrator.validate()
|
|
586
|
+
self.artifact_store.validate()
|
|
587
|
+
self.container_registry.validate()
|
|
588
|
+
return True
|
|
589
|
+
|
|
590
|
+
def to_dict(self) -> dict[str, Any]:
|
|
591
|
+
return {
|
|
592
|
+
"name": self.name,
|
|
593
|
+
"type": "aws",
|
|
594
|
+
"region": self.region,
|
|
595
|
+
"bucket_name": self.bucket_name,
|
|
596
|
+
"orchestrator": self.orchestrator.to_dict(),
|
|
597
|
+
"artifact_store": self.artifact_store.to_dict(),
|
|
598
|
+
"container_registry": self.container_registry.to_dict(),
|
|
599
|
+
}
|