apache-airflow-providers-amazon 9.5.0rc1__py3-none-any.whl → 9.5.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. airflow/providers/amazon/aws/auth_manager/avp/entities.py +2 -0
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +67 -18
  3. airflow/providers/amazon/aws/auth_manager/router/login.py +10 -4
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
  5. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  6. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  7. airflow/providers/amazon/aws/hooks/base_aws.py +9 -1
  8. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  9. airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
  10. airflow/providers/amazon/aws/hooks/dms.py +3 -1
  11. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  12. airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
  13. airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
  14. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  15. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  16. airflow/providers/amazon/aws/links/athena.py +1 -2
  17. airflow/providers/amazon/aws/links/base_aws.py +2 -1
  18. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
  19. airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
  20. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  21. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  22. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  23. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  24. airflow/providers/amazon/aws/operators/eks.py +3 -3
  25. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  26. airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
  27. airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
  28. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  29. airflow/providers/amazon/aws/sensors/glacier.py +1 -1
  30. airflow/providers/amazon/aws/sensors/mwaa.py +59 -11
  31. airflow/providers/amazon/aws/sensors/s3.py +1 -1
  32. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  33. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
  34. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  35. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  36. airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
  37. airflow/providers/amazon/aws/triggers/base.py +10 -1
  38. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  39. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  40. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  41. airflow/providers/amazon/get_provider_info.py +11 -5
  42. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/METADATA +9 -7
  43. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/RECORD +45 -43
  44. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/WHEEL +0 -0
  45. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/entry_points.txt +0 -0
@@ -44,7 +44,6 @@ from airflow.providers.amazon.aws.utils import trim_none_values, validate_execut
44
44
  from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
45
45
  from airflow.providers.amazon.aws.utils.tags import format_tags
46
46
  from airflow.utils.helpers import prune_dict
47
- from airflow.utils.json import AirflowJsonEncoder
48
47
 
49
48
  if TYPE_CHECKING:
50
49
  from airflow.providers.common.compat.openlineage.facet import Dataset
@@ -56,7 +55,7 @@ CHECK_INTERVAL_SECOND: int = 30
56
55
 
57
56
 
58
57
  def serialize(result: dict) -> dict:
59
- return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
58
+ return json.loads(json.dumps(result, default=repr))
60
59
 
61
60
 
62
61
  class SageMakerBaseOperator(BaseOperator):
@@ -171,7 +170,7 @@ class SageMakerBaseOperator(BaseOperator):
171
170
  timestamp = str(
172
171
  time.time_ns() // 1000000000
173
172
  ) # only keep the relevant datetime (first 10 digits)
174
- name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
173
+ name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
175
174
  self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
176
175
  return name
177
176
 
@@ -179,8 +178,7 @@ class SageMakerBaseOperator(BaseOperator):
179
178
  """Raise exception if resource type is not 'model' or 'job'."""
180
179
  if resource_type not in ("model", "job"):
181
180
  raise AirflowException(
182
- "Argument resource_type accepts only 'model' and 'job'. "
183
- f"Provided value: '{resource_type}'."
181
+ f"Argument resource_type accepts only 'model' and 'job'. Provided value: '{resource_type}'."
184
182
  )
185
183
 
186
184
  def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
@@ -560,8 +558,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
560
558
  self.operation = "update"
561
559
  sagemaker_operation = self.hook.update_endpoint
562
560
  self.log.warning(
563
- "cannot create already existing endpoint %s, "
564
- "updating it with the given config instead",
561
+ "cannot create already existing endpoint %s, updating it with the given config instead",
565
562
  endpoint_info["EndpointName"],
566
563
  )
567
564
  if "Tags" in endpoint_info:
@@ -18,21 +18,21 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from collections.abc import Sequence
21
- from functools import cached_property
22
21
  from typing import TYPE_CHECKING, Any
23
22
 
24
23
  from airflow.configuration import conf
25
24
  from airflow.exceptions import AirflowException
26
25
  from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
26
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
27
  from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
28
28
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
29
- from airflow.sensors.base import BaseSensorOperator
29
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from airflow.utils.context import Context
33
33
 
