apache-airflow-providers-amazon 9.5.0rc1__py3-none-any.whl → 9.5.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +2 -0
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +67 -18
- airflow/providers/amazon/aws/auth_manager/router/login.py +10 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +9 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/links/base_aws.py +2 -1
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +59 -11
- airflow/providers/amazon/aws/sensors/s3.py +1 -1
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/base.py +10 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +11 -5
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/METADATA +9 -7
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/RECORD +45 -43
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/entry_points.txt +0 -0
@@ -44,7 +44,6 @@ from airflow.providers.amazon.aws.utils import trim_none_values, validate_execut
|
|
44
44
|
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
|
45
45
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
46
46
|
from airflow.utils.helpers import prune_dict
|
47
|
-
from airflow.utils.json import AirflowJsonEncoder
|
48
47
|
|
49
48
|
if TYPE_CHECKING:
|
50
49
|
from airflow.providers.common.compat.openlineage.facet import Dataset
|
@@ -56,7 +55,7 @@ CHECK_INTERVAL_SECOND: int = 30
|
|
56
55
|
|
57
56
|
|
58
57
|
def serialize(result: dict) -> dict:
|
59
|
-
return json.loads(json.dumps(result,
|
58
|
+
return json.loads(json.dumps(result, default=repr))
|
60
59
|
|
61
60
|
|
62
61
|
class SageMakerBaseOperator(BaseOperator):
|
@@ -171,7 +170,7 @@ class SageMakerBaseOperator(BaseOperator):
|
|
171
170
|
timestamp = str(
|
172
171
|
time.time_ns() // 1000000000
|
173
172
|
) # only keep the relevant datetime (first 10 digits)
|
174
|
-
name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
173
|
+
name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
175
174
|
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
|
176
175
|
return name
|
177
176
|
|
@@ -179,8 +178,7 @@ class SageMakerBaseOperator(BaseOperator):
|
|
179
178
|
"""Raise exception if resource type is not 'model' or 'job'."""
|
180
179
|
if resource_type not in ("model", "job"):
|
181
180
|
raise AirflowException(
|
182
|
-
"Argument resource_type accepts only 'model' and 'job'. "
|
183
|
-
f"Provided value: '{resource_type}'."
|
181
|
+
f"Argument resource_type accepts only 'model' and 'job'. Provided value: '{resource_type}'."
|
184
182
|
)
|
185
183
|
|
186
184
|
def _check_if_job_exists(self, job_name: str, describe_func: Callable[[str], Any]) -> bool:
|
@@ -560,8 +558,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
560
558
|
self.operation = "update"
|
561
559
|
sagemaker_operation = self.hook.update_endpoint
|
562
560
|
self.log.warning(
|
563
|
-
"cannot create already existing endpoint %s, "
|
564
|
-
"updating it with the given config instead",
|
561
|
+
"cannot create already existing endpoint %s, updating it with the given config instead",
|
565
562
|
endpoint_info["EndpointName"],
|
566
563
|
)
|
567
564
|
if "Tags" in endpoint_info:
|
@@ -18,21 +18,21 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from collections.abc import Sequence
|
21
|
-
from functools import cached_property
|
22
21
|
from typing import TYPE_CHECKING, Any
|
23
22
|
|
24
23
|
from airflow.configuration import conf
|
25
24
|
from airflow.exceptions import AirflowException
|
26
25
|
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
|
26
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
27
27
|
from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
|
28
28
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
29
|
-
from airflow.
|
29
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
30
30
|
|
31
31
|
if TYPE_CHECKING:
|
32
32
|
from airflow.utils.context import Context
|
33
33
|
|
34
34
|
|
35
|
-
class EC2InstanceStateSensor(
|
35
|
+
class EC2InstanceStateSensor(AwsBaseSensor[EC2Hook]):
|
36
36
|
"""
|
37
37
|
Poll the state of the AWS EC2 instance until the instance reaches the target state.
|
38
38
|
|
@@ -46,7 +46,8 @@ class EC2InstanceStateSensor(BaseSensorOperator):
|
|
46
46
|
:param deferrable: if True, the sensor will run in deferrable mode
|
47
47
|
"""
|
48
48
|
|
49
|
-
|
49
|
+
aws_hook_class = EC2Hook
|
50
|
+
template_fields: Sequence[str] = aws_template_fields("target_state", "instance_id", "region_name")
|
50
51
|
ui_color = "#cc8811"
|
51
52
|
ui_fgcolor = "#ffffff"
|
52
53
|
valid_states = ["running", "stopped", "terminated"]
|
@@ -56,8 +57,6 @@ class EC2InstanceStateSensor(BaseSensorOperator):
|
|
56
57
|
*,
|
57
58
|
target_state: str,
|
58
59
|
instance_id: str,
|
59
|
-
aws_conn_id: str | None = "aws_default",
|
60
|
-
region_name: str | None = None,
|
61
60
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
62
61
|
**kwargs,
|
63
62
|
):
|
@@ -66,8 +65,6 @@ class EC2InstanceStateSensor(BaseSensorOperator):
|
|
66
65
|
super().__init__(**kwargs)
|
67
66
|
self.target_state = target_state
|
68
67
|
self.instance_id = instance_id
|
69
|
-
self.aws_conn_id = aws_conn_id
|
70
|
-
self.region_name = region_name
|
71
68
|
self.deferrable = deferrable
|
72
69
|
|
73
70
|
def execute(self, context: Context) -> Any:
|
@@ -85,10 +82,6 @@ class EC2InstanceStateSensor(BaseSensorOperator):
|
|
85
82
|
else:
|
86
83
|
super().execute(context=context)
|
87
84
|
|
88
|
-
@cached_property
|
89
|
-
def hook(self):
|
90
|
-
return EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
91
|
-
|
92
85
|
def poke(self, context: Context):
|
93
86
|
instance_state = self.hook.get_instance_state(instance_id=self.instance_id)
|
94
87
|
self.log.info("instance state: %s", instance_state)
|
@@ -95,5 +95,5 @@ class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
|
|
95
95
|
return False
|
96
96
|
else:
|
97
97
|
raise AirflowException(
|
98
|
-
f
|
98
|
+
f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
|
99
99
|
)
|
@@ -18,13 +18,16 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from collections.abc import Collection, Sequence
|
21
|
-
from typing import TYPE_CHECKING
|
21
|
+
from typing import TYPE_CHECKING, Any
|
22
22
|
|
23
|
+
from airflow.configuration import conf
|
23
24
|
from airflow.exceptions import AirflowException
|
24
25
|
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
|
25
26
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
27
|
+
from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger
|
28
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
26
29
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
27
|
-
from airflow.utils.state import
|
30
|
+
from airflow.utils.state import DagRunState
|
28
31
|
|
29
32
|
if TYPE_CHECKING:
|
30
33
|
from airflow.utils.context import Context
|
@@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
46
49
|
(templated)
|
47
50
|
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
|
48
51
|
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
|
49
|
-
``airflow.utils.state.
|
52
|
+
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
|
50
53
|
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
|
51
|
-
AirflowException, default is ``airflow.utils.state.
|
54
|
+
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
|
55
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
56
|
+
module to be installed.
|
57
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
58
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
|
59
|
+
:param max_retries: Number of times before returning the current state. (default: 720)
|
60
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
61
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
62
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
63
|
+
empty, then default boto3 configuration would be used (and must be
|
64
|
+
maintained on each worker node).
|
65
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
66
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
67
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
68
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
69
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
52
70
|
"""
|
53
71
|
|
54
72
|
aws_hook_class = MwaaHook
|
@@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
58
76
|
"external_dag_run_id",
|
59
77
|
"success_states",
|
60
78
|
"failure_states",
|
79
|
+
"deferrable",
|
80
|
+
"max_retries",
|
81
|
+
"poke_interval",
|
61
82
|
)
|
62
83
|
|
63
84
|
def __init__(
|
@@ -68,19 +89,25 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
68
89
|
external_dag_run_id: str,
|
69
90
|
success_states: Collection[str] | None = None,
|
70
91
|
failure_states: Collection[str] | None = None,
|
92
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
93
|
+
poke_interval: int = 60,
|
94
|
+
max_retries: int = 720,
|
71
95
|
**kwargs,
|
72
96
|
):
|
73
97
|
super().__init__(**kwargs)
|
74
98
|
|
75
|
-
self.success_states = set(success_states if success_states else
|
76
|
-
self.failure_states = set(failure_states if failure_states else
|
99
|
+
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
|
100
|
+
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
|
77
101
|
|
78
102
|
if len(self.success_states & self.failure_states):
|
79
|
-
raise
|
103
|
+
raise ValueError("success_states and failure_states must not have any values in common")
|
80
104
|
|
81
105
|
self.external_env_name = external_env_name
|
82
106
|
self.external_dag_id = external_dag_id
|
83
107
|
self.external_dag_run_id = external_dag_run_id
|
108
|
+
self.deferrable = deferrable
|
109
|
+
self.poke_interval = poke_interval
|
110
|
+
self.max_retries = max_retries
|
84
111
|
|
85
112
|
def poke(self, context: Context) -> bool:
|
86
113
|
self.log.info(
|
@@ -102,12 +129,33 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
102
129
|
# The scope of this sensor is going to only be raising AirflowException due to failure of the DAGRun
|
103
130
|
|
104
131
|
state = response["RestApiResponse"]["state"]
|
105
|
-
if state in self.success_states:
|
106
|
-
return True
|
107
132
|
|
108
133
|
if state in self.failure_states:
|
109
134
|
raise AirflowException(
|
110
135
|
f"The DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} "
|
111
|
-
f"failed with state {state}
|
136
|
+
f"failed with state: {state}"
|
112
137
|
)
|
113
|
-
|
138
|
+
|
139
|
+
return state in self.success_states
|
140
|
+
|
141
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
142
|
+
validate_execute_complete_event(event)
|
143
|
+
|
144
|
+
def execute(self, context: Context):
|
145
|
+
if self.deferrable:
|
146
|
+
self.defer(
|
147
|
+
trigger=MwaaDagRunCompletedTrigger(
|
148
|
+
external_env_name=self.external_env_name,
|
149
|
+
external_dag_id=self.external_dag_id,
|
150
|
+
external_dag_run_id=self.external_dag_run_id,
|
151
|
+
success_states=self.success_states,
|
152
|
+
failure_states=self.failure_states,
|
153
|
+
# somehow the type of poke_interval is derived as float ??
|
154
|
+
waiter_delay=self.poke_interval, # type: ignore[arg-type]
|
155
|
+
waiter_max_attempts=self.max_retries,
|
156
|
+
aws_conn_id=self.aws_conn_id,
|
157
|
+
),
|
158
|
+
method_name="execute_complete",
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
super().execute(context=context)
|
@@ -192,7 +192,7 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
|
|
192
192
|
self.defer(
|
193
193
|
timeout=timedelta(seconds=self.timeout),
|
194
194
|
trigger=S3KeyTrigger(
|
195
|
-
bucket_name=cast(str, self.bucket_name),
|
195
|
+
bucket_name=cast("str", self.bucket_name),
|
196
196
|
bucket_key=self.bucket_key,
|
197
197
|
wildcard_match=self.wildcard_match,
|
198
198
|
aws_conn_id=self.aws_conn_id,
|
@@ -103,7 +103,7 @@ class MongoToS3Operator(BaseOperator):
|
|
103
103
|
if self.is_pipeline:
|
104
104
|
results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate(
|
105
105
|
mongo_collection=self.mongo_collection,
|
106
|
-
aggregate_query=cast(list, self.mongo_query),
|
106
|
+
aggregate_query=cast("list", self.mongo_query),
|
107
107
|
mongo_db=self.mongo_db,
|
108
108
|
allowDiskUse=self.allow_disk_use,
|
109
109
|
)
|
@@ -111,7 +111,7 @@ class MongoToS3Operator(BaseOperator):
|
|
111
111
|
else:
|
112
112
|
results = MongoHook(self.mongo_conn_id).find(
|
113
113
|
mongo_collection=self.mongo_collection,
|
114
|
-
query=cast(dict, self.mongo_query),
|
114
|
+
query=cast("dict", self.mongo_query),
|
115
115
|
projection=self.mongo_projection,
|
116
116
|
mongo_db=self.mongo_db,
|
117
117
|
find_one=False,
|
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
|
|
29
29
|
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
|
30
30
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
31
31
|
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
|
32
|
+
from airflow.utils.types import NOTSET, ArgNotSet
|
32
33
|
|
33
34
|
if TYPE_CHECKING:
|
34
35
|
from airflow.utils.context import Context
|
@@ -102,7 +103,7 @@ class RedshiftToS3Operator(BaseOperator):
|
|
102
103
|
table: str | None = None,
|
103
104
|
select_query: str | None = None,
|
104
105
|
redshift_conn_id: str = "redshift_default",
|
105
|
-
aws_conn_id: str | None =
|
106
|
+
aws_conn_id: str | None | ArgNotSet = NOTSET,
|
106
107
|
verify: bool | str | None = None,
|
107
108
|
unload_options: list | None = None,
|
108
109
|
autocommit: bool = False,
|
@@ -118,7 +119,6 @@ class RedshiftToS3Operator(BaseOperator):
|
|
118
119
|
self.schema = schema
|
119
120
|
self.table = table
|
120
121
|
self.redshift_conn_id = redshift_conn_id
|
121
|
-
self.aws_conn_id = aws_conn_id
|
122
122
|
self.verify = verify
|
123
123
|
self.unload_options = unload_options or []
|
124
124
|
self.autocommit = autocommit
|
@@ -127,6 +127,16 @@ class RedshiftToS3Operator(BaseOperator):
|
|
127
127
|
self.table_as_file_name = table_as_file_name
|
128
128
|
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
|
129
129
|
self.select_query = select_query
|
130
|
+
# In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
|
131
|
+
# actually provide a connection note that, because we don't want to let the exception bubble up in
|
132
|
+
# that case (since we're silently injecting a connection on their behalf).
|
133
|
+
self._aws_conn_id: str | None
|
134
|
+
if isinstance(aws_conn_id, ArgNotSet):
|
135
|
+
self.conn_set = False
|
136
|
+
self._aws_conn_id = "aws_default"
|
137
|
+
else:
|
138
|
+
self.conn_set = True
|
139
|
+
self._aws_conn_id = aws_conn_id
|
130
140
|
|
131
141
|
def _build_unload_query(
|
132
142
|
self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
|
@@ -176,11 +186,16 @@ class RedshiftToS3Operator(BaseOperator):
|
|
176
186
|
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
|
177
187
|
else:
|
178
188
|
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
|
179
|
-
conn =
|
189
|
+
conn = (
|
190
|
+
S3Hook.get_connection(conn_id=self._aws_conn_id)
|
191
|
+
# Only fetch the connection if it was set by the user and it is not None
|
192
|
+
if self.conn_set and self._aws_conn_id
|
193
|
+
else None
|
194
|
+
)
|
180
195
|
if conn and conn.extra_dejson.get("role_arn", False):
|
181
196
|
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
|
182
197
|
else:
|
183
|
-
s3_hook = S3Hook(aws_conn_id=self.
|
198
|
+
s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
|
184
199
|
credentials = s3_hook.get_credentials()
|
185
200
|
credentials_block = build_credentials_block(credentials)
|
186
201
|
|
@@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
|
|
25
25
|
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
|
26
26
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
27
27
|
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
|
28
|
+
from airflow.utils.types import NOTSET, ArgNotSet
|
28
29
|
|
29
30
|
if TYPE_CHECKING:
|
30
31
|
from airflow.utils.context import Context
|
@@ -93,7 +94,7 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
93
94
|
s3_key: str,
|
94
95
|
schema: str | None = None,
|
95
96
|
redshift_conn_id: str = "redshift_default",
|
96
|
-
aws_conn_id: str | None =
|
97
|
+
aws_conn_id: str | None | ArgNotSet = NOTSET,
|
97
98
|
verify: bool | str | None = None,
|
98
99
|
column_list: list[str] | None = None,
|
99
100
|
copy_options: list | None = None,
|
@@ -117,6 +118,16 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
117
118
|
self.method = method
|
118
119
|
self.upsert_keys = upsert_keys
|
119
120
|
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
|
121
|
+
# In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
|
122
|
+
# actually provide a connection note that, because we don't want to let the exception bubble up in
|
123
|
+
# that case (since we're silently injecting a connection on their behalf).
|
124
|
+
self._aws_conn_id: str | None
|
125
|
+
if isinstance(aws_conn_id, ArgNotSet):
|
126
|
+
self.conn_set = False
|
127
|
+
self._aws_conn_id = "aws_default"
|
128
|
+
else:
|
129
|
+
self.conn_set = True
|
130
|
+
self._aws_conn_id = aws_conn_id
|
120
131
|
|
121
132
|
if self.redshift_data_api_kwargs:
|
122
133
|
for arg in ["sql", "parameters"]:
|
@@ -149,14 +160,19 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
149
160
|
else:
|
150
161
|
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
|
151
162
|
|
152
|
-
conn =
|
163
|
+
conn = (
|
164
|
+
S3Hook.get_connection(conn_id=self._aws_conn_id)
|
165
|
+
# Only fetch the connection if it was set by the user and it is not None
|
166
|
+
if self.conn_set and self._aws_conn_id
|
167
|
+
else None
|
168
|
+
)
|
153
169
|
region_info = ""
|
154
170
|
if conn and conn.extra_dejson.get("region", False):
|
155
171
|
region_info = f"region '{conn.extra_dejson['region']}'"
|
156
172
|
if conn and conn.extra_dejson.get("role_arn", False):
|
157
173
|
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
|
158
174
|
else:
|
159
|
-
s3_hook = S3Hook(aws_conn_id=self.
|
175
|
+
s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
|
160
176
|
credentials = s3_hook.get_credentials()
|
161
177
|
credentials_block = build_credentials_block(credentials)
|
162
178
|
|
@@ -223,7 +223,7 @@ class SqlToS3Operator(BaseOperator):
|
|
223
223
|
return
|
224
224
|
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
|
225
225
|
yield (
|
226
|
-
cast(str, group_label),
|
226
|
+
cast("str", group_label),
|
227
227
|
grouped_df.get_group(group_label)
|
228
228
|
.drop(random_column_name, axis=1, errors="ignore")
|
229
229
|
.reset_index(drop=True),
|
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
55
55
|
|
56
56
|
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
57
57
|
:param waiter_max_attempts: The maximum number of attempts to be made.
|
58
|
+
:param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will
|
59
|
+
be updated.
|
58
60
|
:param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
|
59
61
|
:param region_name: The AWS region where the resources to watch are. To be used to build the hook.
|
60
62
|
:param verify: Whether or not to verify SSL certificates. To be used to build the hook.
|
@@ -77,6 +79,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
77
79
|
return_value: Any,
|
78
80
|
waiter_delay: int,
|
79
81
|
waiter_max_attempts: int,
|
82
|
+
waiter_config_overrides: dict[str, Any] | None = None,
|
80
83
|
aws_conn_id: str | None,
|
81
84
|
region_name: str | None = None,
|
82
85
|
verify: bool | str | None = None,
|
@@ -91,6 +94,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
91
94
|
self.failure_message = failure_message
|
92
95
|
self.status_message = status_message
|
93
96
|
self.status_queries = status_queries
|
97
|
+
self.waiter_config_overrides = waiter_config_overrides
|
94
98
|
|
95
99
|
self.return_key = return_key
|
96
100
|
self.return_value = return_value
|
@@ -140,7 +144,12 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
140
144
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
141
145
|
hook = self.hook()
|
142
146
|
async with await hook.get_async_conn() as client:
|
143
|
-
waiter = hook.get_waiter(
|
147
|
+
waiter = hook.get_waiter(
|
148
|
+
self.waiter_name,
|
149
|
+
deferrable=True,
|
150
|
+
client=client,
|
151
|
+
config_overrides=self.waiter_config_overrides,
|
152
|
+
)
|
144
153
|
await async_wait(
|
145
154
|
waiter,
|
146
155
|
self.waiter_delay,
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from collections.abc import Collection
|
21
|
+
from typing import TYPE_CHECKING
|
22
|
+
|
23
|
+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
|
24
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
25
|
+
from airflow.utils.state import DagRunState
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
29
|
+
|
30
|
+
|
31
|
+
class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
|
32
|
+
"""
|
33
|
+
Trigger when an MWAA Dag Run is complete.
|
34
|
+
|
35
|
+
:param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
|
36
|
+
(templated)
|
37
|
+
:param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
|
38
|
+
(templated)
|
39
|
+
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
|
40
|
+
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
|
41
|
+
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
|
42
|
+
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
|
43
|
+
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
|
44
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
|
45
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
|
46
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
*,
|
52
|
+
external_env_name: str,
|
53
|
+
external_dag_id: str,
|
54
|
+
external_dag_run_id: str,
|
55
|
+
success_states: Collection[str] | None = None,
|
56
|
+
failure_states: Collection[str] | None = None,
|
57
|
+
waiter_delay: int = 60,
|
58
|
+
waiter_max_attempts: int = 720,
|
59
|
+
aws_conn_id: str | None = None,
|
60
|
+
) -> None:
|
61
|
+
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
|
62
|
+
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
|
63
|
+
|
64
|
+
if len(self.success_states & self.failure_states):
|
65
|
+
raise ValueError("success_states and failure_states must not have any values in common")
|
66
|
+
|
67
|
+
in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states
|
68
|
+
|
69
|
+
super().__init__(
|
70
|
+
serialized_fields={
|
71
|
+
"external_env_name": external_env_name,
|
72
|
+
"external_dag_id": external_dag_id,
|
73
|
+
"external_dag_run_id": external_dag_run_id,
|
74
|
+
"success_states": success_states,
|
75
|
+
"failure_states": failure_states,
|
76
|
+
},
|
77
|
+
waiter_name="mwaa_dag_run_complete",
|
78
|
+
waiter_args={
|
79
|
+
"Name": external_env_name,
|
80
|
+
"Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
|
81
|
+
"Method": "GET",
|
82
|
+
},
|
83
|
+
failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
|
84
|
+
status_message="State of DAG run",
|
85
|
+
status_queries=["RestApiResponse.state"],
|
86
|
+
return_key="dag_run_id",
|
87
|
+
return_value=external_dag_run_id,
|
88
|
+
waiter_delay=waiter_delay,
|
89
|
+
waiter_max_attempts=waiter_max_attempts,
|
90
|
+
aws_conn_id=aws_conn_id,
|
91
|
+
waiter_config_overrides={
|
92
|
+
"acceptors": _build_waiter_acceptors(
|
93
|
+
success_states=self.success_states,
|
94
|
+
failure_states=self.failure_states,
|
95
|
+
in_progress_states=in_progress_states,
|
96
|
+
)
|
97
|
+
},
|
98
|
+
)
|
99
|
+
|
100
|
+
def hook(self) -> AwsGenericHook:
|
101
|
+
return MwaaHook(
|
102
|
+
aws_conn_id=self.aws_conn_id,
|
103
|
+
region_name=self.region_name,
|
104
|
+
verify=self.verify,
|
105
|
+
config=self.botocore_config,
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
def _build_waiter_acceptors(
|
110
|
+
success_states: set[str], failure_states: set[str], in_progress_states: set[str]
|
111
|
+
) -> list:
|
112
|
+
acceptors = []
|
113
|
+
for state_set, state_waiter_category in (
|
114
|
+
(success_states, "success"),
|
115
|
+
(failure_states, "failure"),
|
116
|
+
(in_progress_states, "retry"),
|
117
|
+
):
|
118
|
+
for dag_run_state in state_set:
|
119
|
+
acceptors.append(
|
120
|
+
{
|
121
|
+
"matcher": "path",
|
122
|
+
"argument": "RestApiResponse.state",
|
123
|
+
"expected": dag_run_state,
|
124
|
+
"state": state_waiter_category,
|
125
|
+
}
|
126
|
+
)
|
127
|
+
|
128
|
+
return acceptors
|
@@ -136,15 +136,16 @@ async def async_wait(
|
|
136
136
|
last_response = error.last_response
|
137
137
|
|
138
138
|
if "terminal failure" in error_reason:
|
139
|
-
|
140
|
-
|
139
|
+
raise AirflowException(
|
140
|
+
f"{failure_message}: {_LazyStatusFormatter(status_args, last_response)}\n{error}"
|
141
|
+
)
|
141
142
|
|
142
143
|
if (
|
143
144
|
"An error occurred" in error_reason
|
144
145
|
and isinstance(last_response.get("Error"), dict)
|
145
146
|
and "Code" in last_response.get("Error")
|
146
147
|
):
|
147
|
-
raise AirflowException(f"{failure_message}
|
148
|
+
raise AirflowException(f"{failure_message}\n{last_response}\n{error}")
|
148
149
|
|
149
150
|
log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
|
150
151
|
else:
|
@@ -0,0 +1,36 @@
|
|
1
|
+
{
|
2
|
+
"version": 2,
|
3
|
+
"waiters": {
|
4
|
+
"mwaa_dag_run_complete": {
|
5
|
+
"delay": 60,
|
6
|
+
"maxAttempts": 720,
|
7
|
+
"operation": "InvokeRestApi",
|
8
|
+
"acceptors": [
|
9
|
+
{
|
10
|
+
"matcher": "path",
|
11
|
+
"argument": "RestApiResponse.state",
|
12
|
+
"expected": "queued",
|
13
|
+
"state": "retry"
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"matcher": "path",
|
17
|
+
"argument": "RestApiResponse.state",
|
18
|
+
"expected": "running",
|
19
|
+
"state": "retry"
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"matcher": "path",
|
23
|
+
"argument": "RestApiResponse.state",
|
24
|
+
"expected": "success",
|
25
|
+
"state": "success"
|
26
|
+
},
|
27
|
+
{
|
28
|
+
"matcher": "path",
|
29
|
+
"argument": "RestApiResponse.state",
|
30
|
+
"expected": "failed",
|
31
|
+
"state": "failure"
|
32
|
+
}
|
33
|
+
]
|
34
|
+
}
|
35
|
+
}
|
36
|
+
}
|