zenml-nightly 0.73.0.dev20250123__py3-none-any.whl → 0.73.0.dev20250125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/analytics/context.py +2 -6
- zenml/cli/annotator.py +1 -1
- zenml/cli/login.py +17 -6
- zenml/cli/server.py +1 -0
- zenml/cli/service_connectors.py +5 -5
- zenml/cli/stack.py +2 -2
- zenml/cli/utils.py +2 -54
- zenml/config/pipeline_configurations.py +3 -2
- zenml/config/schedule.py +0 -24
- zenml/enums.py +1 -0
- zenml/event_hub/base_event_hub.py +3 -4
- zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +3 -4
- zenml/integrations/aws/__init__.py +2 -1
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +15 -0
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +310 -70
- zenml/integrations/aws/service_connectors/aws_service_connector.py +8 -13
- zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -10
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +3 -3
- zenml/integrations/huggingface/__init__.py +1 -6
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +3 -3
- zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +6 -2
- zenml/integrations/whylogs/data_validators/whylogs_data_validator.py +2 -3
- zenml/logging/step_logging.py +7 -7
- zenml/login/credentials.py +6 -5
- zenml/login/credentials_store.py +4 -3
- zenml/models/v2/core/api_key.py +5 -2
- zenml/models/v2/core/schedule.py +19 -3
- zenml/orchestrators/publish_utils.py +4 -4
- zenml/orchestrators/step_launcher.py +3 -3
- zenml/orchestrators/step_run_utils.py +2 -2
- zenml/pipelines/run_utils.py +2 -2
- zenml/service_connectors/service_connector.py +7 -4
- zenml/stack/stack.py +5 -4
- zenml/stack/stack_component.py +10 -2
- zenml/stack_deployments/stack_deployment.py +2 -3
- zenml/utils/string_utils.py +2 -2
- zenml/utils/time_utils.py +138 -0
- zenml/zen_server/auth.py +8 -9
- zenml/zen_server/cloud_utils.py +4 -6
- zenml/zen_server/routers/devices_endpoints.py +2 -4
- zenml/zen_server/routers/workspaces_endpoints.py +2 -0
- zenml/zen_server/zen_server_api.py +9 -8
- zenml/zen_stores/migrations/versions/25155145c545_separate_actions_and_triggers.py +3 -2
- zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py +3 -3
- zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +3 -2
- zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +10 -7
- zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
- zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +5 -3
- zenml/zen_stores/schemas/action_schemas.py +2 -2
- zenml/zen_stores/schemas/api_key_schemas.py +5 -4
- zenml/zen_stores/schemas/artifact_schemas.py +3 -3
- zenml/zen_stores/schemas/base_schemas.py +5 -7
- zenml/zen_stores/schemas/code_repository_schemas.py +2 -2
- zenml/zen_stores/schemas/component_schemas.py +2 -2
- zenml/zen_stores/schemas/device_schemas.py +5 -4
- zenml/zen_stores/schemas/event_source_schemas.py +2 -2
- zenml/zen_stores/schemas/flavor_schemas.py +2 -2
- zenml/zen_stores/schemas/model_schemas.py +3 -3
- zenml/zen_stores/schemas/pipeline_run_schemas.py +11 -3
- zenml/zen_stores/schemas/pipeline_schemas.py +2 -2
- zenml/zen_stores/schemas/run_template_schemas.py +2 -2
- zenml/zen_stores/schemas/schedule_schema.py +20 -4
- zenml/zen_stores/schemas/secret_schemas.py +2 -2
- zenml/zen_stores/schemas/server_settings_schemas.py +6 -9
- zenml/zen_stores/schemas/service_connector_schemas.py +3 -2
- zenml/zen_stores/schemas/service_schemas.py +2 -2
- zenml/zen_stores/schemas/stack_schemas.py +2 -2
- zenml/zen_stores/schemas/step_run_schemas.py +3 -2
- zenml/zen_stores/schemas/tag_schemas.py +2 -2
- zenml/zen_stores/schemas/trigger_schemas.py +2 -2
- zenml/zen_stores/schemas/user_schemas.py +3 -3
- zenml/zen_stores/schemas/workspace_schemas.py +2 -2
- zenml/zen_stores/sql_zen_store.py +6 -14
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/METADATA +2 -2
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/RECORD +79 -78
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.73.0.
|
1
|
+
0.73.0.dev20250125
|
zenml/analytics/context.py
CHANGED
@@ -17,7 +17,6 @@ This module is based on the 'analytics-python' package created by Segment.
|
|
17
17
|
The base functionalities are adapted to work with the ZenML analytics server.
|
18
18
|
"""
|
19
19
|
|
20
|
-
import datetime
|
21
20
|
import locale
|
22
21
|
from types import TracebackType
|
23
22
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
@@ -32,6 +31,7 @@ from zenml.constants import (
|
|
32
31
|
)
|
33
32
|
from zenml.environment import Environment, get_environment
|
34
33
|
from zenml.logger import get_logger
|
34
|
+
from zenml.utils.time_utils import utc_now_tz_aware
|
35
35
|
|
36
36
|
if TYPE_CHECKING:
|
37
37
|
from zenml.analytics.enums import AnalyticsEvent
|
@@ -284,11 +284,7 @@ class AnalyticsContext:
|
|
284
284
|
|
285
285
|
try:
|
286
286
|
# Timezone as tzdata
|
287
|
-
tz = (
|
288
|
-
datetime.datetime.now(datetime.timezone.utc)
|
289
|
-
.astimezone()
|
290
|
-
.tzname()
|
291
|
-
)
|
287
|
+
tz = utc_now_tz_aware().astimezone().tzname()
|
292
288
|
if tz is not None:
|
293
289
|
properties.update({"timezone": tz})
|
294
290
|
|
zenml/cli/annotator.py
CHANGED
@@ -182,7 +182,7 @@ def register_annotator_subcommands() -> None:
|
|
182
182
|
kwargs_dict = {}
|
183
183
|
for arg in kwargs:
|
184
184
|
if arg.startswith("--"):
|
185
|
-
key, value = arg.
|
185
|
+
key, value = arg.removeprefix("--").split("=", 1)
|
186
186
|
kwargs_dict[key] = value
|
187
187
|
|
188
188
|
if annotator.flavor == "prodigy":
|
zenml/cli/login.py
CHANGED
@@ -249,6 +249,7 @@ def connect_to_pro_server(
|
|
249
249
|
from zenml.login.pro.tenant.models import TenantStatus
|
250
250
|
|
251
251
|
pro_api_url = pro_api_url or ZENML_PRO_API_URL
|
252
|
+
pro_api_url = pro_api_url.rstrip("/")
|
252
253
|
|
253
254
|
server_id, server_url, server_name = None, None, None
|
254
255
|
login = False
|
@@ -434,6 +435,7 @@ def is_pro_server(
|
|
434
435
|
from zenml.login.credentials_store import get_credentials_store
|
435
436
|
from zenml.login.server_info import get_server_info
|
436
437
|
|
438
|
+
url = url.rstrip("/")
|
437
439
|
# First, check the credentials store
|
438
440
|
credentials_store = get_credentials_store()
|
439
441
|
credentials = credentials_store.get_credentials(url)
|
@@ -790,15 +792,16 @@ def login(
|
|
790
792
|
)
|
791
793
|
|
792
794
|
if server is not None:
|
793
|
-
if
|
794
|
-
# The server argument is a
|
795
|
-
|
796
|
-
|
795
|
+
if re.match(r"^mysql://", server):
|
796
|
+
# The server argument is a MySQL URL, we can directly connect to it
|
797
|
+
|
798
|
+
connect_to_server(
|
799
|
+
url=server,
|
797
800
|
api_key=api_key_value,
|
801
|
+
verify_ssl=verify_ssl,
|
798
802
|
refresh=refresh,
|
799
|
-
pro_api_url=pro_api_url,
|
800
803
|
)
|
801
|
-
|
804
|
+
elif re.match(r"^https?://", server):
|
802
805
|
# The server argument is a server URL
|
803
806
|
|
804
807
|
# First, try to discover if the server is a ZenML Pro server or not
|
@@ -819,6 +822,14 @@ def login(
|
|
819
822
|
verify_ssl=verify_ssl,
|
820
823
|
refresh=refresh,
|
821
824
|
)
|
825
|
+
else:
|
826
|
+
# The server argument is a ZenML Pro server name or UUID
|
827
|
+
connect_to_pro_server(
|
828
|
+
pro_server=server,
|
829
|
+
api_key=api_key_value,
|
830
|
+
refresh=refresh,
|
831
|
+
pro_api_url=pro_api_url,
|
832
|
+
)
|
822
833
|
|
823
834
|
elif current_non_local_server:
|
824
835
|
# The server argument is not provided, so we default to
|
zenml/cli/server.py
CHANGED
@@ -578,6 +578,7 @@ def server_list(
|
|
578
578
|
from zenml.login.pro.tenant.models import TenantRead, TenantStatus
|
579
579
|
|
580
580
|
pro_api_url = pro_api_url or ZENML_PRO_API_URL
|
581
|
+
pro_api_url = pro_api_url.rstrip("/")
|
581
582
|
|
582
583
|
credentials_store = get_credentials_store()
|
583
584
|
pro_token = credentials_store.get_pro_token(
|
zenml/cli/service_connectors.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Service connector CLI commands."""
|
15
15
|
|
16
|
-
from datetime import datetime
|
16
|
+
from datetime import datetime
|
17
17
|
from typing import Any, Dict, List, Optional, Union, cast
|
18
18
|
from uuid import UUID
|
19
19
|
|
@@ -25,7 +25,6 @@ from zenml.cli.utils import (
|
|
25
25
|
is_sorted_or_filtered,
|
26
26
|
list_options,
|
27
27
|
print_page_info,
|
28
|
-
seconds_to_human_readable,
|
29
28
|
)
|
30
29
|
from zenml.client import Client
|
31
30
|
from zenml.console import console
|
@@ -37,6 +36,7 @@ from zenml.models import (
|
|
37
36
|
ServiceConnectorResourcesModel,
|
38
37
|
ServiceConnectorResponse,
|
39
38
|
)
|
39
|
+
from zenml.utils.time_utils import seconds_to_human_readable, utc_now
|
40
40
|
|
41
41
|
|
42
42
|
# Service connectors
|
@@ -292,7 +292,7 @@ def prompt_expires_at(
|
|
292
292
|
default_str = ""
|
293
293
|
if default is not None:
|
294
294
|
seconds = int(
|
295
|
-
(default -
|
295
|
+
(default - utc_now(tz_aware=default)).total_seconds()
|
296
296
|
)
|
297
297
|
default_str = (
|
298
298
|
f" [{str(default)} i.e. in "
|
@@ -309,7 +309,7 @@ def prompt_expires_at(
|
|
309
309
|
|
310
310
|
assert expires_at is not None
|
311
311
|
assert isinstance(expires_at, datetime)
|
312
|
-
if expires_at <
|
312
|
+
if expires_at < utc_now(tz_aware=expires_at):
|
313
313
|
cli_utils.warning(
|
314
314
|
"The expiration time must be in the future. Please enter a "
|
315
315
|
"later date and time."
|
@@ -317,7 +317,7 @@ def prompt_expires_at(
|
|
317
317
|
continue
|
318
318
|
|
319
319
|
seconds = int(
|
320
|
-
(expires_at -
|
320
|
+
(expires_at - utc_now(tz_aware=expires_at)).total_seconds()
|
321
321
|
)
|
322
322
|
|
323
323
|
confirm = click.confirm(
|
zenml/cli/stack.py
CHANGED
@@ -17,7 +17,6 @@ import getpass
|
|
17
17
|
import re
|
18
18
|
import time
|
19
19
|
import webbrowser
|
20
|
-
from datetime import datetime, timezone
|
21
20
|
from typing import (
|
22
21
|
TYPE_CHECKING,
|
23
22
|
Any,
|
@@ -77,6 +76,7 @@ from zenml.service_connectors.service_connector_utils import (
|
|
77
76
|
)
|
78
77
|
from zenml.utils import requirements_utils
|
79
78
|
from zenml.utils.dashboard_utils import get_component_url, get_stack_url
|
79
|
+
from zenml.utils.time_utils import utc_now_tz_aware
|
80
80
|
from zenml.utils.yaml_utils import read_yaml, write_yaml
|
81
81
|
|
82
82
|
if TYPE_CHECKING:
|
@@ -1575,7 +1575,7 @@ def deploy(
|
|
1575
1575
|
):
|
1576
1576
|
raise click.Abort()
|
1577
1577
|
|
1578
|
-
date_start =
|
1578
|
+
date_start = utc_now_tz_aware()
|
1579
1579
|
|
1580
1580
|
webbrowser.open(deployment_config.deployment_url)
|
1581
1581
|
console.print(
|
zenml/cli/utils.py
CHANGED
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Utility functions for the CLI."""
|
15
15
|
|
16
16
|
import contextlib
|
17
|
-
import datetime
|
18
17
|
import json
|
19
18
|
import os
|
20
19
|
import platform
|
@@ -79,6 +78,7 @@ from zenml.services import BaseService, ServiceState
|
|
79
78
|
from zenml.stack import StackComponent
|
80
79
|
from zenml.stack.stack_component import StackComponentConfig
|
81
80
|
from zenml.utils import secret_utils
|
81
|
+
from zenml.utils.time_utils import expires_in
|
82
82
|
|
83
83
|
if TYPE_CHECKING:
|
84
84
|
from uuid import UUID
|
@@ -1581,58 +1581,6 @@ def print_components_table(
|
|
1581
1581
|
print_table(configurations)
|
1582
1582
|
|
1583
1583
|
|
1584
|
-
def seconds_to_human_readable(time_seconds: int) -> str:
|
1585
|
-
"""Converts seconds to human-readable format.
|
1586
|
-
|
1587
|
-
Args:
|
1588
|
-
time_seconds: Seconds to convert.
|
1589
|
-
|
1590
|
-
Returns:
|
1591
|
-
Human readable string.
|
1592
|
-
"""
|
1593
|
-
seconds = time_seconds % 60
|
1594
|
-
minutes = (time_seconds // 60) % 60
|
1595
|
-
hours = (time_seconds // 3600) % 24
|
1596
|
-
days = time_seconds // 86400
|
1597
|
-
tokens = []
|
1598
|
-
if days:
|
1599
|
-
tokens.append(f"{days}d")
|
1600
|
-
if hours:
|
1601
|
-
tokens.append(f"{hours}h")
|
1602
|
-
if minutes:
|
1603
|
-
tokens.append(f"{minutes}m")
|
1604
|
-
if seconds:
|
1605
|
-
tokens.append(f"{seconds}s")
|
1606
|
-
|
1607
|
-
return "".join(tokens)
|
1608
|
-
|
1609
|
-
|
1610
|
-
def expires_in(
|
1611
|
-
expires_at: datetime.datetime,
|
1612
|
-
expired_str: str,
|
1613
|
-
skew_tolerance: Optional[int] = None,
|
1614
|
-
) -> str:
|
1615
|
-
"""Returns a human-readable string of the time until the token expires.
|
1616
|
-
|
1617
|
-
Args:
|
1618
|
-
expires_at: Datetime object of the token expiration.
|
1619
|
-
expired_str: String to return if the token is expired.
|
1620
|
-
skew_tolerance: Seconds of skew tolerance to subtract from the
|
1621
|
-
expiration time. If the token expires within this time, it will be
|
1622
|
-
considered expired.
|
1623
|
-
|
1624
|
-
Returns:
|
1625
|
-
Human readable string.
|
1626
|
-
"""
|
1627
|
-
now = datetime.datetime.now(datetime.timezone.utc)
|
1628
|
-
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
|
1629
|
-
if skew_tolerance:
|
1630
|
-
expires_at -= datetime.timedelta(seconds=skew_tolerance)
|
1631
|
-
if expires_at < now:
|
1632
|
-
return expired_str
|
1633
|
-
return seconds_to_human_readable(int((expires_at - now).total_seconds()))
|
1634
|
-
|
1635
|
-
|
1636
1584
|
def print_service_connectors_table(
|
1637
1585
|
client: "Client",
|
1638
1586
|
connectors: Sequence["ServiceConnectorResponse"],
|
@@ -2660,7 +2608,7 @@ def print_model_url(url: Optional[str]) -> None:
|
|
2660
2608
|
warning(
|
2661
2609
|
"You can display various ZenML entities including pipelines, "
|
2662
2610
|
"runs, stacks and much more on the ZenML Dashboard. "
|
2663
|
-
"You can try it locally, by running `zenml
|
2611
|
+
"You can try it locally, by running `zenml login --local`, or "
|
2664
2612
|
"remotely, by deploying ZenML on the infrastructure of your choice."
|
2665
2613
|
)
|
2666
2614
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Pipeline configuration classes."""
|
15
15
|
|
16
|
-
from datetime import datetime
|
16
|
+
from datetime import datetime
|
17
17
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
18
18
|
|
19
19
|
from pydantic import SerializeAsAny, field_validator
|
@@ -23,6 +23,7 @@ from zenml.config.retry_config import StepRetryConfig
|
|
23
23
|
from zenml.config.source import SourceWithValidator
|
24
24
|
from zenml.config.strict_base_model import StrictBaseModel
|
25
25
|
from zenml.model.model import Model
|
26
|
+
from zenml.utils.time_utils import utc_now
|
26
27
|
|
27
28
|
if TYPE_CHECKING:
|
28
29
|
from zenml.config import DockerSettings
|
@@ -61,7 +62,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
|
|
61
62
|
The full substitutions dict including date and time.
|
62
63
|
"""
|
63
64
|
if start_time is None:
|
64
|
-
start_time =
|
65
|
+
start_time = utc_now()
|
65
66
|
ret = self.substitutions.copy()
|
66
67
|
ret.setdefault("date", start_time.strftime("%Y_%m_%d"))
|
67
68
|
ret.setdefault("time", start_time.strftime("%H_%M_%S_%f"))
|
zenml/config/schedule.py
CHANGED
@@ -98,27 +98,3 @@ class Schedule(BaseModel):
|
|
98
98
|
"or a run once start time "
|
99
99
|
"need to be set for a valid schedule."
|
100
100
|
)
|
101
|
-
|
102
|
-
@property
|
103
|
-
def utc_start_time(self) -> Optional[str]:
|
104
|
-
"""Optional ISO-formatted string of the UTC start time.
|
105
|
-
|
106
|
-
Returns:
|
107
|
-
Optional ISO-formatted string of the UTC start time.
|
108
|
-
"""
|
109
|
-
if not self.start_time:
|
110
|
-
return None
|
111
|
-
|
112
|
-
return self.start_time.astimezone(datetime.timezone.utc).isoformat()
|
113
|
-
|
114
|
-
@property
|
115
|
-
def utc_end_time(self) -> Optional[str]:
|
116
|
-
"""Optional ISO-formatted string of the UTC end time.
|
117
|
-
|
118
|
-
Returns:
|
119
|
-
Optional ISO-formatted string of the UTC end time.
|
120
|
-
"""
|
121
|
-
if not self.end_time:
|
122
|
-
return None
|
123
|
-
|
124
|
-
return self.end_time.astimezone(datetime.timezone.utc).isoformat()
|
zenml/enums.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14
14
|
"""Base class for event hub implementations."""
|
15
15
|
|
16
16
|
from abc import ABC, abstractmethod
|
17
|
-
from datetime import datetime, timedelta
|
17
|
+
from datetime import datetime, timedelta
|
18
18
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
19
19
|
|
20
20
|
from zenml import EventSourceResponse
|
@@ -28,6 +28,7 @@ from zenml.models import (
|
|
28
28
|
TriggerExecutionResponse,
|
29
29
|
TriggerResponse,
|
30
30
|
)
|
31
|
+
from zenml.utils.time_utils import utc_now
|
31
32
|
from zenml.zen_server.auth import AuthContext
|
32
33
|
from zenml.zen_server.jwt import JWTToken
|
33
34
|
|
@@ -134,9 +135,7 @@ class BaseEventHub(ABC):
|
|
134
135
|
)
|
135
136
|
expires: Optional[datetime] = None
|
136
137
|
if trigger.action.auth_window:
|
137
|
-
expires =
|
138
|
-
minutes=trigger.action.auth_window
|
139
|
-
)
|
138
|
+
expires = utc_now() + timedelta(minutes=trigger.action.auth_window)
|
140
139
|
encoded_token = token.encode(expires=expires)
|
141
140
|
auth_context = AuthContext(
|
142
141
|
user=trigger.action.service_account,
|
@@ -42,6 +42,7 @@ from zenml.orchestrators import ContainerizedOrchestrator
|
|
42
42
|
from zenml.orchestrators.utils import get_orchestrator_run_name
|
43
43
|
from zenml.stack import StackValidator
|
44
44
|
from zenml.utils import io_utils
|
45
|
+
from zenml.utils.time_utils import utc_now
|
45
46
|
|
46
47
|
if TYPE_CHECKING:
|
47
48
|
from zenml.config import ResourceSettings
|
@@ -408,8 +409,7 @@ class AirflowOrchestrator(ContainerizedOrchestrator):
|
|
408
409
|
if schedule:
|
409
410
|
if schedule.cron_expression:
|
410
411
|
start_time = schedule.start_time or (
|
411
|
-
datetime.
|
412
|
-
- datetime.timedelta(7)
|
412
|
+
utc_now() - datetime.timedelta(7)
|
413
413
|
)
|
414
414
|
return {
|
415
415
|
"schedule": schedule.cron_expression,
|
@@ -429,7 +429,6 @@ class AirflowOrchestrator(ContainerizedOrchestrator):
|
|
429
429
|
"schedule": "@once",
|
430
430
|
# set a start time in the past and disable catchup so airflow
|
431
431
|
# runs the dag immediately
|
432
|
-
"start_date": datetime.
|
433
|
-
- datetime.timedelta(7),
|
432
|
+
"start_date": utc_now() - datetime.timedelta(7),
|
434
433
|
"catchup": False,
|
435
434
|
}
|
@@ -35,12 +35,13 @@ AWS_RESOURCE_TYPE = "aws-generic"
|
|
35
35
|
S3_RESOURCE_TYPE = "s3-bucket"
|
36
36
|
AWS_IMAGE_BUILDER_FLAVOR = "aws"
|
37
37
|
|
38
|
+
|
38
39
|
class AWSIntegration(Integration):
|
39
40
|
"""Definition of AWS integration for ZenML."""
|
40
41
|
|
41
42
|
NAME = AWS
|
42
43
|
REQUIREMENTS = [
|
43
|
-
"sagemaker>=2.
|
44
|
+
"sagemaker>=2.199.0",
|
44
45
|
"kubernetes",
|
45
46
|
"aws-profile-manager",
|
46
47
|
]
|
@@ -85,6 +85,7 @@ class SagemakerOrchestratorSettings(BaseSettings):
|
|
85
85
|
to the container is configured with input_data_s3_mode. Two possible
|
86
86
|
input types:
|
87
87
|
- str: S3 location where training data is saved.
|
88
|
+
- Dict[str, str]: (ChannelName, S3Location) which represent
|
88
89
|
- Dict[str, str]: (ChannelName, S3Location) which represent
|
89
90
|
channels (e.g. training, validation, testing) where
|
90
91
|
specific parts of the data are saved in S3.
|
@@ -184,6 +185,10 @@ class SagemakerOrchestratorConfig(
|
|
184
185
|
|
185
186
|
Attributes:
|
186
187
|
execution_role: The IAM role ARN to use for the pipeline.
|
188
|
+
scheduler_role: The ARN of the IAM role that will be assumed by
|
189
|
+
the EventBridge service to launch Sagemaker pipelines
|
190
|
+
(For more details regarding the required permissions, please check:
|
191
|
+
https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules)
|
187
192
|
aws_access_key_id: The AWS access key ID to use to authenticate to AWS.
|
188
193
|
If not provided, the value from the default AWS config will be used.
|
189
194
|
aws_secret_access_key: The AWS secret access key to use to authenticate
|
@@ -203,6 +208,7 @@ class SagemakerOrchestratorConfig(
|
|
203
208
|
"""
|
204
209
|
|
205
210
|
execution_role: str
|
211
|
+
scheduler_role: Optional[str] = None
|
206
212
|
aws_access_key_id: Optional[str] = SecretField(default=None)
|
207
213
|
aws_secret_access_key: Optional[str] = SecretField(default=None)
|
208
214
|
aws_profile: Optional[str] = None
|
@@ -232,6 +238,15 @@ class SagemakerOrchestratorConfig(
|
|
232
238
|
"""
|
233
239
|
return self.synchronous
|
234
240
|
|
241
|
+
@property
|
242
|
+
def is_schedulable(self) -> bool:
|
243
|
+
"""Whether the orchestrator is schedulable or not.
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
Whether the orchestrator is schedulable or not.
|
247
|
+
"""
|
248
|
+
return True
|
249
|
+
|
235
250
|
|
236
251
|
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
|
237
252
|
"""Flavor for the Sagemaker orchestrator."""
|