34
34
 
35
- class EC2InstanceStateSensor(BaseSensorOperator):
35
+ class EC2InstanceStateSensor(AwsBaseSensor[EC2Hook]):
36
36
  """
37
37
  Poll the state of the AWS EC2 instance until the instance reaches the target state.
38
38
 
@@ -46,7 +46,8 @@ class EC2InstanceStateSensor(BaseSensorOperator):
46
46
  :param deferrable: if True, the sensor will run in deferrable mode
47
47
  """
48
48
 
49
- template_fields: Sequence[str] = ("target_state", "instance_id", "region_name")
49
+ aws_hook_class = EC2Hook
50
+ template_fields: Sequence[str] = aws_template_fields("target_state", "instance_id", "region_name")
50
51
  ui_color = "#cc8811"
51
52
  ui_fgcolor = "#ffffff"
52
53
  valid_states = ["running", "stopped", "terminated"]
@@ -56,8 +57,6 @@ class EC2InstanceStateSensor(BaseSensorOperator):
56
57
  *,
57
58
  target_state: str,
58
59
  instance_id: str,
59
- aws_conn_id: str | None = "aws_default",
60
- region_name: str | None = None,
61
60
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
62
61
  **kwargs,
63
62
  ):
@@ -66,8 +65,6 @@ class EC2InstanceStateSensor(BaseSensorOperator):
66
65
  super().__init__(**kwargs)
67
66
  self.target_state = target_state
68
67
  self.instance_id = instance_id
69
- self.aws_conn_id = aws_conn_id
70
- self.region_name = region_name
71
68
  self.deferrable = deferrable
72
69
 
73
70
  def execute(self, context: Context) -> Any:
@@ -85,10 +82,6 @@ class EC2InstanceStateSensor(BaseSensorOperator):
85
82
  else:
86
83
  super().execute(context=context)
87
84
 
88
- @cached_property
89
- def hook(self):
90
- return EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
91
-
92
85
  def poke(self, context: Context):
93
86
  instance_state = self.hook.get_instance_state(instance_id=self.instance_id)
94
87
  self.log.info("instance state: %s", instance_state)
@@ -95,5 +95,5 @@ class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
95
95
  return False
96
96
  else:
97
97
  raise AirflowException(
98
- f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}'
98
+ f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
99
99
  )
@@ -18,13 +18,16 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from collections.abc import Collection, Sequence
21
- from typing import TYPE_CHECKING
21
+ from typing import TYPE_CHECKING, Any
22
22
 
23
+ from airflow.configuration import conf
23
24
  from airflow.exceptions import AirflowException
24
25
  from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
25
26
  from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
+ from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger
28
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
26
29
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
27
- from airflow.utils.state import State
30
+ from airflow.utils.state import DagRunState
28
31
 
29
32
  if TYPE_CHECKING:
30
33
  from airflow.utils.context import Context
@@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
46
49
  (templated)
47
50
  :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
48
51
  :param success_states: Collection of DAG Run states that would make this task marked as successful, default is
49
- ``airflow.utils.state.State.success_states`` (templated)
52
+ ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
50
53
  :param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
