flowyml 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. flowyml/__init__.py +3 -0
  2. flowyml/assets/base.py +10 -0
  3. flowyml/assets/metrics.py +6 -0
  4. flowyml/cli/main.py +108 -2
  5. flowyml/cli/run.py +9 -2
  6. flowyml/core/execution_status.py +52 -0
  7. flowyml/core/hooks.py +106 -0
  8. flowyml/core/observability.py +210 -0
  9. flowyml/core/orchestrator.py +274 -0
  10. flowyml/core/pipeline.py +193 -231
  11. flowyml/core/project.py +34 -2
  12. flowyml/core/remote_orchestrator.py +109 -0
  13. flowyml/core/resources.py +34 -17
  14. flowyml/core/retry_policy.py +80 -0
  15. flowyml/core/scheduler.py +9 -9
  16. flowyml/core/scheduler_config.py +2 -3
  17. flowyml/core/step.py +18 -1
  18. flowyml/core/submission_result.py +53 -0
  19. flowyml/integrations/keras.py +95 -22
  20. flowyml/monitoring/alerts.py +2 -2
  21. flowyml/stacks/__init__.py +15 -0
  22. flowyml/stacks/aws.py +599 -0
  23. flowyml/stacks/azure.py +295 -0
  24. flowyml/stacks/bridge.py +9 -9
  25. flowyml/stacks/components.py +24 -2
  26. flowyml/stacks/gcp.py +158 -11
  27. flowyml/stacks/local.py +5 -0
  28. flowyml/stacks/plugins.py +2 -2
  29. flowyml/stacks/registry.py +21 -0
  30. flowyml/storage/artifacts.py +15 -5
  31. flowyml/storage/materializers/__init__.py +2 -0
  32. flowyml/storage/materializers/base.py +33 -0
  33. flowyml/storage/materializers/cloudpickle.py +74 -0
  34. flowyml/storage/metadata.py +3 -881
  35. flowyml/storage/remote.py +590 -0
  36. flowyml/storage/sql.py +911 -0
  37. flowyml/ui/backend/dependencies.py +28 -0
  38. flowyml/ui/backend/main.py +43 -80
  39. flowyml/ui/backend/routers/assets.py +483 -17
  40. flowyml/ui/backend/routers/client.py +46 -0
  41. flowyml/ui/backend/routers/execution.py +13 -2
  42. flowyml/ui/backend/routers/experiments.py +97 -14
  43. flowyml/ui/backend/routers/metrics.py +168 -0
  44. flowyml/ui/backend/routers/pipelines.py +77 -12
  45. flowyml/ui/backend/routers/projects.py +33 -7
  46. flowyml/ui/backend/routers/runs.py +221 -12
  47. flowyml/ui/backend/routers/schedules.py +5 -21
  48. flowyml/ui/backend/routers/stats.py +14 -0
  49. flowyml/ui/backend/routers/traces.py +37 -53
  50. flowyml/ui/frontend/dist/assets/index-DcYwrn2j.css +1 -0
  51. flowyml/ui/frontend/dist/assets/index-Dlz_ygOL.js +592 -0
  52. flowyml/ui/frontend/dist/index.html +2 -2
  53. flowyml/ui/frontend/src/App.jsx +4 -1
  54. flowyml/ui/frontend/src/app/assets/page.jsx +260 -230
  55. flowyml/ui/frontend/src/app/dashboard/page.jsx +38 -7
  56. flowyml/ui/frontend/src/app/experiments/page.jsx +61 -314
  57. flowyml/ui/frontend/src/app/observability/page.jsx +277 -0
  58. flowyml/ui/frontend/src/app/pipelines/page.jsx +79 -402
  59. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectArtifactsList.jsx +151 -0
  60. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectExperimentsList.jsx +145 -0
  61. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectHeader.jsx +45 -0
  62. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectHierarchy.jsx +467 -0
  63. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectMetricsPanel.jsx +253 -0
  64. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectPipelinesList.jsx +105 -0
  65. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectRelations.jsx +189 -0
  66. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectRunsList.jsx +136 -0
  67. flowyml/ui/frontend/src/app/projects/[projectId]/_components/ProjectTabs.jsx +95 -0
  68. flowyml/ui/frontend/src/app/projects/[projectId]/page.jsx +326 -0
  69. flowyml/ui/frontend/src/app/projects/page.jsx +13 -3
  70. flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +79 -10
  71. flowyml/ui/frontend/src/app/runs/page.jsx +82 -424
  72. flowyml/ui/frontend/src/app/settings/page.jsx +1 -0
  73. flowyml/ui/frontend/src/app/tokens/page.jsx +62 -16
  74. flowyml/ui/frontend/src/components/AssetDetailsPanel.jsx +373 -0
  75. flowyml/ui/frontend/src/components/AssetLineageGraph.jsx +291 -0
  76. flowyml/ui/frontend/src/components/AssetStatsDashboard.jsx +302 -0
  77. flowyml/ui/frontend/src/components/AssetTreeHierarchy.jsx +477 -0
  78. flowyml/ui/frontend/src/components/ExperimentDetailsPanel.jsx +227 -0
  79. flowyml/ui/frontend/src/components/NavigationTree.jsx +401 -0
  80. flowyml/ui/frontend/src/components/PipelineDetailsPanel.jsx +239 -0
  81. flowyml/ui/frontend/src/components/PipelineGraph.jsx +67 -3
  82. flowyml/ui/frontend/src/components/ProjectSelector.jsx +115 -0
  83. flowyml/ui/frontend/src/components/RunDetailsPanel.jsx +298 -0
  84. flowyml/ui/frontend/src/components/header/Header.jsx +48 -1
  85. flowyml/ui/frontend/src/components/plugins/ZenMLIntegration.jsx +106 -0
  86. flowyml/ui/frontend/src/components/sidebar/Sidebar.jsx +52 -26
  87. flowyml/ui/frontend/src/components/ui/DataView.jsx +35 -17
  88. flowyml/ui/frontend/src/components/ui/ErrorBoundary.jsx +118 -0
  89. flowyml/ui/frontend/src/contexts/ProjectContext.jsx +2 -2
  90. flowyml/ui/frontend/src/contexts/ToastContext.jsx +116 -0
  91. flowyml/ui/frontend/src/layouts/MainLayout.jsx +5 -1
  92. flowyml/ui/frontend/src/router/index.jsx +4 -0
  93. flowyml/ui/frontend/src/utils/date.js +10 -0
  94. flowyml/ui/frontend/src/utils/downloads.js +11 -0
  95. flowyml/utils/config.py +6 -0
  96. flowyml/utils/stack_config.py +45 -3
  97. {flowyml-1.2.0.dist-info → flowyml-1.4.0.dist-info}/METADATA +44 -4
  98. flowyml-1.4.0.dist-info/RECORD +200 -0
  99. {flowyml-1.2.0.dist-info → flowyml-1.4.0.dist-info}/licenses/LICENSE +1 -1
  100. flowyml/ui/frontend/dist/assets/index-DFNQnrUj.js +0 -448
  101. flowyml/ui/frontend/dist/assets/index-pWI271rZ.css +0 -1
  102. flowyml-1.2.0.dist-info/RECORD +0 -159
  103. {flowyml-1.2.0.dist-info → flowyml-1.4.0.dist-info}/WHEEL +0 -0
  104. {flowyml-1.2.0.dist-info → flowyml-1.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass, field
2
2
  from enum import Enum
3
- from typing import Any, Never
3
+ from typing import Any, NoReturn
4
4
  from datetime import datetime
5
5
  import logging
6
6
 
@@ -24,7 +24,7 @@ class Alert:
24
24
 
25
25
 
26
26
  class AlertHandler:
27
- def handle(self, alert: Alert) -> Never:
27
+ def handle(self, alert: Alert) -> NoReturn:
28
28
  raise NotImplementedError
29
29
 
30
30
 
@@ -2,6 +2,9 @@
2
2
 
3
3
  from flowyml.stacks.base import Stack, StackConfig
4
4
  from flowyml.stacks.local import LocalStack
5
+ from flowyml.stacks.gcp import GCPStack, VertexAIOrchestrator, GCSArtifactStore, GCRContainerRegistry
6
+ from flowyml.stacks.aws import AWSStack, AWSBatchOrchestrator, S3ArtifactStore, ECRContainerRegistry
7
+ from flowyml.stacks.azure import AzureMLStack, AzureMLOrchestrator, AzureBlobArtifactStore, ACRContainerRegistry
5
8
  from flowyml.stacks.components import (
6
9
  ResourceConfig,
7
10
  DockerConfig,
@@ -15,6 +18,18 @@ __all__ = [
15
18
  "Stack",
16
19
  "StackConfig",
17
20
  "LocalStack",
21
+ "GCPStack",
22
+ "AWSStack",
23
+ "AzureMLStack",
24
+ "VertexAIOrchestrator",
25
+ "AWSBatchOrchestrator",
26
+ "AzureMLOrchestrator",
27
+ "GCSArtifactStore",
28
+ "S3ArtifactStore",
29
+ "AzureBlobArtifactStore",
30
+ "GCRContainerRegistry",
31
+ "ECRContainerRegistry",
32
+ "ACRContainerRegistry",
18
33
  "ResourceConfig",
19
34
  "DockerConfig",
20
35
  "Orchestrator",
flowyml/stacks/aws.py ADDED
@@ -0,0 +1,599 @@
1
+ """AWS Stack Components and Preset Stack."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import subprocess
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from flowyml.stacks.base import Stack
12
+ from flowyml.stacks.components import ArtifactStore, ContainerRegistry, ResourceConfig, DockerConfig
13
+ from flowyml.core.remote_orchestrator import RemoteOrchestrator
14
+ from flowyml.stacks.plugins import register_component
15
+ from flowyml.storage.metadata import SQLiteMetadataStore
16
+ from flowyml.core.submission_result import SubmissionResult
17
+ from flowyml.core.execution_status import ExecutionStatus
18
+ from flowyml.stacks.components import Orchestrator
19
+
20
+
21
+ @register_component(name="s3")
22
+ class S3ArtifactStore(ArtifactStore):
23
+ """Artifact store backed by Amazon S3."""
24
+
25
+ def __init__(
26
+ self,
27
+ name: str = "s3",
28
+ bucket_name: str | None = None,
29
+ prefix: str = "flowyml",
30
+ region: str | None = None,
31
+ session_kwargs: dict[str, Any] | None = None,
32
+ ):
33
+ super().__init__(name)
34
+ self.bucket_name = bucket_name
35
+ self.prefix = prefix.strip("/")
36
+ self.region = region
37
+ self.session_kwargs = session_kwargs or {}
38
+
39
+ def _client(self):
40
+ import boto3
41
+
42
+ return boto3.client("s3", region_name=self.region, **self.session_kwargs)
43
+
44
+ def _object_key(self, path: str) -> str:
45
+ normalized = path.lstrip("/")
46
+ return f"{self.prefix}/{normalized}" if self.prefix else normalized
47
+
48
+ def validate(self) -> bool:
49
+ if not self.bucket_name:
50
+ raise ValueError("bucket_name is required for S3ArtifactStore")
51
+ try:
52
+ self._client().head_bucket(Bucket=self.bucket_name)
53
+ except Exception as exc:
54
+ raise ValueError(f"Unable to access bucket '{self.bucket_name}': {exc}") from exc
55
+ return True
56
+
57
+ def save(self, artifact: Any, path: str) -> str:
58
+ """Save artifact to S3. Accepts file paths, bytes, or strings."""
59
+ key = self._object_key(path)
60
+ client = self._client()
61
+
62
+ if isinstance(artifact, (str, Path)) and Path(artifact).exists():
63
+ client.upload_file(str(Path(artifact)), self.bucket_name, key)
64
+ else:
65
+ body = artifact if isinstance(artifact, bytes) else str(artifact).encode()
66
+ client.put_object(Bucket=self.bucket_name, Key=key, Body=body)
67
+
68
+ return f"s3://{self.bucket_name}/{key}"
69
+
70
+ def load(self, path: str) -> bytes:
71
+ key = self._object_key(path)
72
+ client = self._client()
73
+ obj = client.get_object(Bucket=self.bucket_name, Key=key)
74
+ return obj["Body"].read()
75
+
76
+ def exists(self, path: str) -> bool:
77
+ key = self._object_key(path)
78
+ client = self._client()
79
+ try:
80
+ client.head_object(Bucket=self.bucket_name, Key=key)
81
+ return True
82
+ except Exception:
83
+ return False
84
+
85
+ def to_dict(self) -> dict[str, Any]:
86
+ return {
87
+ "name": self.name,
88
+ "type": "s3",
89
+ "bucket_name": self.bucket_name,
90
+ "prefix": self.prefix,
91
+ "region": self.region,
92
+ }
93
+
94
+
95
+ @register_component(name="ecr")
96
+ class ECRContainerRegistry(ContainerRegistry):
97
+ """Amazon Elastic Container Registry integration."""
98
+
99
+ def __init__(
100
+ self,
101
+ name: str = "ecr",
102
+ account_id: str | None = None,
103
+ region: str = "us-east-1",
104
+ registry_alias: str | None = None,
105
+ ):
106
+ super().__init__(name)
107
+ self.account_id = account_id
108
+ self.region = region
109
+ self.registry_alias = registry_alias
110
+
111
+ def _client(self):
112
+ import boto3
113
+
114
+ return boto3.client("ecr", region_name=self.region)
115
+
116
+ def validate(self) -> bool:
117
+ if not self.account_id:
118
+ raise ValueError("account_id is required for ECRContainerRegistry")
119
+ return True
120
+
121
+ def _login(self) -> None:
122
+ client = self._client()
123
+ auth = client.get_authorization_token()
124
+ data = auth["authorizationData"][0]
125
+ token = base64.b64decode(data["authorizationToken"]).decode()
126
+ username, password = token.split(":")
127
+ endpoint = data["proxyEndpoint"]
128
+ subprocess.run(["docker", "login", "--username", username, "--password", password, endpoint], check=True)
129
+
130
+ def push_image(self, image_name: str, tag: str = "latest") -> str:
131
+ full_uri = self.get_image_uri(image_name, tag)
132
+ self._login()
133
+ subprocess.run(["docker", "tag", f"{image_name}:{tag}", full_uri], check=True)
134
+ subprocess.run(["docker", "push", full_uri], check=True)
135
+ return full_uri
136
+
137
+ def pull_image(self, image_name: str, tag: str = "latest") -> None:
138
+ full_uri = self.get_image_uri(image_name, tag)
139
+ self._login()
140
+ subprocess.run(["docker", "pull", full_uri], check=True)
141
+
142
+ def get_image_uri(self, image_name: str, tag: str = "latest") -> str:
143
+ registry = self.registry_alias or f"{self.account_id}.dkr.ecr.{self.region}.amazonaws.com"
144
+ return f"{registry}/{image_name}:{tag}"
145
+
146
+ def to_dict(self) -> dict[str, Any]:
147
+ return {
148
+ "name": self.name,
149
+ "type": "ecr",
150
+ "account_id": self.account_id,
151
+ "region": self.region,
152
+ "registry_alias": self.registry_alias,
153
+ }
154
+
155
+
156
+ @register_component(name="aws_batch")
157
+ class AWSBatchOrchestrator(RemoteOrchestrator):
158
+ """Submit Flow yML jobs to AWS Batch."""
159
+
160
+ def __init__(
161
+ self,
162
+ name: str = "aws_batch",
163
+ region: str = "us-east-1",
164
+ job_queue: str | None = None,
165
+ job_definition: str | None = None,
166
+ parameters: dict[str, str] | None = None,
167
+ ):
168
+ super().__init__(name)
169
+ self.region = region
170
+ self.job_queue = job_queue
171
+ self.job_definition = job_definition
172
+ self.parameters = parameters or {}
173
+
174
+ def _client(self):
175
+ import boto3
176
+
177
+ return boto3.client("batch", region_name=self.region)
178
+
179
+ def validate(self) -> bool:
180
+ if not self.job_queue or not self.job_definition:
181
+ raise ValueError("job_queue and job_definition are required for AWSBatchOrchestrator")
182
+ return True
183
+
184
+ def run_pipeline(
185
+ self,
186
+ pipeline: Any,
187
+ run_id: str,
188
+ resources: ResourceConfig | None = None,
189
+ docker_config: DockerConfig | None = None,
190
+ inputs: dict[str, Any] | None = None,
191
+ context: dict[str, Any] | None = None,
192
+ **kwargs,
193
+ ) -> SubmissionResult:
194
+ """Submit pipeline to AWS Batch.
195
+
196
+ Returns:
197
+ SubmissionResult with job ID and optional wait function.
198
+ """
199
+ from flowyml.core.submission_result import SubmissionResult
200
+ import time
201
+
202
+ client = self._client()
203
+ job_name = kwargs.get("job_name") or f"{pipeline.name}-{getattr(pipeline, 'run_id', uuid.uuid4().hex)[:8]}"
204
+
205
+ env = [
206
+ {"name": "FLOWYML_PIPELINE_NAME", "value": pipeline.name},
207
+ {"name": "FLOWYML_RUN_ID", "value": getattr(pipeline, "run_id", uuid.uuid4().hex)},
208
+ ]
209
+ if docker_config and docker_config.env_vars:
210
+ for key, value in docker_config.env_vars.items():
211
+ env.append({"name": key, "value": value})
212
+
213
+ container_overrides: dict[str, Any] = {"environment": env}
214
+ if docker_config and docker_config.image:
215
+ container_overrides["command"] = ["python", "-m", "flowyml.cli.run"]
216
+
217
+ if resources:
218
+ container_overrides["resourceRequirements"] = [
219
+ {"type": "VCPU", "value": resources.cpu},
220
+ {"type": "MEMORY", "value": resources.memory.replace("Gi", "")},
221
+ ]
222
+ if resources.gpu:
223
+ container_overrides["resourceRequirements"].append(
224
+ {"type": "GPU", "value": str(resources.gpu_count or 1)},
225
+ )
226
+
227
+ response = client.submit_job(
228
+ jobName=job_name,
229
+ jobQueue=self.job_queue,
230
+ jobDefinition=self.job_definition,
231
+ containerOverrides=container_overrides,
232
+ parameters=self.parameters,
233
+ )
234
+ job_id = response["jobId"]
235
+
236
+ # Create wait function
237
+ def wait_for_completion():
238
+ """Poll job status until completion."""
239
+ while True:
240
+ status = self.get_run_status(job_id)
241
+ if status.is_finished:
242
+ if not status.is_successful:
243
+ raise RuntimeError(f"AWS Batch job {job_id} failed with status: {status}")
244
+ break
245
+ time.sleep(10) # Poll every 10 seconds
246
+
247
+ return SubmissionResult(
248
+ job_id=job_id,
249
+ wait_for_completion=wait_for_completion,
250
+ metadata={
251
+ "platform": "aws_batch",
252
+ "region": self.region,
253
+ "job_queue": self.job_queue,
254
+ "job_name": job_name,
255
+ },
256
+ )
257
+
258
+ def get_run_status(self, job_id: str) -> ExecutionStatus:
259
+ """Get status of AWS Batch job.
260
+
261
+ Args:
262
+ job_id: The AWS Batch job ID.
263
+
264
+ Returns:
265
+ Current execution status.
266
+ """
267
+ from flowyml.core.execution_status import ExecutionStatus
268
+
269
+ client = self._client()
270
+ try:
271
+ response = client.describe_jobs(jobs=[job_id])
272
+ if not response.get("jobs"):
273
+ return ExecutionStatus.FAILED
274
+
275
+ job = response["jobs"][0]
276
+ status = job.get("status", "UNKNOWN")
277
+
278
+ # Map AWS Batch status to ExecutionStatus
279
+ status_map = {
280
+ "SUBMITTED": ExecutionStatus.PROVISIONING,
281
+ "PENDING": ExecutionStatus.PROVISIONING,
282
+ "RUNNABLE": ExecutionStatus.PROVISIONING,
283
+ "STARTING": ExecutionStatus.INITIALIZING,
284
+ "RUNNING": ExecutionStatus.RUNNING,
285
+ "SUCCEEDED": ExecutionStatus.COMPLETED,
286
+ "FAILED": ExecutionStatus.FAILED,
287
+ }
288
+ return status_map.get(status, ExecutionStatus.RUNNING)
289
+ except Exception as e:
290
+ print(f"Error fetching job status: {e}")
291
+ return ExecutionStatus.FAILED
292
+
293
+ def stop_run(self, job_id: str, graceful: bool = True) -> None:
294
+ """Stop an AWS Batch job.
295
+
296
+ Args:
297
+ job_id: The AWS Batch job ID.
298
+ graceful: If True, allow job to finish current work. If False, terminate immediately.
299
+ """
300
+ client = self._client()
301
+ reason = "Stopped by user"
302
+
303
+ try:
304
+ if graceful:
305
+ # Cancel the job (allows cleanup)
306
+ client.cancel_job(jobId=job_id, reason=reason)
307
+ else:
308
+ # Terminate immediately
309
+ client.terminate_job(jobId=job_id, reason=reason)
310
+ except Exception as e:
311
+ print(f"Error stopping job {job_id}: {e}")
312
+ raise
313
+
314
+ def to_dict(self) -> dict[str, Any]:
315
+ return {
316
+ "name": self.name,
317
+ "type": "aws_batch",
318
+ "region": self.region,
319
+ "job_queue": self.job_queue,
320
+ "job_definition": self.job_definition,
321
+ }
322
+
323
+
324
+ @register_component(name="sagemaker")
325
+ class SageMakerOrchestrator(RemoteOrchestrator):
326
+ """Amazon SageMaker Training/Inference Orchestrator."""
327
+
328
+ def __init__(
329
+ self,
330
+ name: str = "sagemaker",
331
+ region: str = "us-east-1",
332
+ role_arn: str | None = None,
333
+ default_instance_type: str = "ml.m5.xlarge",
334
+ default_instance_count: int = 1,
335
+ volume_size_gb: int = 50,
336
+ output_path: str | None = None,
337
+ ):
338
+ super().__init__(name)
339
+ self.region = region
340
+ self.role_arn = role_arn
341
+ self.default_instance_type = default_instance_type
342
+ self.default_instance_count = default_instance_count
343
+ self.volume_size_gb = volume_size_gb
344
+ self.output_path = output_path
345
+
346
+ def _client(self):
347
+ import boto3
348
+
349
+ return boto3.client("sagemaker", region_name=self.region)
350
+
351
+ def validate(self) -> bool:
352
+ if not self.role_arn:
353
+ raise ValueError("role_arn is required for SageMakerOrchestrator")
354
+ return True
355
+
356
+ def run_pipeline(
357
+ self,
358
+ pipeline: Any,
359
+ run_id: str,
360
+ resources: ResourceConfig | None = None,
361
+ docker_config: DockerConfig | None = None,
362
+ inputs: dict[str, Any] | None = None,
363
+ context: dict[str, Any] | None = None,
364
+ hyperparameters: dict[str, str] | None = None,
365
+ **kwargs,
366
+ ) -> SubmissionResult:
367
+ client = self._client()
368
+ training_image = docker_config.image if docker_config and docker_config.image else kwargs.get("training_image")
369
+ if not training_image:
370
+ raise ValueError("A Docker image must be provided via DockerConfig for SageMaker training.")
371
+
372
+ instance_type = (
373
+ kwargs.get("instance_type") or (resources.machine_type if resources else None) or self.default_instance_type
374
+ )
375
+ instance_count = (
376
+ kwargs.get("instance_count")
377
+ or (resources.gpu_count if resources and resources.gpu_count else None)
378
+ or self.default_instance_count
379
+ )
380
+
381
+ job_name = kwargs.get("job_name") or f"{pipeline.name}-{getattr(pipeline, 'run_id', uuid.uuid4().hex)[:8]}"
382
+ training_input_s3 = kwargs.get("input_s3_uri")
383
+ output_path = kwargs.get("output_s3_uri") or self.output_path
384
+ if not output_path:
385
+ raise ValueError("output_path must be provided for SageMaker training outputs.")
386
+
387
+ create_kwargs = {
388
+ "TrainingJobName": job_name,
389
+ "AlgorithmSpecification": {
390
+ "TrainingImage": training_image,
391
+ "TrainingInputMode": "File",
392
+ },
393
+ "RoleArn": self.role_arn,
394
+ "OutputDataConfig": {"S3OutputPath": output_path},
395
+ "ResourceConfig": {
396
+ "InstanceType": instance_type,
397
+ "InstanceCount": instance_count,
398
+ "VolumeSizeInGB": kwargs.get("volume_size_gb", self.volume_size_gb),
399
+ },
400
+ "StoppingCondition": {"MaxRuntimeInSeconds": kwargs.get("max_runtime", 3600 * 24)},
401
+ }
402
+
403
+ if training_input_s3:
404
+ create_kwargs["InputDataConfig"] = [
405
+ {
406
+ "ChannelName": "training",
407
+ "DataSource": {
408
+ "S3DataSource": {
409
+ "S3DataType": "S3Prefix",
410
+ "S3Uri": training_input_s3,
411
+ "S3DataDistributionType": "FullyReplicated",
412
+ },
413
+ },
414
+ },
415
+ ]
416
+
417
+ if hyperparameters:
418
+ create_kwargs["HyperParameters"] = hyperparameters
419
+
420
+ client.create_training_job(**create_kwargs)
421
+
422
+ # Return Submission Result
423
+ from flowyml.core.submission_result import SubmissionResult
424
+ import time
425
+
426
+ job_id = job_name
427
+
428
+ # Create wait function
429
+ def wait_for_completion():
430
+ """Poll training job status until completion."""
431
+ while True:
432
+ status = self.get_run_status(job_id)
433
+ if status.is_finished:
434
+ if not status.is_successful:
435
+ raise RuntimeError(f"SageMaker job {job_id} failed with status: {status}")
436
+ break
437
+ time.sleep(15) # Poll every 15 seconds
438
+
439
+ return SubmissionResult(
440
+ job_id=job_id,
441
+ wait_for_completion=wait_for_completion,
442
+ metadata={
443
+ "platform": "sagemaker",
444
+ "region": self.region,
445
+ "instance_type": instance_type,
446
+ "job_name": job_name,
447
+ },
448
+ )
449
+
450
+ def deploy_model(
451
+ self,
452
+ model_artifact_s3_uri: str,
453
+ inference_image: str,
454
+ endpoint_name: str,
455
+ instance_type: str = "ml.m5.large",
456
+ instance_count: int = 1,
457
+ wait: bool = True,
458
+ ) -> str:
459
+ client = self._client()
460
+ model_name = f"{endpoint_name}-model"
461
+ client.create_model(
462
+ ModelName=model_name,
463
+ PrimaryContainer={
464
+ "Image": inference_image,
465
+ "ModelDataUrl": model_artifact_s3_uri,
466
+ },
467
+ ExecutionRoleArn=self.role_arn,
468
+ )
469
+
470
+ endpoint_config_name = f"{endpoint_name}-config"
471
+ client.create_endpoint_config(
472
+ EndpointConfigName=endpoint_config_name,
473
+ ProductionVariants=[
474
+ {
475
+ "VariantName": "AllTraffic",
476
+ "ModelName": model_name,
477
+ "InstanceType": instance_type,
478
+ "InitialInstanceCount": instance_count,
479
+ },
480
+ ],
481
+ )
482
+
483
+ client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
484
+
485
+ if wait:
486
+ waiter = client.get_waiter("endpoint_in_service")
487
+ waiter.wait(EndpointName=endpoint_name)
488
+ return endpoint_name
489
+
490
+ def get_run_status(self, job_id: str) -> ExecutionStatus:
491
+ """Get status of SageMaker training job.
492
+
493
+ Args:
494
+ job_id: The SageMaker training job name.
495
+
496
+ Returns:
497
+ Current execution status.
498
+ """
499
+ from flowyml.core.execution_status import ExecutionStatus
500
+
501
+ client = self._client()
502
+ try:
503
+ response = client.describe_training_job(TrainingJobName=job_id)
504
+ status = response.get("TrainingJobStatus", "Unknown")
505
+
506
+ # Map SageMaker status to ExecutionStatus
507
+ status_map = {
508
+ "InProgress": ExecutionStatus.RUNNING,
509
+ "Completed": ExecutionStatus.COMPLETED,
510
+ "Failed": ExecutionStatus.FAILED,
511
+ "Stopping": ExecutionStatus.STOPPING,
512
+ "Stopped": ExecutionStatus.STOPPED,
513
+ }
514
+ return status_map.get(status, ExecutionStatus.RUNNING)
515
+ except Exception as e:
516
+ print(f"Error fetching training job status: {e}")
517
+ return ExecutionStatus.FAILED
518
+
519
+ def stop_run(self, job_id: str, graceful: bool = True) -> None:
520
+ """Stop a SageMaker training job.
521
+
522
+ Args:
523
+ job_id: The SageMaker training job name.
524
+ graceful: Graceful shutdown (SageMaker always stops gracefully).
525
+ """
526
+ client = self._client()
527
+
528
+ try:
529
+ client.stop_training_job(TrainingJobName=job_id)
530
+ except Exception as e:
531
+ print(f"Error stopping training job {job_id}: {e}")
532
+ raise
533
+
534
+ def to_dict(self) -> dict[str, Any]:
535
+ return {
536
+ "name": self.name,
537
+ "type": "sagemaker",
538
+ "region": self.region,
539
+ "role_arn": self.role_arn,
540
+ "default_instance_type": self.default_instance_type,
541
+ }
542
+
543
+
544
+ class AWSStack(Stack):
545
+ """Pre-built stack for AWS Batch, S3, and ECR."""
546
+
547
+ def __init__(
548
+ self,
549
+ name: str = "aws",
550
+ region: str = "us-east-1",
551
+ bucket_name: str | None = None,
552
+ account_id: str | None = None,
553
+ job_queue: str | None = None,
554
+ job_definition: str | None = None,
555
+ registry_alias: str | None = None,
556
+ orchestrator_type: str = "batch",
557
+ role_arn: str | None = None,
558
+ metadata_store: Any | None = None,
559
+ ):
560
+ orchestrator: Orchestrator
561
+ if orchestrator_type == "sagemaker":
562
+ orchestrator = SageMakerOrchestrator(region=region, role_arn=role_arn)
563
+ else:
564
+ orchestrator = AWSBatchOrchestrator(region=region, job_queue=job_queue, job_definition=job_definition)
565
+
566
+ artifact_store = S3ArtifactStore(bucket_name=bucket_name, region=region)
567
+ container_registry = ECRContainerRegistry(account_id=account_id, region=region, registry_alias=registry_alias)
568
+
569
+ if metadata_store is None:
570
+ metadata_store = SQLiteMetadataStore()
571
+
572
+ super().__init__(
573
+ name=name,
574
+ executor=None,
575
+ artifact_store=artifact_store,
576
+ metadata_store=metadata_store,
577
+ container_registry=container_registry,
578
+ orchestrator=orchestrator,
579
+ )
580
+
581
+ self.region = region
582
+ self.bucket_name = bucket_name
583
+
584
+ def validate(self) -> bool:
585
+ self.orchestrator.validate()
586
+ self.artifact_store.validate()
587
+ self.container_registry.validate()
588
+ return True
589
+
590
+ def to_dict(self) -> dict[str, Any]:
591
+ return {
592
+ "name": self.name,
593
+ "type": "aws",
594
+ "region": self.region,
595
+ "bucket_name": self.bucket_name,
596
+ "orchestrator": self.orchestrator.to_dict(),
597
+ "artifact_store": self.artifact_store.to_dict(),
598
+ "container_registry": self.container_registry.to_dict(),
599
+ }