zenml-nightly 0.73.0.dev20250123__py3-none-any.whl → 0.73.0.dev20250125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zenml/VERSION +1 -1
  2. zenml/analytics/context.py +2 -6
  3. zenml/cli/annotator.py +1 -1
  4. zenml/cli/login.py +17 -6
  5. zenml/cli/server.py +1 -0
  6. zenml/cli/service_connectors.py +5 -5
  7. zenml/cli/stack.py +2 -2
  8. zenml/cli/utils.py +2 -54
  9. zenml/config/pipeline_configurations.py +3 -2
  10. zenml/config/schedule.py +0 -24
  11. zenml/enums.py +1 -0
  12. zenml/event_hub/base_event_hub.py +3 -4
  13. zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +3 -4
  14. zenml/integrations/aws/__init__.py +2 -1
  15. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +15 -0
  16. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +310 -70
  17. zenml/integrations/aws/service_connectors/aws_service_connector.py +8 -13
  18. zenml/integrations/azure/service_connectors/azure_service_connector.py +4 -10
  19. zenml/integrations/gcp/service_connectors/gcp_service_connector.py +3 -3
  20. zenml/integrations/huggingface/__init__.py +1 -6
  21. zenml/integrations/kubernetes/orchestrators/kube_utils.py +3 -3
  22. zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +6 -2
  23. zenml/integrations/whylogs/data_validators/whylogs_data_validator.py +2 -3
  24. zenml/logging/step_logging.py +7 -7
  25. zenml/login/credentials.py +6 -5
  26. zenml/login/credentials_store.py +4 -3
  27. zenml/models/v2/core/api_key.py +5 -2
  28. zenml/models/v2/core/schedule.py +19 -3
  29. zenml/orchestrators/publish_utils.py +4 -4
  30. zenml/orchestrators/step_launcher.py +3 -3
  31. zenml/orchestrators/step_run_utils.py +2 -2
  32. zenml/pipelines/run_utils.py +2 -2
  33. zenml/service_connectors/service_connector.py +7 -4
  34. zenml/stack/stack.py +5 -4
  35. zenml/stack/stack_component.py +10 -2
  36. zenml/stack_deployments/stack_deployment.py +2 -3
  37. zenml/utils/string_utils.py +2 -2
  38. zenml/utils/time_utils.py +138 -0
  39. zenml/zen_server/auth.py +8 -9
  40. zenml/zen_server/cloud_utils.py +4 -6
  41. zenml/zen_server/routers/devices_endpoints.py +2 -4
  42. zenml/zen_server/routers/workspaces_endpoints.py +2 -0
  43. zenml/zen_server/zen_server_api.py +9 -8
  44. zenml/zen_stores/migrations/versions/25155145c545_separate_actions_and_triggers.py +3 -2
  45. zenml/zen_stores/migrations/versions/3dcc5d20e82f_add_last_user_activity.py +3 -3
  46. zenml/zen_stores/migrations/versions/46506f72f0ed_add_server_settings.py +3 -2
  47. zenml/zen_stores/migrations/versions/5994f9ad0489_introduce_role_permissions.py +10 -7
  48. zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +3 -2
  49. zenml/zen_stores/migrations/versions/a91762e6be36_artifact_version_table.py +5 -3
  50. zenml/zen_stores/schemas/action_schemas.py +2 -2
  51. zenml/zen_stores/schemas/api_key_schemas.py +5 -4
  52. zenml/zen_stores/schemas/artifact_schemas.py +3 -3
  53. zenml/zen_stores/schemas/base_schemas.py +5 -7
  54. zenml/zen_stores/schemas/code_repository_schemas.py +2 -2
  55. zenml/zen_stores/schemas/component_schemas.py +2 -2
  56. zenml/zen_stores/schemas/device_schemas.py +5 -4
  57. zenml/zen_stores/schemas/event_source_schemas.py +2 -2
  58. zenml/zen_stores/schemas/flavor_schemas.py +2 -2
  59. zenml/zen_stores/schemas/model_schemas.py +3 -3
  60. zenml/zen_stores/schemas/pipeline_run_schemas.py +11 -3
  61. zenml/zen_stores/schemas/pipeline_schemas.py +2 -2
  62. zenml/zen_stores/schemas/run_template_schemas.py +2 -2
  63. zenml/zen_stores/schemas/schedule_schema.py +20 -4
  64. zenml/zen_stores/schemas/secret_schemas.py +2 -2
  65. zenml/zen_stores/schemas/server_settings_schemas.py +6 -9
  66. zenml/zen_stores/schemas/service_connector_schemas.py +3 -2
  67. zenml/zen_stores/schemas/service_schemas.py +2 -2
  68. zenml/zen_stores/schemas/stack_schemas.py +2 -2
  69. zenml/zen_stores/schemas/step_run_schemas.py +3 -2
  70. zenml/zen_stores/schemas/tag_schemas.py +2 -2
  71. zenml/zen_stores/schemas/trigger_schemas.py +2 -2
  72. zenml/zen_stores/schemas/user_schemas.py +3 -3
  73. zenml/zen_stores/schemas/workspace_schemas.py +2 -2
  74. zenml/zen_stores/sql_zen_store.py +6 -14
  75. {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/METADATA +2 -2
  76. {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/RECORD +79 -78
  77. {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/LICENSE +0 -0
  78. {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/WHEEL +0 -0
  79. {zenml_nightly-0.73.0.dev20250123.dist-info → zenml_nightly-0.73.0.dev20250125.dist-info}/entry_points.txt +0 -0
@@ -35,14 +35,20 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput
35
35
  from sagemaker.workflow.execution_variables import ExecutionVariables
36
36
  from sagemaker.workflow.pipeline import Pipeline
37
37
  from sagemaker.workflow.steps import ProcessingStep, TrainingStep
38
+ from sagemaker.workflow.triggers import PipelineSchedule
38
39
 
40
+ from zenml.client import Client
39
41
  from zenml.config.base_settings import BaseSettings
40
42
  from zenml.constants import (
41
43
  METADATA_ORCHESTRATOR_LOGS_URL,
42
44
  METADATA_ORCHESTRATOR_RUN_ID,
43
45
  METADATA_ORCHESTRATOR_URL,
44
46
  )
45
- from zenml.enums import ExecutionStatus, StackComponentType
47
+ from zenml.enums import (
48
+ ExecutionStatus,
49
+ MetadataResourceTypes,
50
+ StackComponentType,
51
+ )
46
52
  from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
47
53
  SagemakerOrchestratorConfig,
48
54
  SagemakerOrchestratorSettings,
@@ -57,6 +63,7 @@ from zenml.orchestrators import ContainerizedOrchestrator
57
63
  from zenml.orchestrators.utils import get_orchestrator_run_name
58
64
  from zenml.stack import StackValidator
59
65
  from zenml.utils.env_utils import split_environment_variables
66
+ from zenml.utils.time_utils import to_utc_timezone, utc_now_tz_aware
60
67
 
61
68
  if TYPE_CHECKING:
62
69
  from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
@@ -69,6 +76,36 @@ POLLING_DELAY = 30
69
76
  logger = get_logger(__name__)
70
77
 
71
78
 
79
+ def dissect_schedule_arn(
80
+ schedule_arn: str,
81
+ ) -> Tuple[Optional[str], Optional[str]]:
82
+ """Extracts the region and the name from an EventBridge schedule ARN.
83
+
84
+ Args:
85
+ schedule_arn: The ARN of the EventBridge schedule.
86
+
87
+ Returns:
88
+ Region Name, Schedule Name (including the group name)
89
+
90
+ Raises:
91
+ ValueError: If the input is not a properly formatted ARN.
92
+ """
93
+ # Split the ARN into parts
94
+ arn_parts = schedule_arn.split(":")
95
+
96
+ # Validate ARN structure
97
+ if len(arn_parts) < 6 or not arn_parts[5].startswith("schedule/"):
98
+ raise ValueError("Invalid EventBridge schedule ARN format.")
99
+
100
+ # Extract the region
101
+ region = arn_parts[3]
102
+
103
+ # Extract the group name and schedule name
104
+ name = arn_parts[5].split("schedule/")[1]
105
+
106
+ return region, name
107
+
108
+
72
109
  def dissect_pipeline_execution_arn(
73
110
  pipeline_execution_arn: str,
74
111
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
@@ -237,21 +274,15 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
237
274
  environment.
238
275
 
239
276
  Raises:
240
- RuntimeError: If a connector is used that does not return a
241
- `boto3.Session` object.
277
+ RuntimeError: If there is an error creating or scheduling the
278
+ pipeline.
242
279
  TypeError: If the network_config passed is not compatible with the
243
280
  AWS SageMaker NetworkConfig class.
281
+ ValueError: If the schedule is not valid.
244
282
 
245
283
  Yields:
246
284
  A dictionary of metadata related to the pipeline run.
247
285
  """
248
- if deployment.schedule:
249
- logger.warning(
250
- "The Sagemaker Orchestrator currently does not support the "
251
- "use of schedules. The `schedule` will be ignored "
252
- "and the pipeline will be run immediately."
253
- )
254
-
255
286
  # sagemaker requires pipelineName to use alphanum and hyphens only
256
287
  unsanitized_orchestrator_run_name = get_orchestrator_run_name(
257
288
  pipeline_name=deployment.pipeline_configuration.name
@@ -459,7 +490,7 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
459
490
 
460
491
  sagemaker_steps.append(sagemaker_step)
461
492
 
462
- # construct the pipeline from the sagemaker_steps
493
+ # Create the pipeline
463
494
  pipeline = Pipeline(
464
495
  name=orchestrator_run_name,
465
496
  steps=sagemaker_steps,
@@ -479,39 +510,211 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
479
510
  if settings.pipeline_tags
480
511
  else None,
481
512
  )
482
- execution = pipeline.start()
483
- logger.warning(
484
- "Steps can take 5-15 minutes to start running "
485
- "when using the Sagemaker Orchestrator."
486
- )
487
513
 
488
- # Yield metadata based on the generated execution object
489
- yield from self.compute_metadata(
490
- execution=execution, settings=settings
491
- )
514
+ # Handle scheduling if specified
515
+ if deployment.schedule:
516
+ if settings.synchronous:
517
+ logger.warning(
518
+ "The 'synchronous' setting is ignored for scheduled "
519
+ "pipelines since they run independently of the "
520
+ "deployment process."
521
+ )
492
522
 
493
- # mainly for testing purposes, we wait for the pipeline to finish
494
- if settings.synchronous:
495
- logger.info(
496
- "Executing synchronously. Waiting for pipeline to finish... \n"
497
- "At this point you can `Ctrl-C` out without cancelling the "
498
- "execution."
523
+ schedule_name = orchestrator_run_name
524
+ next_execution = None
525
+ start_date = (
526
+ to_utc_timezone(deployment.schedule.start_time)
527
+ if deployment.schedule.start_time
528
+ else None
499
529
  )
530
+
531
+ # Create PipelineSchedule based on schedule type
532
+ if deployment.schedule.cron_expression:
533
+ cron_exp = self._validate_cron_expression(
534
+ deployment.schedule.cron_expression
535
+ )
536
+ schedule = PipelineSchedule(
537
+ name=schedule_name,
538
+ cron=cron_exp,
539
+ start_date=start_date,
540
+ enabled=True,
541
+ )
542
+ elif deployment.schedule.interval_second:
543
+ # This is necessary because SageMaker's PipelineSchedule rate
544
+ # expressions require minutes as the minimum time unit.
545
+ # Even if a user specifies an interval of less than 60 seconds,
546
+ # it will be rounded up to 1 minute.
547
+ minutes = max(
548
+ 1,
549
+ int(
550
+ deployment.schedule.interval_second.total_seconds()
551
+ / 60
552
+ ),
553
+ )
554
+ schedule = PipelineSchedule(
555
+ name=schedule_name,
556
+ rate=(minutes, "minutes"),
557
+ start_date=start_date,
558
+ enabled=True,
559
+ )
560
+ next_execution = (
561
+ deployment.schedule.start_time or utc_now_tz_aware()
562
+ ) + deployment.schedule.interval_second
563
+ else:
564
+ # One-time schedule
565
+ execution_time = (
566
+ deployment.schedule.run_once_start_time
567
+ or deployment.schedule.start_time
568
+ )
569
+ if not execution_time:
570
+ raise ValueError(
571
+ "A start time must be specified for one-time "
572
+ "schedule execution"
573
+ )
574
+ schedule = PipelineSchedule(
575
+ name=schedule_name,
576
+ at=to_utc_timezone(execution_time),
577
+ enabled=True,
578
+ )
579
+ next_execution = execution_time
580
+
581
+ # Get the current role ARN if not explicitly configured
582
+ if self.config.scheduler_role is None:
583
+ logger.info(
584
+ "No scheduler_role configured. Trying to extract it from "
585
+ "the client side authentication."
586
+ )
587
+ sts = session.boto_session.client("sts")
588
+ try:
589
+ scheduler_role_arn = sts.get_caller_identity()["Arn"]
590
+ # If this is a user ARN, try to get the role ARN
591
+ if ":user/" in scheduler_role_arn:
592
+ logger.warning(
593
+ f"Using IAM user credentials "
594
+ f"({scheduler_role_arn}). For production "
595
+ "environments, it's recommended to use IAM roles "
596
+ "instead."
597
+ )
598
+ # If this is an assumed role, extract the role ARN
599
+ elif ":assumed-role/" in scheduler_role_arn:
600
+ # Convert assumed-role ARN format to role ARN format
601
+ # From: arn:aws:sts::123456789012:assumed-role/role-name/session-name
602
+ # To: arn:aws:iam::123456789012:role/role-name
603
+ scheduler_role_arn = re.sub(
604
+ r"arn:aws:sts::(\d+):assumed-role/([^/]+)/.*",
605
+ r"arn:aws:iam::\1:role/\2",
606
+ scheduler_role_arn,
607
+ )
608
+ elif ":role/" not in scheduler_role_arn:
609
+ raise RuntimeError(
610
+ f"Unexpected credential type "
611
+ f"({scheduler_role_arn}). Please use IAM "
612
+ f"roles for SageMaker pipeline scheduling."
613
+ )
614
+ else:
615
+ raise RuntimeError(
616
+ "The ARN of the caller identity "
617
+ f"`{scheduler_role_arn}` does not "
618
+ "include a user or a proper role."
619
+ )
620
+ except Exception:
621
+ raise RuntimeError(
622
+ "Failed to get current role ARN. This means the "
623
+ "your client side credentials that you are "
624
+ "is not configured correctly to schedule sagemaker "
625
+ "pipelines. For more information, please check:"
626
+ "https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules"
627
+ )
628
+ else:
629
+ scheduler_role_arn = self.config.scheduler_role
630
+
631
+ # Attach schedule to pipeline
632
+ triggers = pipeline.put_triggers(
633
+ triggers=[schedule],
634
+ role_arn=scheduler_role_arn,
635
+ )
636
+ logger.info(f"The schedule ARN is: {triggers[0]}")
637
+
500
638
  try:
501
- execution.wait(
502
- delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
639
+ from zenml.models import RunMetadataResource
640
+
641
+ schedule_metadata = self.generate_schedule_metadata(
642
+ schedule_arn=triggers[0]
503
643
  )
504
- logger.info("Pipeline completed successfully.")
505
- except WaiterError:
506
- raise RuntimeError(
507
- "Timed out while waiting for pipeline execution to "
508
- "finish. For long-running pipelines we recommend "
509
- "configuring your orchestrator for asynchronous execution. "
510
- "The following command does this for you: \n"
511
- f"`zenml orchestrator update {self.name} "
512
- f"--synchronous=False`"
644
+
645
+ Client().create_run_metadata(
646
+ metadata=schedule_metadata, # type: ignore[arg-type]
647
+ resources=[
648
+ RunMetadataResource(
649
+ id=deployment.schedule.id,
650
+ type=MetadataResourceTypes.SCHEDULE,
651
+ )
652
+ ],
653
+ )
654
+ except Exception as e:
655
+ logger.debug(
656
+ "There was an error attaching metadata to the "
657
+ f"schedule: {e}"
513
658
  )
514
659
 
660
+ logger.info(
661
+ f"Successfully scheduled pipeline with name: {schedule_name}\n"
662
+ + (
663
+ f"First execution will occur at: "
664
+ f"{next_execution.strftime('%Y-%m-%d %H:%M:%S UTC')}"
665
+ if next_execution
666
+ else f"Using cron expression: "
667
+ f"{deployment.schedule.cron_expression}"
668
+ )
669
+ + (
670
+ f" (and every {minutes} minutes after)"
671
+ if deployment.schedule.interval_second
672
+ else ""
673
+ )
674
+ )
675
+ logger.info(
676
+ "\n\nIn order to cancel the schedule, you can use execute "
677
+ "the following command:\n"
678
+ )
679
+ logger.info(
680
+ f"`aws scheduler delete-schedule --name {schedule_name}`"
681
+ )
682
+ else:
683
+ # Execute the pipeline immediately if no schedule is specified
684
+ execution = pipeline.start()
685
+ logger.warning(
686
+ "Steps can take 5-15 minutes to start running "
687
+ "when using the Sagemaker Orchestrator."
688
+ )
689
+
690
+ # Yield metadata based on the generated execution object
691
+ yield from self.compute_metadata(
692
+ execution_arn=execution.arn, settings=settings
693
+ )
694
+
695
+ # mainly for testing purposes, we wait for the pipeline to finish
696
+ if settings.synchronous:
697
+ logger.info(
698
+ "Executing synchronously. Waiting for pipeline to "
699
+ "finish... \n"
700
+ "At this point you can `Ctrl-C` out without cancelling the "
701
+ "execution."
702
+ )
703
+ try:
704
+ execution.wait(
705
+ delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
706
+ )
707
+ logger.info("Pipeline completed successfully.")
708
+ except WaiterError:
709
+ raise RuntimeError(
710
+ "Timed out while waiting for pipeline execution to "
711
+ "finish. For long-running pipelines we recommend "
712
+ "configuring your orchestrator for asynchronous "
713
+ "execution. The following command does this for you: \n"
714
+ f"`zenml orchestrator update {self.name} "
715
+ f"--synchronous=False`"
716
+ )
717
+
515
718
  def get_pipeline_run_metadata(
516
719
  self, run_id: UUID
517
720
  ) -> Dict[str, "MetadataType"]:
@@ -523,10 +726,20 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
523
726
  Returns:
524
727
  A dictionary of metadata.
525
728
  """
526
- pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
527
- run_metadata: Dict[str, "MetadataType"] = {
528
- "pipeline_execution_arn": pipeline_execution_arn,
529
- }
729
+ execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
730
+
731
+ run_metadata: Dict[str, "MetadataType"] = {}
732
+
733
+ settings = cast(
734
+ SagemakerOrchestratorSettings,
735
+ self.get_settings(Client().get_pipeline_run(run_id)),
736
+ )
737
+
738
+ for metadata in self.compute_metadata(
739
+ execution_arn=execution_arn,
740
+ settings=settings,
741
+ ):
742
+ run_metadata.update(metadata)
530
743
 
531
744
  return run_metadata
532
745
 
@@ -588,56 +801,57 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
588
801
 
589
802
  def compute_metadata(
590
803
  self,
591
- execution: Any,
804
+ execution_arn: str,
592
805
  settings: SagemakerOrchestratorSettings,
593
806
  ) -> Iterator[Dict[str, MetadataType]]:
594
807
  """Generate run metadata based on the generated Sagemaker Execution.
595
808
 
596
809
  Args:
597
- execution: The corresponding _PipelineExecution object.
810
+ execution_arn: The ARN of the pipeline execution.
598
811
  settings: The Sagemaker orchestrator settings.
599
812
 
600
813
  Yields:
601
814
  A dictionary of metadata related to the pipeline run.
602
815
  """
603
- # Metadata
604
- metadata: Dict[str, MetadataType] = {}
605
-
606
816
  # Orchestrator Run ID
607
- if run_id := self._compute_orchestrator_run_id(execution):
608
- metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id
817
+ metadata: Dict[str, MetadataType] = {
818
+ "pipeline_execution_arn": execution_arn,
819
+ METADATA_ORCHESTRATOR_RUN_ID: execution_arn,
820
+ }
609
821
 
610
822
  # URL to the Sagemaker's pipeline view
611
- if orchestrator_url := self._compute_orchestrator_url(execution):
823
+ if orchestrator_url := self._compute_orchestrator_url(
824
+ execution_arn=execution_arn
825
+ ):
612
826
  metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)
613
827
 
614
828
  # URL to the corresponding CloudWatch page
615
829
  if logs_url := self._compute_orchestrator_logs_url(
616
- execution, settings
830
+ execution_arn=execution_arn, settings=settings
617
831
  ):
618
832
  metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)
619
833
 
620
834
  yield metadata
621
835
 
622
- @staticmethod
623
836
  def _compute_orchestrator_url(
624
- pipeline_execution: Any,
837
+ self,
838
+ execution_arn: Any,
625
839
  ) -> Optional[str]:
626
840
  """Generate the Orchestrator Dashboard URL upon pipeline execution.
627
841
 
628
842
  Args:
629
- pipeline_execution: The corresponding _PipelineExecution object.
843
+ execution_arn: The ARN of the pipeline execution.
630
844
 
631
845
  Returns:
632
846
  the URL to the dashboard view in SageMaker.
633
847
  """
634
848
  try:
635
849
  region_name, pipeline_name, execution_id = (
636
- dissect_pipeline_execution_arn(pipeline_execution.arn)
850
+ dissect_pipeline_execution_arn(execution_arn)
637
851
  )
638
852
 
639
853
  # Get the Sagemaker session
640
- session = pipeline_execution.sagemaker_session
854
+ session = self._get_sagemaker_session()
641
855
 
642
856
  # List the Studio domains and get the Studio Domain ID
643
857
  domains_response = session.sagemaker_client.list_domains()
@@ -657,13 +871,13 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
657
871
 
658
872
  @staticmethod
659
873
  def _compute_orchestrator_logs_url(
660
- pipeline_execution: Any,
874
+ execution_arn: Any,
661
875
  settings: SagemakerOrchestratorSettings,
662
876
  ) -> Optional[str]:
663
877
  """Generate the CloudWatch URL upon pipeline execution.
664
878
 
665
879
  Args:
666
- pipeline_execution: The corresponding _PipelineExecution object.
880
+ execution_arn: The ARN of the pipeline execution.
667
881
  settings: The Sagemaker orchestrator settings.
668
882
 
669
883
  Returns:
@@ -671,7 +885,7 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
671
885
  """
672
886
  try:
673
887
  region_name, _, execution_id = dissect_pipeline_execution_arn(
674
- pipeline_execution.arn
888
+ execution_arn
675
889
  )
676
890
 
677
891
  use_training_jobs = True
@@ -693,22 +907,48 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
693
907
  return None
694
908
 
695
909
  @staticmethod
696
- def _compute_orchestrator_run_id(
697
- pipeline_execution: Any,
698
- ) -> Optional[str]:
699
- """Fetch the Orchestrator Run ID upon pipeline execution.
910
+ def generate_schedule_metadata(schedule_arn: str) -> Dict[str, str]:
911
+ """Attaches metadata to the ZenML Schedules.
700
912
 
701
913
  Args:
702
- pipeline_execution: The corresponding _PipelineExecution object.
914
+ schedule_arn: The trigger ARNs that is generated on the AWS side.
703
915
 
704
916
  Returns:
705
- the Execution ID of the run in SageMaker.
917
+ a dictionary containing metadata related to the schedule.
706
918
  """
707
- try:
708
- return str(pipeline_execution.arn)
919
+ region, name = dissect_schedule_arn(schedule_arn=schedule_arn)
709
920
 
710
- except Exception as e:
711
- logger.warning(
712
- f"There was an issue while extracting the pipeline run ID: {e}"
921
+ return {
922
+ "trigger_url": (
923
+ f"https://{region}.console.aws.amazon.com/scheduler/home"
924
+ f"?region={region}#schedules/{name}"
925
+ ),
926
+ }
927
+
928
+ @staticmethod
929
+ def _validate_cron_expression(cron_expression: str) -> str:
930
+ """Validates and formats a cron expression for SageMaker schedules.
931
+
932
+ Args:
933
+ cron_expression: The cron expression to validate
934
+
935
+ Returns:
936
+ The formatted cron expression
937
+
938
+ Raises:
939
+ ValueError: If the cron expression is invalid
940
+ """
941
+ # Strip any "cron(" prefix if it exists
942
+ cron_exp = cron_expression.replace("cron(", "").replace(")", "")
943
+
944
+ # Split into components
945
+ parts = cron_exp.split()
946
+ if len(parts) not in [6, 7]: # AWS cron requires 6 or 7 fields
947
+ raise ValueError(
948
+ f"Invalid cron expression: {cron_expression}. AWS cron "
949
+ "expressions must have 6 or 7 fields: minute hour day-of-month "
950
+ "month day-of-week year(optional). Example: '15 10 ? * 6L "
951
+ "2022-2023'"
713
952
  )
714
- return None
953
+
954
+ return cron_exp
@@ -66,6 +66,7 @@ from zenml.service_connectors.service_connector import (
66
66
  )
67
67
  from zenml.utils.enum_utils import StrEnum
68
68
  from zenml.utils.secret_utils import PlainSerializedSecretStr
69
+ from zenml.utils.time_utils import utc_now_tz_aware
69
70
 
70
71
  logger = get_logger(__name__)
71
72
 
@@ -711,7 +712,7 @@ class AWSServiceConnector(ServiceConnector):
711
712
  return session, None
712
713
 
713
714
  # Refresh expired sessions
714
- now = datetime.datetime.now(datetime.timezone.utc)
715
+ now = utc_now_tz_aware()
715
716
  expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
716
717
  # check if the token expires in the near future
717
718
  if expires_at > now + datetime.timedelta(
@@ -959,9 +960,7 @@ class AWSServiceConnector(ServiceConnector):
959
960
  # determine the expiration time of the temporary credentials
960
961
  # from the boto3 session, so we assume the default IAM role
961
962
  # expiration date is used
962
- expiration_time = datetime.datetime.now(
963
- tz=datetime.timezone.utc
964
- ) + datetime.timedelta(
963
+ expiration_time = utc_now_tz_aware() + datetime.timedelta(
965
964
  seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
966
965
  )
967
966
  return session, expiration_time
@@ -1673,9 +1672,7 @@ class AWSServiceConnector(ServiceConnector):
1673
1672
  # expiration time of the temporary credentials from the
1674
1673
  # boto3 session, so we assume the default IAM role
1675
1674
  # expiration period is used
1676
- expires_at = datetime.datetime.now(
1677
- tz=datetime.timezone.utc
1678
- ) + datetime.timedelta(
1675
+ expires_at = utc_now_tz_aware() + datetime.timedelta(
1679
1676
  seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
1680
1677
  )
1681
1678
 
@@ -1720,9 +1717,7 @@ class AWSServiceConnector(ServiceConnector):
1720
1717
  aws_secret_access_key=credentials["SecretAccessKey"],
1721
1718
  aws_session_token=credentials["SessionToken"],
1722
1719
  )
1723
- expires_at = datetime.datetime.now(
1724
- tz=datetime.timezone.utc
1725
- ) + datetime.timedelta(
1720
+ expires_at = utc_now_tz_aware() + datetime.timedelta(
1726
1721
  seconds=DEFAULT_STS_TOKEN_EXPIRATION
1727
1722
  )
1728
1723
 
@@ -2130,9 +2125,9 @@ class AWSServiceConnector(ServiceConnector):
2130
2125
  # Kubernetes authentication tokens issued by AWS EKS have a fixed
2131
2126
  # expiration time of 15 minutes
2132
2127
  # source: https://aws.github.io/aws-eks-best-practices/security/docs/iam/#controlling-access-to-eks-clusters
2133
- expires_at = datetime.datetime.now(
2134
- tz=datetime.timezone.utc
2135
- ) + datetime.timedelta(minutes=EKS_KUBE_API_TOKEN_EXPIRATION)
2128
+ expires_at = utc_now_tz_aware() + datetime.timedelta(
2129
+ minutes=EKS_KUBE_API_TOKEN_EXPIRATION
2130
+ )
2136
2131
 
2137
2132
  # get cluster details
2138
2133
  cluster_arn = cluster["cluster"]["arn"]
@@ -58,6 +58,7 @@ from zenml.service_connectors.service_connector import (
58
58
  )
59
59
  from zenml.utils.enum_utils import StrEnum
60
60
  from zenml.utils.secret_utils import PlainSerializedSecretStr
61
+ from zenml.utils.time_utils import to_local_tz, to_utc_timezone, utc_now
61
62
 
62
63
  # Configure the logging level for azure.identity
63
64
  logging.getLogger("azure.identity").setLevel(logging.WARNING)
@@ -171,12 +172,7 @@ class ZenMLAzureTokenCredential(TokenCredential):
171
172
  self.token = token
172
173
 
173
174
  # Convert the expiration time from UTC to local time
174
- expires_at.replace(tzinfo=datetime.timezone.utc)
175
- expires_at = expires_at.astimezone(
176
- datetime.datetime.now().astimezone().tzinfo
177
- )
178
-
179
- self.expires_on = int(expires_at.timestamp())
175
+ self.expires_on = int(to_local_tz(expires_at).timestamp())
180
176
 
181
177
  def get_token(self, *scopes: str, **kwargs: Any) -> Any:
182
178
  """Get token.
@@ -604,11 +600,9 @@ class AzureServiceConnector(ServiceConnector):
604
600
  return session, None
605
601
 
606
602
  # Refresh expired sessions
607
- now = datetime.datetime.now(datetime.timezone.utc)
608
- expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
609
603
 
610
604
  # check if the token expires in the near future
611
- if expires_at > now + datetime.timedelta(
605
+ if expires_at > utc_now(tz_aware=expires_at) + datetime.timedelta(
612
606
  minutes=AZURE_SESSION_EXPIRATION_BUFFER
613
607
  ):
614
608
  return session, expires_at
@@ -1137,7 +1131,7 @@ class AzureServiceConnector(ServiceConnector):
1137
1131
  # format.
1138
1132
  expires_at = datetime.datetime.fromtimestamp(token.expires_on)
1139
1133
  # Convert the expiration timestamp from local time to UTC time.
1140
- expires_at = expires_at.astimezone(datetime.timezone.utc)
1134
+ expires_at = to_utc_timezone(expires_at)
1141
1135
 
1142
1136
  auth_config = AzureAccessTokenConfig(
1143
1137
  token=token.token,
@@ -78,6 +78,7 @@ from zenml.service_connectors.service_connector import (
78
78
  from zenml.utils.enum_utils import StrEnum
79
79
  from zenml.utils.pydantic_utils import before_validator_handler
80
80
  from zenml.utils.secret_utils import PlainSerializedSecretStr
81
+ from zenml.utils.time_utils import utc_now
81
82
 
82
83
  logger = get_logger(__name__)
83
84
 
@@ -1124,10 +1125,9 @@ class GCPServiceConnector(ServiceConnector):
1124
1125
  return session, None
1125
1126
 
1126
1127
  # Refresh expired sessions
1127
- now = datetime.datetime.now(datetime.timezone.utc)
1128
- expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
1128
+
1129
1129
  # check if the token expires in the near future
1130
- if expires_at > now + datetime.timedelta(
1130
+ if expires_at > utc_now(tz_aware=expires_at) + datetime.timedelta(
1131
1131
  minutes=GCP_SESSION_EXPIRATION_BUFFER
1132
1132
  ):
1133
1133
  return session, expires_at
@@ -47,16 +47,11 @@ class HuggingfaceIntegration(Integration):
47
47
  A list of requirements.
48
48
  """
49
49
  requirements = [
50
- "datasets",
50
+ "datasets>=2.16.0",
51
51
  "huggingface_hub>0.19.0",
52
52
  "accelerate",
53
53
  "bitsandbytes>=0.41.3",
54
54
  "peft",
55
- # temporary fix for CI issue similar to:
56
- # - https://github.com/huggingface/datasets/issues/6737
57
- # - https://github.com/huggingface/datasets/issues/6697
58
- # TODO try relaxing it back going forward
59
- "fsspec<=2023.12.0",
60
55
  "transformers",
61
56
  ]
62
57