51
- AirflowException, default is ``airflow.utils.state.State.failed_states`` (templated)
54
+ AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
55
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
56
+ module to be installed.
57
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
58
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
59
+ :param max_retries: Number of times before returning the current state. (default: 720)
60
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
61
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
62
+ running Airflow in a distributed manner and aws_conn_id is None or
63
+ empty, then default boto3 configuration would be used (and must be
64
+ maintained on each worker node).
65
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
66
+ :param verify: Whether or not to verify SSL certificates. See:
67
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
68
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
69
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
52
70
  """
53
71
 
54
72
  aws_hook_class = MwaaHook
@@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
58
76
  "external_dag_run_id",
59
77
  "success_states",
60
78
  "failure_states",
79
+ "deferrable",
80
+ "max_retries",
81
+ "poke_interval",
61
82
  )
62
83
 
63
84
  def __init__(
@@ -68,19 +89,25 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
68
89
  external_dag_run_id: str,
69
90
  success_states: Collection[str] | None = None,
70
91
  failure_states: Collection[str] | None = None,
92
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
93
+ poke_interval: int = 60,
94
+ max_retries: int = 720,
71
95
  **kwargs,
72
96
  ):
73
97
  super().__init__(**kwargs)
74
98
 
75
- self.success_states = set(success_states if success_states else State.success_states)
76
- self.failure_states = set(failure_states if failure_states else State.failed_states)
99
+ self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
100
+ self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
77
101
 
78
102
  if len(self.success_states & self.failure_states):
79
- raise AirflowException("allowed_states and failed_states must not have any values in common")
103
+ raise ValueError("success_states and failure_states must not have any values in common")
80
104
 
81
105
  self.external_env_name = external_env_name
82
106
  self.external_dag_id = external_dag_id
83
107
  self.external_dag_run_id = external_dag_run_id
108
+ self.deferrable = deferrable
109
+ self.poke_interval = poke_interval
110
+ self.max_retries = max_retries
84
111
 
85
112
  def poke(self, context: Context) -> bool:
86
113
  self.log.info(
@@ -102,12 +129,33 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
102
129
  # The scope of this sensor is going to only be raising AirflowException due to failure of the DAGRun
103
130
 
104
131
  state = response["RestApiResponse"]["state"]
105
- if state in self.success_states:
106
- return True
107
132
 
108
133
  if state in self.failure_states:
109
134
  raise AirflowException(
110
135
  f"The DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} "
111
- f"failed with state {state}."
136
+ f"failed with state: {state}"
112
137
  )
113
- return False
138
+
139
+ return state in self.success_states
140
+
141
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
142
+ validate_execute_complete_event(event)
143
+
144
+ def execute(self, context: Context):
145
+ if self.deferrable:
146
+ self.defer(
147
+ trigger=MwaaDagRunCompletedTrigger(
148
+ external_env_name=self.external_env_name,
149
+ external_dag_id=self.external_dag_id,
150
+ external_dag_run_id=self.external_dag_run_id,
151
+ success_states=self.success_states,
152
+ failure_states=self.failure_states,
153
+ # somehow the type of poke_interval is derived as float ??
154
+ waiter_delay=self.poke_interval, # type: ignore[arg-type]
155
+ waiter_max_attempts=self.max_retries,
156
+ aws_conn_id=self.aws_conn_id,
157
+ ),
158
+ method_name="execute_complete",
159
+ )
160
+ else:
161
+ super().execute(context=context)
@@ -192,7 +192,7 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
192
192
  self.defer(
193
193
  timeout=timedelta(seconds=self.timeout),
194
194
  trigger=S3KeyTrigger(
195
- bucket_name=cast(str, self.bucket_name),
195
+ bucket_name=cast("str", self.bucket_name),
196
196
  bucket_key=self.bucket_key,
197
197
  wildcard_match=self.wildcard_match,
198
198
  aws_conn_id=self.aws_conn_id,
@@ -81,5 +81,6 @@ class StepFunctionExecutionSensor(AwsBaseSensor[StepFunctionHook]):
81
81
  return False
82
82
 
83
83
  self.log.info("Doing xcom_push of output")
84
- self.xcom_push(context, "output", output)
84
+
85
+ context["ti"].xcom_push(key="output", value=output)
85
86
  return True
@@ -103,7 +103,7 @@ class MongoToS3Operator(BaseOperator):
103
103
  if self.is_pipeline:
104
104
  results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate(
105
105
  mongo_collection=self.mongo_collection,
106
- aggregate_query=cast(list, self.mongo_query),
106
+ aggregate_query=cast("list", self.mongo_query),
107
107
  mongo_db=self.mongo_db,
108
108
  allowDiskUse=self.allow_disk_use,
109
109
  )
@@ -111,7 +111,7 @@ class MongoToS3Operator(BaseOperator):
111
111
  else:
112
112
  results = MongoHook(self.mongo_conn_id).find(
113
113
  mongo_collection=self.mongo_collection,
114
- query=cast(dict, self.mongo_query),
114
+ query=cast("dict", self.mongo_query),
115
115
  projection=self.mongo_projection,
116
116
  mongo_db=self.mongo_db,
117
117
  find_one=False,
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
29
29
  from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
30
30
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
31
31
  from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
32
+ from airflow.utils.types import NOTSET, ArgNotSet
32
33
 
33
34
  if TYPE_CHECKING:
34
35
  from airflow.utils.context import Context
@@ -102,7 +103,7 @@ class RedshiftToS3Operator(BaseOperator):
102
103
  table: str | None = None,
103
104
  select_query: str | None = None,
104
105
  redshift_conn_id: str = "redshift_default",
105
- aws_conn_id: str | None = "aws_default",
106
+ aws_conn_id: str | None | ArgNotSet = NOTSET,
106
107
  verify: bool | str | None = None,
107
108
  unload_options: list | None = None,
108
109
  autocommit: bool = False,
@@ -118,7 +119,6 @@ class RedshiftToS3Operator(BaseOperator):
118
119
  self.schema = schema
119
120
  self.table = table
120
121
  self.redshift_conn_id = redshift_conn_id
121
- self.aws_conn_id = aws_conn_id
122
122
  self.verify = verify
123
123
  self.unload_options = unload_options or []
124
124
  self.autocommit = autocommit
@@ -127,6 +127,16 @@ class RedshiftToS3Operator(BaseOperator):
127
127
  self.table_as_file_name = table_as_file_name
128
128
  self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
129
129
  self.select_query = select_query
130
+ # In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
131
+ # actually provide a connection note that, because we don't want to let the exception bubble up in
132
+ # that case (since we're silently injecting a connection on their behalf).
133
+ self._aws_conn_id: str | None
134
+ if isinstance(aws_conn_id, ArgNotSet):
135
+ self.conn_set = False
136
+ self._aws_conn_id = "aws_default"
137
+ else:
138
+ self.conn_set = True
139
+ self._aws_conn_id = aws_conn_id
130
140
 
131
141
  def _build_unload_query(
132
142
  self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
@@ -176,11 +186,16 @@ class RedshiftToS3Operator(BaseOperator):
176
186
  raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
177
187
  else:
178
188
  redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
179
- conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
189
+ conn = (
190
+ S3Hook.get_connection(conn_id=self._aws_conn_id)
191
+ # Only fetch the connection if it was set by the user and it is not None
192
+ if self.conn_set and self._aws_conn_id
193
+ else None
194
+ )
180
195
  if conn and conn.extra_dejson.get("role_arn", False):
181
196
  credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
182
197
  else:
183
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
198
+ s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
184
199
  credentials = s3_hook.get_credentials()
185
200
  credentials_block = build_credentials_block(credentials)
186
201
 
@@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
25
25
  from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
26
26
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
27
27
  from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
28
+ from airflow.utils.types import NOTSET, ArgNotSet
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from airflow.utils.context import Context
@@ -93,7 +94,7 @@ class S3ToRedshiftOperator(BaseOperator):
93
94
  s3_key: str,
94
95
  schema: str | None = None,
95
96
  redshift_conn_id: str = "redshift_default",
96
- aws_conn_id: str | None = "aws_default",
97
+ aws_conn_id: str | None | ArgNotSet = NOTSET,
97
98
  verify: bool | str | None = None,
98
99
  column_list: list[str] | None = None,
99
100
  copy_options: list | None = None,
@@ -117,6 +118,16 @@ class S3ToRedshiftOperator(BaseOperator):
117
118
  self.method = method
118
119
  self.upsert_keys = upsert_keys
119
120
  self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
121
+ # In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
122
+ # actually provide a connection note that, because we don't want to let the exception bubble up in
123
+ # that case (since we're silently injecting a connection on their behalf).
124
+ self._aws_conn_id: str | None
125
+ if isinstance(aws_conn_id, ArgNotSet):
126
+ self.conn_set = False
127
+ self._aws_conn_id = "aws_default"
128
+ else:
129
+ self.conn_set = True
130
+ self._aws_conn_id = aws_conn_id
120
131
 
121
132
  if self.redshift_data_api_kwargs:
122
133
  for arg in ["sql", "parameters"]:
@@ -149,14 +160,19 @@ class S3ToRedshiftOperator(BaseOperator):
149
160
  else:
150
161
  redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
151
162
 
152
- conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
163
+ conn = (
164
+ S3Hook.get_connection(conn_id=self._aws_conn_id)
165
+ # Only fetch the connection if it was set by the user and it is not None
166
+ if self.conn_set and self._aws_conn_id
167
+ else None
168
+ )
153
169
  region_info = ""
154
170
  if conn and conn.extra_dejson.get("region", False):
155
171
  region_info = f"region '{conn.extra_dejson['region']}'"
156
172
  if conn and conn.extra_dejson.get("role_arn", False):
157
173
  credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
158
174
  else:
159
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
175
+ s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
160
176
  credentials = s3_hook.get_credentials()
161
177
  credentials_block = build_credentials_block(credentials)
162
178
 
@@ -223,7 +223,7 @@ class SqlToS3Operator(BaseOperator):
223
223
  return
224
224
  for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
225
225
  yield (
226
- cast(str, group_label),
226
+ cast("str", group_label),
227
227
  grouped_df.get_group(group_label)
228
228
  .drop(random_column_name, axis=1, errors="ignore")
229
229
  .reset_index(drop=True),
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
55
55
 
56
56
  :param waiter_delay: The amount of time in seconds to wait between attempts.
57
57
  :param waiter_max_attempts: The maximum number of attempts to be made.
58
+ :param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will
59
+ be updated.
58
60
  :param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
59
61
  :param region_name: The AWS region where the resources to watch are. To be used to build the hook.
60
62
  :param verify: Whether or not to verify SSL certificates. To be used to build the hook.
@@ -77,6 +79,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
77
79
  return_value: Any,
78
80
  waiter_delay: int,
79
81
  waiter_max_attempts: int,
82
+ waiter_config_overrides: dict[str, Any] | None = None,
80
83
  aws_conn_id: str | None,
81
84
  region_name: str | None = None,
82
85
  verify: bool | str | None = None,
@@ -91,6 +94,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
91
94
  self.failure_message = failure_message
92
95
  self.status_message = status_message
93
96
  self.status_queries = status_queries
97
+ self.waiter_config_overrides = waiter_config_overrides
94
98
 
95
99
  self.return_key = return_key
96
100
  self.return_value = return_value
@@ -140,7 +144,12 @@ class AwsBaseWaiterTrigger(BaseTrigger):
140
144
  async def run(self) -> AsyncIterator[TriggerEvent]:
141
145
  hook = self.hook()
142
146
  async with await hook.get_async_conn() as client:
143
- waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client)
147
+ waiter = hook.get_waiter(
148
+ self.waiter_name,
149
+ deferrable=True,
150
+ client=client,
151
+ config_overrides=self.waiter_config_overrides,
152
+ )
144
153
  await async_wait(
145
154
  waiter,
146
155
  self.waiter_delay,
@@ -0,0 +1,128 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ from collections.abc import Collection
21
+ from typing import TYPE_CHECKING
22
+
23
+ from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
24
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
25
+ from airflow.utils.state import DagRunState
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
29
+
30
+
31
+ class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
32
+ """
33
+ Trigger when an MWAA Dag Run is complete.
34
+
35
+ :param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
36
+ (templated)
37
+ :param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
38
+ (templated)
39
+ :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
40
+ :param success_states: Collection of DAG Run states that would make this task marked as successful, default is
41
+ ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
42
+ :param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
43
+ AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
44
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
45
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
46
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ *,
52
+ external_env_name: str,
53
+ external_dag_id: str,
54
+ external_dag_run_id: str,
55
+ success_states: Collection[str] | None = None,
56
+ failure_states: Collection[str] | None = None,
57
+ waiter_delay: int = 60,
58
+ waiter_max_attempts: int = 720,
59
+ aws_conn_id: str | None = None,
60
+ ) -> None:
61
+ self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
62
+ self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
63
+
64
+ if len(self.success_states & self.failure_states):
65
+ raise ValueError("success_states and failure_states must not have any values in common")
66
+
67
+ in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states
68
+
69
+ super().__init__(
70
+ serialized_fields={
71
+ "external_env_name": external_env_name,
72
+ "external_dag_id": external_dag_id,
73
+ "external_dag_run_id": external_dag_run_id,
74
+ "success_states": success_states,
75
+ "failure_states": failure_states,
76
+ },
77
+ waiter_name="mwaa_dag_run_complete",
78
+ waiter_args={
79
+ "Name": external_env_name,
80
+ "Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
81
+ "Method": "GET",
82
+ },
83
+ failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
84
+ status_message="State of DAG run",
85
+ status_queries=["RestApiResponse.state"],
86
+ return_key="dag_run_id",
87
+ return_value=external_dag_run_id,
88
+ waiter_delay=waiter_delay,
89
+ waiter_max_attempts=waiter_max_attempts,
90
+ aws_conn_id=aws_conn_id,
91
+ waiter_config_overrides={
92
+ "acceptors": _build_waiter_acceptors(
93
+ success_states=self.success_states,
94
+ failure_states=self.failure_states,
95
+ in_progress_states=in_progress_states,
96
+ )
97
+ },
98
+ )
99
+
100
+ def hook(self) -> AwsGenericHook:
101
+ return MwaaHook(
102
+ aws_conn_id=self.aws_conn_id,
103
+ region_name=self.region_name,
104
+ verify=self.verify,
105
+ config=self.botocore_config,
106
+ )
107
+
108
+
109
+ def _build_waiter_acceptors(
110
+ success_states: set[str], failure_states: set[str], in_progress_states: set[str]
111
+ ) -> list:
112
+ acceptors = []
113
+ for state_set, state_waiter_category in (
114
+ (success_states, "success"),
115
+ (failure_states, "failure"),
116
+ (in_progress_states, "retry"),
117
+ ):
118
+ for dag_run_state in state_set:
119
+ acceptors.append(
120
+ {
121
+ "matcher": "path",
122
+ "argument": "RestApiResponse.state",
123
+ "expected": dag_run_state,
124
+ "state": state_waiter_category,
125
+ }
126
+ )
127
+
128
+ return acceptors
@@ -136,15 +136,16 @@ async def async_wait(
136
136
  last_response = error.last_response
137
137
 
138
138
  if "terminal failure" in error_reason:
139
- log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, last_response))
140
- raise AirflowException(f"{failure_message}: {error}")
139
+ raise AirflowException(
140
+ f"{failure_message}: {_LazyStatusFormatter(status_args, last_response)}\n{error}"
141
+ )
141
142
 
