apache-airflow-providers-amazon 8.26.0rc1__py3-none-any.whl → 8.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
  3. airflow/providers/amazon/aws/datasets/__init__.py +16 -0
  4. airflow/providers/amazon/aws/datasets/s3.py +45 -0
  5. airflow/providers/amazon/aws/executors/batch/batch_executor.py +27 -17
  6. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +31 -13
  7. airflow/providers/amazon/aws/hooks/kinesis_analytics.py +65 -0
  8. airflow/providers/amazon/aws/hooks/rds.py +3 -3
  9. airflow/providers/amazon/aws/hooks/s3.py +26 -1
  10. airflow/providers/amazon/aws/hooks/step_function.py +18 -0
  11. airflow/providers/amazon/aws/operators/athena.py +16 -17
  12. airflow/providers/amazon/aws/operators/emr.py +23 -23
  13. airflow/providers/amazon/aws/operators/kinesis_analytics.py +348 -0
  14. airflow/providers/amazon/aws/operators/rds.py +17 -20
  15. airflow/providers/amazon/aws/operators/redshift_cluster.py +71 -53
  16. airflow/providers/amazon/aws/operators/s3.py +18 -12
  17. airflow/providers/amazon/aws/operators/sagemaker.py +12 -27
  18. airflow/providers/amazon/aws/operators/step_function.py +12 -2
  19. airflow/providers/amazon/aws/sensors/kinesis_analytics.py +234 -0
  20. airflow/providers/amazon/aws/sensors/s3.py +11 -5
  21. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -0
  22. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  23. airflow/providers/amazon/aws/triggers/emr.py +3 -1
  24. airflow/providers/amazon/aws/triggers/kinesis_analytics.py +69 -0
  25. airflow/providers/amazon/aws/triggers/sagemaker.py +9 -1
  26. airflow/providers/amazon/aws/waiters/kinesisanalyticsv2.json +151 -0
  27. airflow/providers/amazon/aws/waiters/rds.json +253 -0
  28. airflow/providers/amazon/get_provider_info.py +35 -2
  29. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/METADATA +32 -25
  30. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/RECORD +32 -24
  31. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/WHEEL +0 -0
  32. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/entry_points.txt +0 -0
@@ -324,8 +324,7 @@ class S3CopyObjectOperator(BaseOperator):
324
324
  )
325
325
 
326
326
  def get_openlineage_facets_on_start(self):
327
- from openlineage.client.run import Dataset
328
-
327
+ from airflow.providers.common.compat.openlineage.facet import Dataset
329
328
  from airflow.providers.openlineage.extractors import OperatorLineage
330
329
 
331
330
  dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
@@ -439,8 +438,7 @@ class S3CreateObjectOperator(BaseOperator):
439
438
  s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy)
440
439
 
441
440
  def get_openlineage_facets_on_start(self):
442
- from openlineage.client.run import Dataset
443
-
441
+ from airflow.providers.common.compat.openlineage.facet import Dataset
444
442
  from airflow.providers.openlineage.extractors import OperatorLineage
445
443
 
446
444
  bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key")
@@ -546,13 +544,12 @@ class S3DeleteObjectsOperator(BaseOperator):
546
544
 
547
545
  def get_openlineage_facets_on_complete(self, task_instance):
548
546
  """Implement _on_complete because object keys are resolved in execute()."""
