zenml-nightly 0.73.0.dev20250123__py3-none-any.whl → 0.73.0.dev20250125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/analytics/context.py +2 -6
- zenml/cli/annotator.py +1 -1
- zenml/cli/login.py +17 -6
- zenml/cli/server.py +1 -0
- zenml/cli/service_connectors.py +5 -5
- zenml/cli/stack.py +2 -2
- zenml/cli/utils.py +2 -54
- zenml/config/pipeline_configurations.py +3 -2
- zenml/config/schedule.py +0 -24
- zenml/enums.py +1 -0
- zenml/event_hub/base_event_hub.py +3 -4
- zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +3 -4
- zenml/integrations/aws/__init__.py +2 -1
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +15 -0
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +310 -70
- zenml/integrations/aws/service_connectors/aws_service_connector.py +8 -13
- zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -10
- zenml/integrations/gcp/service_connectors/gcp_service_connector.py +3 -3
- zenml/integrations/huggingface/__init__.py +1 -6
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +3 -3
- zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +6 -2
- zenml/integrations/whylogs/data_validators/whylogs_data_validator.py +2 -3
- zenml/logging/step_logging.py +7 -7
- zenml/login/credentials.py +6 -5
- zenml/login/credentials_store.py +4 -3
- zenml/models/v2/core/api_key.py +5 -2
- zenml/models/v2/core/schedule.py +19 -3
- zenml/orchestrators/publish_utils.py +4 -4
- zenml/orchestrators/step_launcher.py +3 -3
- zenml/orchestrators/step_run_utils.py +2 -2
- zenml/pipelines/run_utils.py +2 -2
- zenml/service_connectors/service_connector.py +7 -4
- zenml/stack/stack.py +5 -4
- zenml/stack/stack_component.py +10 -2
- zenml/stack_deployments/stack_deployment.py +2 -3
- zenml/utils/string_utils.py +2 -2
- zenml/utils/time_utils.py +138 -0
- zenml/zen_server/auth.py +8 -9
- zenml/zen_server/cloud_utils.py +4 -6
- zenml/zen_server/routers/devices_endpoints.py +2 -4
- zenml/zen_server/routers/workspaces_endpoints.py +2 -0
- zenml/zen_server/zen_server_api.py +9 -8
- zenml/zen_stores/migrations/versions/25155145c545_separate_actions_and_triggers.py +3 -2
- zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py +3 -3
- zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +3 -2
- zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +10 -7
- zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
- zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +5 -3
- zenml/zen_stores/schemas/action_schemas.py +2 -2
- zenml/zen_stores/schemas/api_key_schemas.py +5 -4
- zenml/zen_stores/schemas/artifact_schemas.py +3 -3
- zenml/zen_stores/schemas/base_schemas.py +5 -7
- zenml/zen_stores/schemas/code_repository_schemas.py +2 -2
- zenml/zen_stores/schemas/component_schemas.py +2 -2
- zenml/zen_stores/schemas/device_schemas.py +5 -4
- zenml/zen_stores/schemas/event_source_schemas.py +2 -2
- zenml/zen_stores/schemas/flavor_schemas.py +2 -2
- zenml/zen_stores/schemas/model_schemas.py +3 -3
- zenml/zen_stores/schemas/pipeline_run_schemas.py +11 -3
- zenml/zen_stores/schemas/pipeline_schemas.py +2 -2
- zenml/zen_stores/schemas/run_template_schemas.py +2 -2
- zenml/zen_stores/schemas/schedule_schema.py +20 -4
- zenml/zen_stores/schemas/secret_schemas.py +2 -2
- zenml/zen_stores/schemas/server_settings_schemas.py +6 -9
- zenml/zen_stores/schemas/service_connector_schemas.py +3 -2
- zenml/zen_stores/schemas/service_schemas.py +2 -2
- zenml/zen_stores/schemas/stack_schemas.py +2 -2
- zenml/zen_stores/schemas/step_run_schemas.py +3 -2
- zenml/zen_stores/schemas/tag_schemas.py +2 -2
- zenml/zen_stores/schemas/trigger_schemas.py +2 -2
- zenml/zen_stores/schemas/user_schemas.py +3 -3
- zenml/zen_stores/schemas/workspace_schemas.py +2 -2
- zenml/zen_stores/sql_zen_store.py +6 -14
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/METADATA +2 -2
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/RECORD +79 -78
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/entry_points.txt +0 -0
@@ -35,14 +35,20 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput
|
|
35
35
|
from sagemaker.workflow.execution_variables import ExecutionVariables
|
36
36
|
from sagemaker.workflow.pipeline import Pipeline
|
37
37
|
from sagemaker.workflow.steps import ProcessingStep, TrainingStep
|
38
|
+
from sagemaker.workflow.triggers import PipelineSchedule
|
38
39
|
|
40
|
+
from zenml.client import Client
|
39
41
|
from zenml.config.base_settings import BaseSettings
|
40
42
|
from zenml.constants import (
|
41
43
|
METADATA_ORCHESTRATOR_LOGS_URL,
|
42
44
|
METADATA_ORCHESTRATOR_RUN_ID,
|
43
45
|
METADATA_ORCHESTRATOR_URL,
|
44
46
|
)
|
45
|
-
from zenml.enums import
|
47
|
+
from zenml.enums import (
|
48
|
+
ExecutionStatus,
|
49
|
+
MetadataResourceTypes,
|
50
|
+
StackComponentType,
|
51
|
+
)
|
46
52
|
from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
|
47
53
|
SagemakerOrchestratorConfig,
|
48
54
|
SagemakerOrchestratorSettings,
|
@@ -57,6 +63,7 @@ from zenml.orchestrators import ContainerizedOrchestrator
|
|
57
63
|
from zenml.orchestrators.utils import get_orchestrator_run_name
|
58
64
|
from zenml.stack import StackValidator
|
59
65
|
from zenml.utils.env_utils import split_environment_variables
|
66
|
+
from zenml.utils.time_utils import to_utc_timezone, utc_now_tz_aware
|
60
67
|
|
61
68
|
if TYPE_CHECKING:
|
62
69
|
from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
|
@@ -69,6 +76,36 @@ POLLING_DELAY = 30
|
|
69
76
|
logger = get_logger(__name__)
|
70
77
|
|
71
78
|
|
79
|
+
def dissect_schedule_arn(
|
80
|
+
schedule_arn: str,
|
81
|
+
) -> Tuple[Optional[str], Optional[str]]:
|
82
|
+
"""Extracts the region and the name from an EventBridge schedule ARN.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
schedule_arn: The ARN of the EventBridge schedule.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Region Name, Schedule Name (including the group name)
|
89
|
+
|
90
|
+
Raises:
|
91
|
+
ValueError: If the input is not a properly formatted ARN.
|
92
|
+
"""
|
93
|
+
# Split the ARN into parts
|
94
|
+
arn_parts = schedule_arn.split(":")
|
95
|
+
|
96
|
+
# Validate ARN structure
|
97
|
+
if len(arn_parts) < 6 or not arn_parts[5].startswith("schedule/"):
|
98
|
+
raise ValueError("Invalid EventBridge schedule ARN format.")
|
99
|
+
|
100
|
+
# Extract the region
|
101
|
+
region = arn_parts[3]
|
102
|
+
|
103
|
+
# Extract the group name and schedule name
|
104
|
+
name = arn_parts[5].split("schedule/")[1]
|
105
|
+
|
106
|
+
return region, name
|
107
|
+
|
108
|
+
|
72
109
|
def dissect_pipeline_execution_arn(
|
73
110
|
pipeline_execution_arn: str,
|
74
111
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
@@ -237,21 +274,15 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
237
274
|
environment.
|
238
275
|
|
239
276
|
Raises:
|
240
|
-
RuntimeError: If
|
241
|
-
|
277
|
+
RuntimeError: If there is an error creating or scheduling the
|
278
|
+
pipeline.
|
242
279
|
TypeError: If the network_config passed is not compatible with the
|
243
280
|
AWS SageMaker NetworkConfig class.
|
281
|
+
ValueError: If the schedule is not valid.
|
244
282
|
|
245
283
|
Yields:
|
246
284
|
A dictionary of metadata related to the pipeline run.
|
247
285
|
"""
|
248
|
-
if deployment.schedule:
|
249
|
-
logger.warning(
|
250
|
-
"The Sagemaker Orchestrator currently does not support the "
|
251
|
-
"use of schedules. The `schedule` will be ignored "
|
252
|
-
"and the pipeline will be run immediately."
|
253
|
-
)
|
254
|
-
|
255
286
|
# sagemaker requires pipelineName to use alphanum and hyphens only
|
256
287
|
unsanitized_orchestrator_run_name = get_orchestrator_run_name(
|
257
288
|
pipeline_name=deployment.pipeline_configuration.name
|
@@ -459,7 +490,7 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
459
490
|
|
460
491
|
sagemaker_steps.append(sagemaker_step)
|
461
492
|
|
462
|
-
#
|
493
|
+
# Create the pipeline
|
463
494
|
pipeline = Pipeline(
|
464
495
|
name=orchestrator_run_name,
|
465
496
|
steps=sagemaker_steps,
|
@@ -479,39 +510,211 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
479
510
|
if settings.pipeline_tags
|
480
511
|
else None,
|
481
512
|
)
|
482
|
-
execution = pipeline.start()
|
483
|
-
logger.warning(
|
484
|
-
"Steps can take 5-15 minutes to start running "
|
485
|
-
"when using the Sagemaker Orchestrator."
|
486
|
-
)
|
487
513
|
|
488
|
-
#
|
489
|
-
|
490
|
-
|
491
|
-
|
514
|
+
# Handle scheduling if specified
|
515
|
+
if deployment.schedule:
|
516
|
+
if settings.synchronous:
|
517
|
+
logger.warning(
|
518
|
+
"The 'synchronous' setting is ignored for scheduled "
|
519
|
+
"pipelines since they run independently of the "
|
520
|
+
"deployment process."
|
521
|
+
)
|
492
522
|
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
523
|
+
schedule_name = orchestrator_run_name
|
524
|
+
next_execution = None
|
525
|
+
start_date = (
|
526
|
+
to_utc_timezone(deployment.schedule.start_time)
|
527
|
+
if deployment.schedule.start_time
|
528
|
+
else None
|
499
529
|
)
|
530
|
+
|
531
|
+
# Create PipelineSchedule based on schedule type
|
532
|
+
if deployment.schedule.cron_expression:
|
533
|
+
cron_exp = self._validate_cron_expression(
|
534
|
+
deployment.schedule.cron_expression
|
535
|
+
)
|
536
|
+
schedule = PipelineSchedule(
|
537
|
+
name=schedule_name,
|
538
|
+
cron=cron_exp,
|
539
|
+
start_date=start_date,
|
540
|
+
enabled=True,
|
541
|
+
)
|
542
|
+
elif deployment.schedule.interval_second:
|
543
|
+
# This is necessary because SageMaker's PipelineSchedule rate
|
544
|
+
# expressions require minutes as the minimum time unit.
|
545
|
+
# Even if a user specifies an interval of less than 60 seconds,
|
546
|
+
# it will be rounded up to 1 minute.
|
547
|
+
minutes = max(
|
548
|
+
1,
|
549
|
+
int(
|
550
|
+
deployment.schedule.interval_second.total_seconds()
|
551
|
+
/ 60
|
552
|
+
),
|
553
|
+
)
|
554
|
+
schedule = PipelineSchedule(
|
555
|
+
name=schedule_name,
|
556
|
+
rate=(minutes, "minutes"),
|
557
|
+
start_date=start_date,
|
558
|
+
enabled=True,
|
559
|
+
)
|
560
|
+
next_execution = (
|
561
|
+
deployment.schedule.start_time or utc_now_tz_aware()
|
562
|
+
) + deployment.schedule.interval_second
|
563
|
+
else:
|
564
|
+
# One-time schedule
|
565
|
+
execution_time = (
|
566
|
+
deployment.schedule.run_once_start_time
|
567
|
+
or deployment.schedule.start_time
|
568
|
+
)
|
569
|
+
if not execution_time:
|
570
|
+
raise ValueError(
|
571
|
+
"A start time must be specified for one-time "
|
572
|
+
"schedule execution"
|
573
|
+
)
|
574
|
+
schedule = PipelineSchedule(
|
575
|
+
name=schedule_name,
|
576
|
+
at=to_utc_timezone(execution_time),
|
577
|
+
enabled=True,
|
578
|
+
)
|
579
|
+
next_execution = execution_time
|
580
|
+
|
581
|
+
# Get the current role ARN if not explicitly configured
|
582
|
+
if self.config.scheduler_role is None:
|
583
|
+
logger.info(
|
584
|
+
"No scheduler_role configured. Trying to extract it from "
|
585
|
+
"the client side authentication."
|
586
|
+
)
|
587
|
+
sts = session.boto_session.client("sts")
|
588
|
+
try:
|
589
|
+
scheduler_role_arn = sts.get_caller_identity()["Arn"]
|
590
|
+
# If this is a user ARN, try to get the role ARN
|
591
|
+
if ":user/" in scheduler_role_arn:
|
592
|
+
logger.warning(
|
593
|
+
f"Using IAM user credentials "
|
594
|
+
f"({scheduler_role_arn}). For production "
|
595
|
+
"environments, it's recommended to use IAM roles "
|
596
|
+
"instead."
|
597
|
+
)
|
598
|
+
# If this is an assumed role, extract the role ARN
|
599
|
+
elif ":assumed-role/" in scheduler_role_arn:
|
600
|
+
# Convert assumed-role ARN format to role ARN format
|
601
|
+
# From: arn:aws:sts::123456789012:assumed-role/role-name/session-name
|
602
|
+
# To: arn:aws:iam::123456789012:role/role-name
|
603
|
+
scheduler_role_arn = re.sub(
|
604
|
+
r"arn:aws:sts::(\d+):assumed-role/([^/]+)/.*",
|
605
|
+
r"arn:aws:iam::\1:role/\2",
|
606
|
+
scheduler_role_arn,
|
607
|
+
)
|
608
|
+
elif ":role/" not in scheduler_role_arn:
|
609
|
+
raise RuntimeError(
|
610
|
+
f"Unexpected credential type "
|
611
|
+
f"({scheduler_role_arn}). Please use IAM "
|
612
|
+
f"roles for SageMaker pipeline scheduling."
|
613
|
+
)
|
614
|
+
else:
|
615
|
+
raise RuntimeError(
|
616
|
+
"The ARN of the caller identity "
|
617
|
+
f"`{scheduler_role_arn}` does not "
|
618
|
+
"include a user or a proper role."
|
619
|
+
)
|
620
|
+
except Exception:
|
621
|
+
raise RuntimeError(
|
622
|
+
"Failed to get current role ARN. This means the "
|
623
|
+
"your client side credentials that you are "
|
624
|
+
"is not configured correctly to schedule sagemaker "
|
625
|
+
"pipelines. For more information, please check:"
|
626
|
+
"https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules"
|
627
|
+
)
|
628
|
+
else:
|
629
|
+
scheduler_role_arn = self.config.scheduler_role
|
630
|
+
|
631
|
+
# Attach schedule to pipeline
|
632
|
+
triggers = pipeline.put_triggers(
|
633
|
+
triggers=[schedule],
|
634
|
+
role_arn=scheduler_role_arn,
|
635
|
+
)
|
636
|
+
logger.info(f"The schedule ARN is: {triggers[0]}")
|
637
|
+
|
500
638
|
try:
|
501
|
-
|
502
|
-
|
639
|
+
from zenml.models import RunMetadataResource
|
640
|
+
|
641
|
+
schedule_metadata = self.generate_schedule_metadata(
|
642
|
+
schedule_arn=triggers[0]
|
503
643
|
)
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
644
|
+
|
645
|
+
Client().create_run_metadata(
|
646
|
+
metadata=schedule_metadata, # type: ignore[arg-type]
|
647
|
+
resources=[
|
648
|
+
RunMetadataResource(
|
649
|
+
id=deployment.schedule.id,
|
650
|
+
type=MetadataResourceTypes.SCHEDULE,
|
651
|
+
)
|
652
|
+
],
|
653
|
+
)
|
654
|
+
except Exception as e:
|
655
|
+
logger.debug(
|
656
|
+
"There was an error attaching metadata to the "
|
657
|
+
f"schedule: {e}"
|
513
658
|
)
|
514
659
|
|
660
|
+
logger.info(
|
661
|
+
f"Successfully scheduled pipeline with name: {schedule_name}\n"
|
662
|
+
+ (
|
663
|
+
f"First execution will occur at: "
|
664
|
+
f"{next_execution.strftime('%Y-%m-%d %H:%M:%S UTC')}"
|
665
|
+
if next_execution
|
666
|
+
else f"Using cron expression: "
|
667
|
+
f"{deployment.schedule.cron_expression}"
|
668
|
+
)
|
669
|
+
+ (
|
670
|
+
f" (and every {minutes} minutes after)"
|
671
|
+
if deployment.schedule.interval_second
|
672
|
+
else ""
|
673
|
+
)
|
674
|
+
)
|
675
|
+
logger.info(
|
676
|
+
"\n\nIn order to cancel the schedule, you can use execute "
|
677
|
+
"the following command:\n"
|
678
|
+
)
|
679
|
+
logger.info(
|
680
|
+
f"`aws scheduler delete-schedule --name {schedule_name}`"
|
681
|
+
)
|
682
|
+
else:
|
683
|
+
# Execute the pipeline immediately if no schedule is specified
|
684
|
+
execution = pipeline.start()
|
685
|
+
logger.warning(
|
686
|
+
"Steps can take 5-15 minutes to start running "
|
687
|
+
"when using the Sagemaker Orchestrator."
|
688
|
+
)
|
689
|
+
|
690
|
+
# Yield metadata based on the generated execution object
|
691
|
+
yield from self.compute_metadata(
|
692
|
+
execution_arn=execution.arn, settings=settings
|
693
|
+
)
|
694
|
+
|
695
|
+
# mainly for testing purposes, we wait for the pipeline to finish
|
696
|
+
if settings.synchronous:
|
697
|
+
logger.info(
|
698
|
+
"Executing synchronously. Waiting for pipeline to "
|
699
|
+
"finish... \n"
|
700
|
+
"At this point you can `Ctrl-C` out without cancelling the "
|
701
|
+
"execution."
|
702
|
+
)
|
703
|
+
try:
|
704
|
+
execution.wait(
|
705
|
+
delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
|
706
|
+
)
|
707
|
+
logger.info("Pipeline completed successfully.")
|
708
|
+
except WaiterError:
|
709
|
+
raise RuntimeError(
|
710
|
+
"Timed out while waiting for pipeline execution to "
|
711
|
+
"finish. For long-running pipelines we recommend "
|
712
|
+
"configuring your orchestrator for asynchronous "
|
713
|
+
"execution. The following command does this for you: \n"
|
714
|
+
f"`zenml orchestrator update {self.name} "
|
715
|
+
f"--synchronous=False`"
|
716
|
+
)
|
717
|
+
|
515
718
|
def get_pipeline_run_metadata(
|
516
719
|
self, run_id: UUID
|
517
720
|
) -> Dict[str, "MetadataType"]:
|
@@ -523,10 +726,20 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
523
726
|
Returns:
|
524
727
|
A dictionary of metadata.
|
525
728
|
"""
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
729
|
+
execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
|
730
|
+
|
731
|
+
run_metadata: Dict[str, "MetadataType"] = {}
|
732
|
+
|
733
|
+
settings = cast(
|
734
|
+
SagemakerOrchestratorSettings,
|
735
|
+
self.get_settings(Client().get_pipeline_run(run_id)),
|
736
|
+
)
|
737
|
+
|
738
|
+
for metadata in self.compute_metadata(
|
739
|
+
execution_arn=execution_arn,
|
740
|
+
settings=settings,
|
741
|
+
):
|
742
|
+
run_metadata.update(metadata)
|
530
743
|
|
531
744
|
return run_metadata
|
532
745
|
|
@@ -588,56 +801,57 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
588
801
|
|
589
802
|
def compute_metadata(
|
590
803
|
self,
|
591
|
-
|
804
|
+
execution_arn: str,
|
592
805
|
settings: SagemakerOrchestratorSettings,
|
593
806
|
) -> Iterator[Dict[str, MetadataType]]:
|
594
807
|
"""Generate run metadata based on the generated Sagemaker Execution.
|
595
808
|
|
596
809
|
Args:
|
597
|
-
|
810
|
+
execution_arn: The ARN of the pipeline execution.
|
598
811
|
settings: The Sagemaker orchestrator settings.
|
599
812
|
|
600
813
|
Yields:
|
601
814
|
A dictionary of metadata related to the pipeline run.
|
602
815
|
"""
|
603
|
-
# Metadata
|
604
|
-
metadata: Dict[str, MetadataType] = {}
|
605
|
-
|
606
816
|
# Orchestrator Run ID
|
607
|
-
|
608
|
-
|
817
|
+
metadata: Dict[str, MetadataType] = {
|
818
|
+
"pipeline_execution_arn": execution_arn,
|
819
|
+
METADATA_ORCHESTRATOR_RUN_ID: execution_arn,
|
820
|
+
}
|
609
821
|
|
610
822
|
# URL to the Sagemaker's pipeline view
|
611
|
-
if orchestrator_url := self._compute_orchestrator_url(
|
823
|
+
if orchestrator_url := self._compute_orchestrator_url(
|
824
|
+
execution_arn=execution_arn
|
825
|
+
):
|
612
826
|
metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
|
613
827
|
|
614
828
|
# URL to the corresponding CloudWatch page
|
615
829
|
if logs_url := self._compute_orchestrator_logs_url(
|
616
|
-
|
830
|
+
execution_arn=execution_arn, settings=settings
|
617
831
|
):
|
618
832
|
metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
|
619
833
|
|
620
834
|
yield metadata
|
621
835
|
|
622
|
-
@staticmethod
|
623
836
|
def _compute_orchestrator_url(
|
624
|
-
|
837
|
+
self,
|
838
|
+
execution_arn: Any,
|
625
839
|
) -> Optional[str]:
|
626
840
|
"""Generate the Orchestrator Dashboard URL upon pipeline execution.
|
627
841
|
|
628
842
|
Args:
|
629
|
-
|
843
|
+
execution_arn: The ARN of the pipeline execution.
|
630
844
|
|
631
845
|
Returns:
|
632
846
|
the URL to the dashboard view in SageMaker.
|
633
847
|
"""
|
634
848
|
try:
|
635
849
|
region_name, pipeline_name, execution_id = (
|
636
|
-
dissect_pipeline_execution_arn(
|
850
|
+
dissect_pipeline_execution_arn(execution_arn)
|
637
851
|
)
|
638
852
|
|
639
853
|
# Get the Sagemaker session
|
640
|
-
session =
|
854
|
+
session = self._get_sagemaker_session()
|
641
855
|
|
642
856
|
# List the Studio domains and get the Studio Domain ID
|
643
857
|
domains_response = session.sagemaker_client.list_domains()
|
@@ -657,13 +871,13 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
657
871
|
|
658
872
|
@staticmethod
|
659
873
|
def _compute_orchestrator_logs_url(
|
660
|
-
|
874
|
+
execution_arn: Any,
|
661
875
|
settings: SagemakerOrchestratorSettings,
|
662
876
|
) -> Optional[str]:
|
663
877
|
"""Generate the CloudWatch URL upon pipeline execution.
|
664
878
|
|
665
879
|
Args:
|
666
|
-
|
880
|
+
execution_arn: The ARN of the pipeline execution.
|
667
881
|
settings: The Sagemaker orchestrator settings.
|
668
882
|
|
669
883
|
Returns:
|
@@ -671,7 +885,7 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
671
885
|
"""
|
672
886
|
try:
|
673
887
|
region_name, _, execution_id = dissect_pipeline_execution_arn(
|
674
|
-
|
888
|
+
execution_arn
|
675
889
|
)
|
676
890
|
|
677
891
|
use_training_jobs = True
|
@@ -693,22 +907,48 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
|
|
693
907
|
return None
|
694
908
|
|
695
909
|
@staticmethod
|
696
|
-
def
|
697
|
-
|
698
|
-
) -> Optional[str]:
|
699
|
-
"""Fetch the Orchestrator Run ID upon pipeline execution.
|
910
|
+
def generate_schedule_metadata(schedule_arn: str) -> Dict[str, str]:
|
911
|
+
"""Attaches metadata to the ZenML Schedules.
|
700
912
|
|
701
913
|
Args:
|
702
|
-
|
914
|
+
schedule_arn: The trigger ARNs that is generated on the AWS side.
|
703
915
|
|
704
916
|
Returns:
|
705
|
-
|
917
|
+
a dictionary containing metadata related to the schedule.
|
706
918
|
"""
|
707
|
-
|
708
|
-
return str(pipeline_execution.arn)
|
919
|
+
region, name = dissect_schedule_arn(schedule_arn=schedule_arn)
|
709
920
|
|
710
|
-
|
711
|
-
|
712
|
-
f"
|
921
|
+
return {
|
922
|
+
"trigger_url": (
|
923
|
+
f"https://{region}.console.aws.amazon.com/scheduler/home"
|
924
|
+
f"?region={region}#schedules/{name}"
|
925
|
+
),
|
926
|
+
}
|
927
|
+
|
928
|
+
@staticmethod
|
929
|
+
def _validate_cron_expression(cron_expression: str) -> str:
|
930
|
+
"""Validates and formats a cron expression for SageMaker schedules.
|
931
|
+
|
932
|
+
Args:
|
933
|
+
cron_expression: The cron expression to validate
|
934
|
+
|
935
|
+
Returns:
|
936
|
+
The formatted cron expression
|
937
|
+
|
938
|
+
Raises:
|
939
|
+
ValueError: If the cron expression is invalid
|
940
|
+
"""
|
941
|
+
# Strip any "cron(" prefix if it exists
|
942
|
+
cron_exp = cron_expression.replace("cron(", "").replace(")", "")
|
943
|
+
|
944
|
+
# Split into components
|
945
|
+
parts = cron_exp.split()
|
946
|
+
if len(parts) not in [6, 7]: # AWS cron requires 6 or 7 fields
|
947
|
+
raise ValueError(
|
948
|
+
f"Invalid cron expression: {cron_expression}. AWS cron "
|
949
|
+
"expressions must have 6 or 7 fields: minute hour day-of-month "
|
950
|
+
"month day-of-week year(optional). Example: '15 10 ? * 6L "
|
951
|
+
"2022-2023'"
|
713
952
|
)
|
714
|
-
|
953
|
+
|
954
|
+
return cron_exp
|
@@ -66,6 +66,7 @@ from zenml.service_connectors.service_connector import (
|
|
66
66
|
)
|
67
67
|
from zenml.utils.enum_utils import StrEnum
|
68
68
|
from zenml.utils.secret_utils import PlainSerializedSecretStr
|
69
|
+
from zenml.utils.time_utils import utc_now_tz_aware
|
69
70
|
|
70
71
|
logger = get_logger(__name__)
|
71
72
|
|
@@ -711,7 +712,7 @@ class AWSServiceConnector(ServiceConnector):
|
|
711
712
|
return session, None
|
712
713
|
|
713
714
|
# Refresh expired sessions
|
714
|
-
now =
|
715
|
+
now = utc_now_tz_aware()
|
715
716
|
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
|
716
717
|
# check if the token expires in the near future
|
717
718
|
if expires_at > now + datetime.timedelta(
|
@@ -959,9 +960,7 @@ class AWSServiceConnector(ServiceConnector):
|
|
959
960
|
# determine the expiration time of the temporary credentials
|
960
961
|
# from the boto3 session, so we assume the default IAM role
|
961
962
|
# expiration date is used
|
962
|
-
expiration_time = datetime.
|
963
|
-
tz=datetime.timezone.utc
|
964
|
-
) + datetime.timedelta(
|
963
|
+
expiration_time = utc_now_tz_aware() + datetime.timedelta(
|
965
964
|
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
|
966
965
|
)
|
967
966
|
return session, expiration_time
|
@@ -1673,9 +1672,7 @@ class AWSServiceConnector(ServiceConnector):
|
|
1673
1672
|
# expiration time of the temporary credentials from the
|
1674
1673
|
# boto3 session, so we assume the default IAM role
|
1675
1674
|
# expiration period is used
|
1676
|
-
expires_at = datetime.
|
1677
|
-
tz=datetime.timezone.utc
|
1678
|
-
) + datetime.timedelta(
|
1675
|
+
expires_at = utc_now_tz_aware() + datetime.timedelta(
|
1679
1676
|
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
|
1680
1677
|
)
|
1681
1678
|
|
@@ -1720,9 +1717,7 @@ class AWSServiceConnector(ServiceConnector):
|
|
1720
1717
|
aws_secret_access_key=credentials["SecretAccessKey"],
|
1721
1718
|
aws_session_token=credentials["SessionToken"],
|
1722
1719
|
)
|
1723
|
-
expires_at = datetime.
|
1724
|
-
tz=datetime.timezone.utc
|
1725
|
-
) + datetime.timedelta(
|
1720
|
+
expires_at = utc_now_tz_aware() + datetime.timedelta(
|
1726
1721
|
seconds=DEFAULT_STS_TOKEN_EXPIRATION
|
1727
1722
|
)
|
1728
1723
|
|
@@ -2130,9 +2125,9 @@ class AWSServiceConnector(ServiceConnector):
|
|
2130
2125
|
# Kubernetes authentication tokens issued by AWS EKS have a fixed
|
2131
2126
|
# expiration time of 15 minutes
|
2132
2127
|
# source: https://aws.github.io/aws-eks-best-practices/security/docs/iam/#controlling-access-to-eks-clusters
|
2133
|
-
expires_at = datetime.
|
2134
|
-
|
2135
|
-
)
|
2128
|
+
expires_at = utc_now_tz_aware() + datetime.timedelta(
|
2129
|
+
minutes=EKS_KUBE_API_TOKEN_EXPIRATION
|
2130
|
+
)
|
2136
2131
|
|
2137
2132
|
# get cluster details
|
2138
2133
|
cluster_arn = cluster["cluster"]["arn"]
|
@@ -58,6 +58,7 @@ from zenml.service_connectors.service_connector import (
|
|
58
58
|
)
|
59
59
|
from zenml.utils.enum_utils import StrEnum
|
60
60
|
from zenml.utils.secret_utils import PlainSerializedSecretStr
|
61
|
+
from zenml.utils.time_utils import to_local_tz, to_utc_timezone, utc_now
|
61
62
|
|
62
63
|
# Configure the logging level for azure.identity
|
63
64
|
logging.getLogger("azure.identity").setLevel(logging.WARNING)
|
@@ -171,12 +172,7 @@ class ZenMLAzureTokenCredential(TokenCredential):
|
|
171
172
|
self.token = token
|
172
173
|
|
173
174
|
# Convert the expiration time from UTC to local time
|
174
|
-
expires_at.
|
175
|
-
expires_at = expires_at.astimezone(
|
176
|
-
datetime.datetime.now().astimezone().tzinfo
|
177
|
-
)
|
178
|
-
|
179
|
-
self.expires_on = int(expires_at.timestamp())
|
175
|
+
self.expires_on = int(to_local_tz(expires_at).timestamp())
|
180
176
|
|
181
177
|
def get_token(self, *scopes: str, **kwargs: Any) -> Any:
|
182
178
|
"""Get token.
|
@@ -604,11 +600,9 @@ class AzureServiceConnector(ServiceConnector):
|
|
604
600
|
return session, None
|
605
601
|
|
606
602
|
# Refresh expired sessions
|
607
|
-
now = datetime.datetime.now(datetime.timezone.utc)
|
608
|
-
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
|
609
603
|
|
610
604
|
# check if the token expires in the near future
|
611
|
-
if expires_at >
|
605
|
+
if expires_at > utc_now(tz_aware=expires_at) + datetime.timedelta(
|
612
606
|
minutes=AZURE_SESSION_EXPIRATION_BUFFER
|
613
607
|
):
|
614
608
|
return session, expires_at
|
@@ -1137,7 +1131,7 @@ class AzureServiceConnector(ServiceConnector):
|
|
1137
1131
|
# format.
|
1138
1132
|
expires_at = datetime.datetime.fromtimestamp(token.expires_on)
|
1139
1133
|
# Convert the expiration timestamp from local time to UTC time.
|
1140
|
-
expires_at = expires_at
|
1134
|
+
expires_at = to_utc_timezone(expires_at)
|
1141
1135
|
|
1142
1136
|
auth_config = AzureAccessTokenConfig(
|
1143
1137
|
token=token.token,
|
@@ -78,6 +78,7 @@ from zenml.service_connectors.service_connector import (
|
|
78
78
|
from zenml.utils.enum_utils import StrEnum
|
79
79
|
from zenml.utils.pydantic_utils import before_validator_handler
|
80
80
|
from zenml.utils.secret_utils import PlainSerializedSecretStr
|
81
|
+
from zenml.utils.time_utils import utc_now
|
81
82
|
|
82
83
|
logger = get_logger(__name__)
|
83
84
|
|
@@ -1124,10 +1125,9 @@ class GCPServiceConnector(ServiceConnector):
|
|
1124
1125
|
return session, None
|
1125
1126
|
|
1126
1127
|
# Refresh expired sessions
|
1127
|
-
|
1128
|
-
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
|
1128
|
+
|
1129
1129
|
# check if the token expires in the near future
|
1130
|
-
if expires_at >
|
1130
|
+
if expires_at > utc_now(tz_aware=expires_at) + datetime.timedelta(
|
1131
1131
|
minutes=GCP_SESSION_EXPIRATION_BUFFER
|
1132
1132
|
):
|
1133
1133
|
return session, expires_at
|
@@ -47,16 +47,11 @@ class HuggingfaceIntegration(Integration):
|
|
47
47
|
A list of requirements.
|
48
48
|
"""
|
49
49
|
requirements = [
|
50
|
-
"datasets",
|
50
|
+
"datasets>=2.16.0",
|
51
51
|
"huggingface_hub>0.19.0",
|
52
52
|
"accelerate",
|
53
53
|
"bitsandbytes>=0.41.3",
|
54
54
|
"peft",
|
55
|
-
# temporary fix for CI issue similar to:
|
56
|
-
# - https://github.com/huggingface/datasets/issues/6737
|
57
|
-
# - https://github.com/huggingface/datasets/issues/6697
|
58
|
-
# TODO try relaxing it back going forward
|
59
|
-
"fsspec<=2023.12.0",
|
60
55
|
"transformers",
|
61
56
|
]
|
62
57
|
|