142
143
  if (
143
144
  "An error occurred" in error_reason
144
145
  and isinstance(last_response.get("Error"), dict)
145
146
  and "Code" in last_response.get("Error")
146
147
  ):
147
- raise AirflowException(f"{failure_message}: {error}")
148
+ raise AirflowException(f"{failure_message}\n{last_response}\n{error}")
148
149
 
149
150
  log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
150
151
  else:
@@ -0,0 +1,36 @@
1
+ {
2
+ "version": 2,
3
+ "waiters": {
4
+ "mwaa_dag_run_complete": {
5
+ "delay": 60,
6
+ "maxAttempts": 720,
7
+ "operation": "InvokeRestApi",
8
+ "acceptors": [
9
+ {
10
+ "matcher": "path",
11
+ "argument": "RestApiResponse.state",
12
+ "expected": "queued",
13
+ "state": "retry"
14
+ },
15
+ {
16
+ "matcher": "path",
17
+ "argument": "RestApiResponse.state",
18
+ "expected": "running",
19
+ "state": "retry"
20
+ },
21
+ {
22
+ "matcher": "path",
23
+ "argument": "RestApiResponse.state",
24
+ "expected": "success",
25
+ "state": "success"
26
+ },
27
+ {
28
+ "matcher": "path",
29
+ "argument": "RestApiResponse.state",
30
+ "expected": "failed",
31
+ "state": "failure"
32
+ }
33
+ ]
34
+ }
35
+ }
36
+ }