549
- from openlineage.client.facet import (
547
+ from airflow.providers.common.compat.openlineage.facet import (
548
+ Dataset,
550
549
  LifecycleStateChange,
551
550
  LifecycleStateChangeDatasetFacet,
552
- LifecycleStateChangeDatasetFacetPreviousIdentifier,
551
+ PreviousIdentifier,
553
552
  )
554
- from openlineage.client.run import Dataset
555
-
556
553
  from airflow.providers.openlineage.extractors import OperatorLineage
557
554
 
558
555
  if not self._keys:
@@ -570,7 +567,7 @@ class S3DeleteObjectsOperator(BaseOperator):
570
567
  facets={
571
568
  "lifecycleStateChange": LifecycleStateChangeDatasetFacet(
572
569
  lifecycleStateChange=LifecycleStateChange.DROP.value,
573
- previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
570
+ previousIdentifier=PreviousIdentifier(
574
571
  namespace=bucket_url,
575
572
  name=key,
576
573
  ),
@@ -610,6 +607,7 @@ class S3FileTransformOperator(BaseOperator):
610
607
  :param dest_s3_key: The key to be written from S3. (templated)
611
608
  :param transform_script: location of the executable transformation script
612
609
  :param select_expression: S3 Select expression
610
+ :param select_expr_serialization_config: A dictionary that contains input and output serialization configurations for S3 Select.
613
611
  :param script_args: arguments for transformation script (templated)
614
612
  :param source_aws_conn_id: source s3 connection
615
613
  :param source_verify: Whether or not to verify SSL certificates for S3 connection.
@@ -641,6 +639,7 @@ class S3FileTransformOperator(BaseOperator):
641
639
  dest_s3_key: str,
642
640
  transform_script: str | None = None,
643
641
  select_expression=None,
642
+ select_expr_serialization_config: dict[str, dict[str, dict]] | None = None,
644
643
  script_args: Sequence[str] | None = None,
645
644
  source_aws_conn_id: str | None = "aws_default",
646
645
  source_verify: bool | str | None = None,
@@ -659,6 +658,7 @@ class S3FileTransformOperator(BaseOperator):
659
658
  self.replace = replace
660
659
  self.transform_script = transform_script
661
660
  self.select_expression = select_expression
661
+ self.select_expr_serialization_config = select_expr_serialization_config or {}
662
662
  self.script_args = script_args or []
663
663
  self.output_encoding = sys.getdefaultencoding()
664
664
 
@@ -678,7 +678,14 @@ class S3FileTransformOperator(BaseOperator):
678
678
  self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name)
679
679
 
680
680
  if self.select_expression is not None:
681
- content = source_s3.select_key(key=self.source_s3_key, expression=self.select_expression)
681
+ input_serialization = self.select_expr_serialization_config.get("input_serialization")
682
+ output_serialization = self.select_expr_serialization_config.get("output_serialization")
683
+ content = source_s3.select_key(
684
+ key=self.source_s3_key,
685
+ expression=self.select_expression,
686
+ input_serialization=input_serialization,
687
+ output_serialization=output_serialization,
688
+ )
682
689
  f_source.write(content.encode("utf-8"))
683
690
  else:
684
691
  source_s3_key_object.download_fileobj(Fileobj=f_source)
@@ -715,8 +722,7 @@ class S3FileTransformOperator(BaseOperator):
715
722
  self.log.info("Upload successful")
716
723
 
717
724
  def get_openlineage_facets_on_start(self):
718
- from openlineage.client.run import Dataset
719
-
725
+ from airflow.providers.common.compat.openlineage.facet import Dataset
720
726
  from airflow.providers.openlineage.extractors import OperatorLineage
721
727
 
722
728
  dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
@@ -36,7 +36,6 @@ from airflow.providers.amazon.aws.hooks.sagemaker import (
36
36
  )
37
37
  from airflow.providers.amazon.aws.triggers.sagemaker import (
38
38
  SageMakerPipelineTrigger,
39
- SageMakerTrainingPrintLogTrigger,
40
39
  SageMakerTrigger,
41
40
  )
42
41
  from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
@@ -46,8 +45,7 @@ from airflow.utils.helpers import prune_dict
46
45
  from airflow.utils.json import AirflowJsonEncoder
47
46
 
48
47
  if TYPE_CHECKING:
49
- from openlineage.client.run import Dataset
50
-
48
+ from airflow.providers.common.compat.openlineage.facet import Dataset
51
49
  from airflow.providers.openlineage.extractors.base import OperatorLineage
52
50
  from airflow.utils.context import Context
53
51
 
@@ -208,7 +206,7 @@ class SageMakerBaseOperator(BaseOperator):
208
206
 
209
207
  @staticmethod
210
208
  def path_to_s3_dataset(path) -> Dataset:
211
- from openlineage.client.run import Dataset
209
+ from airflow.providers.common.compat.openlineage.facet import Dataset
212
210
 
213
211
  path = path.replace("s3://", "")
214
212
  split_path = path.split("/")
@@ -361,7 +359,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
361
359
  raise AirflowException(f"Error while running job: {event}")
362
360
 
363
361
  self.log.info(event["message"])
364
- self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
362
+ self.serialized_job = serialize(self.hook.describe_processing_job(event["job_name"]))
365
363
  self.log.info("%s completed successfully.", self.task_id)
366
364
  return {"Processing": self.serialized_job}
367
365
 
@@ -612,12 +610,11 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
612
610
 
613
611
  if event["status"] != "success":
614
612
  raise AirflowException(f"Error while running job: {event}")
615
- endpoint_info = self.config.get("Endpoint", self.config)
613
+
614
+ response = self.hook.describe_endpoint(event["job_name"])
616
615
  return {
617
- "EndpointConfig": serialize(
618
- self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
619
- ),
620
- "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
616
+ "EndpointConfig": serialize(self.hook.describe_endpoint_config(response["EndpointConfigName"])),
617
+ "Endpoint": serialize(self.hook.describe_endpoint(response["EndpointName"])),
621
618
  }
622
619
 
623
620
 
@@ -997,9 +994,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
997
994
 
998
995
  if event["status"] != "success":
999
996
  raise AirflowException(f"Error while running job: {event}")
1000
- return {
1001
- "Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
1002
- }
997
+ return {"Tuning": serialize(self.hook.describe_tuning_job(event["job_name"]))}
1003
998
 
1004
999
 
1005
1000
  class SageMakerModelOperator(SageMakerBaseOperator):
@@ -1199,25 +1194,15 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1199
1194
  if self.max_ingestion_time:
1200
1195
  timeout = datetime.timedelta(seconds=self.max_ingestion_time)
1201
1196
 
1202
- trigger: SageMakerTrainingPrintLogTrigger | SageMakerTrigger
1203
- if self.print_log:
1204
- trigger = SageMakerTrainingPrintLogTrigger(
1205
- job_name=self.config["TrainingJobName"],
1206
- poke_interval=self.check_interval,
1207
- aws_conn_id=self.aws_conn_id,
1208
- )
1209
- else:
1210
- trigger = SageMakerTrigger(
1197
+ self.defer(
1198
+ timeout=timeout,
1199
+ trigger=SageMakerTrigger(
1211
1200
  job_name=self.config["TrainingJobName"],
1212
1201
  job_type="Training",
1213
1202
  poke_interval=self.check_interval,
1214
1203
  max_attempts=self.max_attempts,
1215
1204
  aws_conn_id=self.aws_conn_id,
1216
- )
1217
-
1218
- self.defer(
1219
- timeout=timeout,
1220
- trigger=trigger,
1205
+ ),
1221
1206
  method_name="execute_complete",
1222
1207
  )
1223
1208
 
@@ -48,6 +48,8 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
48
48
 
49
49
  :param state_machine_arn: ARN of the Step Function State Machine
50
50
  :param name: The name of the execution.
51
+ :param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
52
+ complete successfully in the last 14 days.
51
53
  :param state_machine_input: JSON data input to pass to the State Machine
52
54
  :param aws_conn_id: The Airflow connection used for AWS credentials.
53
55
  If this is None or empty then the default boto3 behaviour is used. If
@@ -73,7 +75,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
73
75
  """
74
76
 
75
77
  aws_hook_class = StepFunctionHook
76
- template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input")
78
+ template_fields: Sequence[str] = aws_template_fields(
79
+ "state_machine_arn", "name", "input", "is_redrive_execution"
80
+ )
77
81
  ui_color = "#f9c915"
78
82
  operator_extra_links = (StateMachineDetailsLink(), StateMachineExecutionsDetailsLink())
79
83
 
@@ -82,6 +86,7 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
82
86
  *,
83
87
  state_machine_arn: str,
84
88
  name: str | None = None,
89
+ is_redrive_execution: bool = False,
85
90
  state_machine_input: dict | str | None = None,
86
91
  waiter_max_attempts: int = 30,
87
92
  waiter_delay: int = 60,
@@ -91,6 +96,7 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
91
96
  super().__init__(**kwargs)
92
97
  self.state_machine_arn = state_machine_arn
93
98
  self.name = name
99
+ self.is_redrive_execution = is_redrive_execution
94
100
  self.input = state_machine_input
95
101
  self.waiter_delay = waiter_delay
96
102
  self.waiter_max_attempts = waiter_max_attempts
@@ -105,7 +111,11 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
105
111
  state_machine_arn=self.state_machine_arn,
106
112
  )
107
113
 
108
- if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)):
114
+ if not (
115
+ execution_arn := self.hook.start_execution(
116
+ self.state_machine_arn, self.name, self.input, self.is_redrive_execution
117
+ )
118
+ ):
109
119
  raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
110
120
 
111
121
  StateMachineExecutionsDetailsLink.persist(
@@ -0,0 +1,234 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING, Any, Sequence
20
+
21
+ from airflow.configuration import conf
22
+ from airflow.exceptions import AirflowException, AirflowSkipException
23
+ from airflow.providers.amazon.aws.hooks.kinesis_analytics import KinesisAnalyticsV2Hook
24
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
25
+ from airflow.providers.amazon.aws.triggers.kinesis_analytics import (
26
+ KinesisAnalyticsV2ApplicationOperationCompleteTrigger,
27
+ )
28
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
29
+
30
+ if TYPE_CHECKING:
31
+ from airflow.utils.context import Context
32
+
33
+
34
+ class KinesisAnalyticsV2BaseSensor(AwsBaseSensor[KinesisAnalyticsV2Hook]):
35
+ """
36
+ General sensor behaviour for AWS Managed Service for Apache Flink.
37
+
38
+ Subclasses must set the following fields:
39
+ - ``INTERMEDIATE_STATES``
40
+ - ``FAILURE_STATES``
41
+ - ``SUCCESS_STATES``
42
+ - ``FAILURE_MESSAGE``
43
+ - ``SUCCESS_MESSAGE``
44
+
45
+ :param application_name: Application name.
46
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
47
+ module to be installed.
48
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
49
+
50
+ """
51
+
52
+ aws_hook_class = KinesisAnalyticsV2Hook
53
+ ui_color = "#66c3ff"
54
+
55
+ INTERMEDIATE_STATES: tuple[str, ...] = ()
56
+ FAILURE_STATES: tuple[str, ...] = ()
57
+ SUCCESS_STATES: tuple[str, ...] = ()
58
+ FAILURE_MESSAGE = ""
59
+ SUCCESS_MESSAGE = ""
60
+
61
+ def __init__(
62
+ self,
63
+ application_name: str,
64
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
65
+ **kwargs: Any,
66
+ ):
67
+ super().__init__(**kwargs)
68
+ self.application_name = application_name
69
+ self.deferrable = deferrable
70
+
71
+ def poke(self, context: Context, **kwargs) -> bool:
72
+ status = self.hook.conn.describe_application(ApplicationName=self.application_name)[
73
+ "ApplicationDetail"
74
+ ]["ApplicationStatus"]
75
+
76
+ self.log.info(
77
+ "Poking for AWS Managed Service for Apache Flink application: %s status: %s",
78
+ self.application_name,
79
+ status,
80
+ )
81
+
82
+ if status in self.FAILURE_STATES:
83
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
84
+ if self.soft_fail:
85
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
86
+ raise AirflowException(self.FAILURE_MESSAGE)
87
+
88
+ if status in self.SUCCESS_STATES:
89
+ self.log.info(
90
+ "%s `%s`.",
91
+ self.SUCCESS_MESSAGE,
92
+ self.application_name,
93
+ )
94
+ return True
95
+
96
+ return False
97
+
98
+
99
+ class KinesisAnalyticsV2StartApplicationCompletedSensor(KinesisAnalyticsV2BaseSensor):
100
+ """
101
+ Waits for AWS Managed Service for Apache Flink application to start.
102
+
103
+ .. seealso::
104
+ For more information on how to use this sensor, take a look at the guide:
105
+ :ref:`howto/sensor:KinesisAnalyticsV2StartApplicationCompletedSensor`
106
+
107
+ :param application_name: Application name.
108
+
109
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
110
+ module to be installed.
111
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
112
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
113
+ :param max_retries: Number of times before returning the current state. (default: 75)
114
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
115
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
116
+ running Airflow in a distributed manner and aws_conn_id is None or
117
+ empty, then default boto3 configuration would be used (and must be
118
+ maintained on each worker node).
119
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
120
+ :param verify: Whether to verify SSL certificates. See:
121
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
122
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
123
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
124
+
125
+ """
126
+
127
+ INTERMEDIATE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_START_INTERMEDIATE_STATES
128
+ FAILURE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_START_FAILURE_STATES
129
+ SUCCESS_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_START_SUCCESS_STATES
130
+
131
+ FAILURE_MESSAGE = "AWS Managed Service for Apache Flink application start failed."
132
+ SUCCESS_MESSAGE = "AWS Managed Service for Apache Flink application started successfully"
133
+
134
+ template_fields: Sequence[str] = aws_template_fields("application_name")
135
+
136
+ def __init__(
137
+ self,
138
+ *,
139
+ application_name: str,
140
+ max_retries: int = 75,
141
+ poke_interval: int = 120,
142
+ **kwargs: Any,
143
+ ) -> None:
144
+ super().__init__(application_name=application_name, **kwargs)
145
+ self.application_name = application_name
146
+ self.max_retries = max_retries
147
+ self.poke_interval = poke_interval
148
+
149
+ def execute(self, context: Context) -> Any:
150
+ if self.deferrable:
151
+ self.defer(
152
+ trigger=KinesisAnalyticsV2ApplicationOperationCompleteTrigger(
153
+ application_name=self.application_name,
154
+ waiter_name="application_start_complete",
155
+ aws_conn_id=self.aws_conn_id,
156
+ waiter_delay=int(self.poke_interval),
157
+ waiter_max_attempts=self.max_retries,
158
+ region_name=self.region_name,
159
+ verify=self.verify,
160
+ botocore_config=self.botocore_config,
161
+ ),
162
+ method_name="poke",
163
+ )
164
+ else:
165
+ super().execute(context=context)
166
+
167
+
168
+ class KinesisAnalyticsV2StopApplicationCompletedSensor(KinesisAnalyticsV2BaseSensor):
169
+ """
170
+ Waits for AWS Managed Service for Apache Flink application to stop.
171
+
172
+ .. seealso::
173
+ For more information on how to use this sensor, take a look at the guide:
174
+ :ref:`howto/sensor:KinesisAnalyticsV2StopApplicationCompletedSensor`
175
+
176
+ :param application_name: Application name.
177
+
178
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
179
+ module to be installed.
180
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
181
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
182
+ :param max_retries: Number of times before returning the current state. (default: 75)
183
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
184
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
185
+ running Airflow in a distributed manner and aws_conn_id is None or
186
+ empty, then default boto3 configuration would be used (and must be
187
+ maintained on each worker node).
188
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
189
+ :param verify: Whether to verify SSL certificates. See:
190
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
191
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
192
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
193
+
194
+ """
195
+
196
+ INTERMEDIATE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_STOP_INTERMEDIATE_STATES
197
+ FAILURE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_STOP_FAILURE_STATES
198
+ SUCCESS_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_STOP_SUCCESS_STATES
199
+
200
+ FAILURE_MESSAGE = "AWS Managed Service for Apache Flink application stop failed."
201
+ SUCCESS_MESSAGE = "AWS Managed Service for Apache Flink application stopped successfully"
202
+
203
+ template_fields: Sequence[str] = aws_template_fields("application_name")
204
+
205
+ def __init__(
206
+ self,
207
+ *,
208
+ application_name: str,
209
+ max_retries: int = 75,
210
+ poke_interval: int = 120,
211
+ **kwargs: Any,
212
+ ) -> None:
213
+ super().__init__(application_name=application_name, **kwargs)
214
+ self.application_name = application_name
215
+ self.max_retries = max_retries
216
+ self.poke_interval = poke_interval
217
+
218
+ def execute(self, context: Context) -> Any:
219
+ if self.deferrable:
220
+ self.defer(
221
+ trigger=KinesisAnalyticsV2ApplicationOperationCompleteTrigger(
222
+ application_name=self.application_name,
223
+ waiter_name="application_stop_complete",
224
+ aws_conn_id=self.aws_conn_id,
225
+ waiter_delay=int(self.poke_interval),
226
+ waiter_max_attempts=self.max_retries,
227
+ region_name=self.region_name,
228
+ verify=self.verify,
229
+ botocore_config=self.botocore_config,
230
+ ),
231
+ method_name="poke",
232
+ )
233
+ else:
234
+ super().execute(context=context)
@@ -18,6 +18,7 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import fnmatch
21
+ import inspect
21
22
  import os
22
23
  import re
23
24
  from datetime import datetime, timedelta
@@ -57,13 +58,13 @@ class S3KeySensor(BaseSensorOperator):
57
58
  refers to this bucket
58
59
  :param wildcard_match: whether the bucket_key should be interpreted as a
59
60
  Unix wildcard pattern
60
- :param check_fn: Function that receives the list of the S3 objects,
61
+ :param check_fn: Function that receives the list of the S3 objects with the context values,
61
62
  and returns a boolean:
62
63
  - ``True``: the criteria is met
63
64
  - ``False``: the criteria isn't met
64
65
  **Example**: Wait for any S3 object size more than 1 megabyte ::
65
66
 
66
- def check_fn(files: List) -> bool:
67
+ def check_fn(files: List, **kwargs) -> bool:
67
68
  return any(f.get('Size', 0) > 1048576 for f in files)
68
69
  :param aws_conn_id: a reference to the s3 connection
69
70
  :param verify: Whether to verify SSL certificates for S3 connection.
@@ -112,7 +113,7 @@ class S3KeySensor(BaseSensorOperator):
112
113
  self.use_regex = use_regex
113
114
  self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
114
115
 
115
- def _check_key(self, key):
116
+ def _check_key(self, key, context: Context):
116
117
  bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
117
118
  self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
118
119
 
@@ -167,15 +168,20 @@ class S3KeySensor(BaseSensorOperator):
167
168
  files = [metadata]
168
169
 
169
170
  if self.check_fn is not None:
171
+ # For backwards compatibility, check if the function takes a context argument
172
+ signature = inspect.signature(self.check_fn)
173
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
174
+ return self.check_fn(files, **context)
175
+ # Otherwise, just pass the files
170
176
  return self.check_fn(files)
171
177
 
172
178
  return True
173
179
 
174
180
  def poke(self, context: Context):
175
181
  if isinstance(self.bucket_key, str):
176
- return self._check_key(self.bucket_key)
182
+ return self._check_key(self.bucket_key, context=context)
177
183
  else:
178
- return all(self._check_key(key) for key in self.bucket_key)
184
+ return all(self._check_key(key, context=context) for key in self.bucket_key)
179
185
 
180
186
  def execute(self, context: Context) -> None:
181
187
  """Airflow runs this method on the worker and defers using the trigger."""
@@ -84,6 +84,7 @@ class RedshiftToS3Operator(BaseOperator):
84
84
  "unload_options",
85
85
  "select_query",
86
86
  "redshift_conn_id",
87
+ "redshift_data_api_kwargs",
87
88
  )
88
89
  template_ext: Sequence[str] = (".sql",)
89
90
  template_fields_renderers = {"select_query": "sql"}
@@ -77,6 +77,7 @@ class S3ToRedshiftOperator(BaseOperator):
77
77
  "copy_options",
78
78
  "redshift_conn_id",
79
79
  "method",
80
+ "redshift_data_api_kwargs",
80
81
  "aws_conn_id",
81
82
  )
82
83
  template_ext: Sequence[str] = ()
@@ -16,6 +16,7 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
+ import sys
19
20
  import warnings
20
21
  from typing import TYPE_CHECKING
21
22
 
@@ -174,6 +175,7 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
174
175
  :param job_id: job_id to check the state
175
176
  :param aws_conn_id: Reference to AWS connection id
176
177
  :param waiter_delay: polling period in seconds to check for the status
178
+ :param waiter_max_attempts: The maximum number of attempts to be made. Defaults to an infinite wait.
177
179
  """
178
180
 
179
181
  def __init__(
@@ -183,7 +185,7 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
183
185
  aws_conn_id: str | None = "aws_default",
184
186
  poll_interval: int | None = None, # deprecated
185
187
  waiter_delay: int = 30,
186
- waiter_max_attempts: int = 600,
188
+ waiter_max_attempts: int = sys.maxsize,
187
189
  ):
188
190
  if poll_interval is not None:
189
191
  warnings.warn(
@@ -0,0 +1,69 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from airflow.providers.amazon.aws.hooks.kinesis_analytics import KinesisAnalyticsV2Hook
22
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
23
+
24
+ if TYPE_CHECKING:
25
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
26
+
27
+
28
+ class KinesisAnalyticsV2ApplicationOperationCompleteTrigger(AwsBaseWaiterTrigger):
29
+ """
30
+ Trigger when a Managed Service for Apache Flink application Start or Stop is complete.
31
+
32
+ :param application_name: Application name.
33
+ :param waiter_name: The name of the waiter for stop or start application.
34
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
35
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
36
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ application_name: str,
42
+ waiter_name: str,
43
+ waiter_delay: int = 120,
44
+ waiter_max_attempts: int = 75,
45
+ aws_conn_id: str | None = "aws_default",
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(
49
+ serialized_fields={"application_name": application_name, "waiter_name": waiter_name},
50
+ waiter_name=waiter_name,
51
+ waiter_args={"ApplicationName": application_name},
52
+ failure_message=f"AWS Managed Service for Apache Flink Application {application_name} failed.",
53
+ status_message=f"Status of AWS Managed Service for Apache Flink Application {application_name} is",
54
+ status_queries=["ApplicationDetail.ApplicationStatus"],
55
+ return_key="application_name",
56
+ return_value=application_name,
57
+ waiter_delay=waiter_delay,
58
+ waiter_max_attempts=waiter_max_attempts,
59
+ aws_conn_id=aws_conn_id,
60
+ **kwargs,
61
+ )
62
+
63
+ def hook(self) -> AwsGenericHook:
64
+ return KinesisAnalyticsV2Hook(
65
+ aws_conn_id=self.aws_conn_id,
66
+ region_name=self.region_name,
67
+ verify=self.verify,
68
+ config=self.botocore_config,
69
+ )
@@ -25,8 +25,9 @@ from functools import cached_property
25
25
  from typing import Any, AsyncIterator
26
26
 
27
27
  from botocore.exceptions import WaiterError
28
+ from deprecated import deprecated
28
29
 
29
- from airflow.exceptions import AirflowException
30
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
30
31
  from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
31
32
  from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
32
33
  from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -199,6 +200,13 @@ class SageMakerPipelineTrigger(BaseTrigger):
199
200
  raise AirflowException("Waiter error: max attempts reached")
200
201
 
201
202
 
203
+ @deprecated(
204
+ reason=(
205
+ "`airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger` "
206
+ "has been deprecated and will be removed in future. Please use ``SageMakerTrigger`` instead."
207
+ ),
208
+ category=AirflowProviderDeprecationWarning,
209
+ )
202
210
  class SageMakerTrainingPrintLogTrigger(BaseTrigger):
203
211
  """
204
212
  SageMakerTrainingPrintLogTrigger is fired as deferred class with params to run the task in triggerer.