apache-airflow-providers-amazon 9.6.1rc1__py3-none-any.whl → 9.7.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +3 -3
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/schema.json +33 -7
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +8 -5
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +6 -9
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +2 -12
- airflow/providers/amazon/aws/auth_manager/datamodels/login.py +26 -0
- airflow/providers/amazon/aws/auth_manager/routes/__init__.py +16 -0
- airflow/providers/amazon/aws/auth_manager/{router → routes}/login.py +29 -10
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +1 -5
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -6
- airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -4
- airflow/providers/amazon/aws/operators/emr.py +147 -142
- airflow/providers/amazon/aws/operators/glue.py +56 -48
- airflow/providers/amazon/aws/queues/__init__.py +16 -0
- airflow/providers/amazon/aws/queues/sqs.py +52 -0
- airflow/providers/amazon/aws/sensors/emr.py +49 -52
- airflow/providers/amazon/get_provider_info.py +2 -7
- airflow/providers/amazon/version_compat.py +0 -1
- {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.0rc1.dist-info}/METADATA +37 -30
- {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.0rc1.dist-info}/RECORD +24 -20
- /airflow/providers/amazon/aws/auth_manager/{router → datamodels}/__init__.py +0 -0
- {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.0rc1.dist-info}/entry_points.txt +0 -0
@@ -20,14 +20,12 @@ from __future__ import annotations
|
|
20
20
|
import os
|
21
21
|
import urllib.parse
|
22
22
|
from collections.abc import Sequence
|
23
|
-
from functools import cached_property
|
24
23
|
from typing import TYPE_CHECKING, Any
|
25
24
|
|
26
25
|
from botocore.exceptions import ClientError
|
27
26
|
|
28
27
|
from airflow.configuration import conf
|
29
28
|
from airflow.exceptions import AirflowException
|
30
|
-
from airflow.models import BaseOperator
|
31
29
|
from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
|
32
30
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
33
31
|
from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
|
@@ -38,12 +36,13 @@ from airflow.providers.amazon.aws.triggers.glue import (
|
|
38
36
|
GlueJobCompleteTrigger,
|
39
37
|
)
|
40
38
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
39
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
41
40
|
|
42
41
|
if TYPE_CHECKING:
|
43
42
|
from airflow.utils.context import Context
|
44
43
|
|
45
44
|
|
46
|
-
class GlueJobOperator(
|
45
|
+
class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
|
47
46
|
"""
|
48
47
|
Create an AWS Glue Job.
|
49
48
|
|
@@ -82,7 +81,8 @@ class GlueJobOperator(BaseOperator):
|
|
82
81
|
For more information see: https://repost.aws/questions/QUaKgpLBMPSGWO0iq2Fob_bw/glue-run-concurrent-jobs#ANFpCL2fRnQRqgDFuIU_rpvA
|
83
82
|
"""
|
84
83
|
|
85
|
-
|
84
|
+
aws_hook_class = GlueJobHook
|
85
|
+
template_fields: Sequence[str] = aws_template_fields(
|
86
86
|
"job_name",
|
87
87
|
"script_location",
|
88
88
|
"script_args",
|
@@ -112,8 +112,6 @@ class GlueJobOperator(BaseOperator):
|
|
112
112
|
script_args: dict | None = None,
|
113
113
|
retry_limit: int = 0,
|
114
114
|
num_of_dpus: int | float | None = None,
|
115
|
-
aws_conn_id: str | None = "aws_default",
|
116
|
-
region_name: str | None = None,
|
117
115
|
s3_bucket: str | None = None,
|
118
116
|
iam_role_name: str | None = None,
|
119
117
|
iam_role_arn: str | None = None,
|
@@ -137,8 +135,6 @@ class GlueJobOperator(BaseOperator):
|
|
137
135
|
self.script_args = script_args or {}
|
138
136
|
self.retry_limit = retry_limit
|
139
137
|
self.num_of_dpus = num_of_dpus
|
140
|
-
self.aws_conn_id = aws_conn_id
|
141
|
-
self.region_name = region_name
|
142
138
|
self.s3_bucket = s3_bucket
|
143
139
|
self.iam_role_name = iam_role_name
|
144
140
|
self.iam_role_arn = iam_role_arn
|
@@ -155,39 +151,49 @@ class GlueJobOperator(BaseOperator):
|
|
155
151
|
self.stop_job_run_on_kill = stop_job_run_on_kill
|
156
152
|
self._job_run_id: str | None = None
|
157
153
|
self.sleep_before_return: int = sleep_before_return
|
154
|
+
self.s3_script_location: str | None = None
|
158
155
|
|
159
|
-
@
|
160
|
-
def
|
156
|
+
@property
|
157
|
+
def _hook_parameters(self):
|
158
|
+
# Upload script to S3 before creating the hook.
|
161
159
|
if self.script_location is None:
|
162
|
-
s3_script_location = None
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
self.
|
169
|
-
|
170
|
-
|
171
|
-
)
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
160
|
+
self.s3_script_location = None
|
161
|
+
# location provided, but it's not in S3 yet.
|
162
|
+
elif self.script_location and self.s3_script_location is None:
|
163
|
+
if not self.script_location.startswith(self.s3_protocol):
|
164
|
+
self.upload_etl_script_to_s3()
|
165
|
+
else:
|
166
|
+
self.s3_script_location = self.script_location
|
167
|
+
|
168
|
+
return {
|
169
|
+
**super()._hook_parameters,
|
170
|
+
"job_name": self.job_name,
|
171
|
+
"desc": self.job_desc,
|
172
|
+
"concurrent_run_limit": self.concurrent_run_limit,
|
173
|
+
"script_location": self.s3_script_location,
|
174
|
+
"retry_limit": self.retry_limit,
|
175
|
+
"num_of_dpus": self.num_of_dpus,
|
176
|
+
"aws_conn_id": self.aws_conn_id,
|
177
|
+
"region_name": self.region_name,
|
178
|
+
"s3_bucket": self.s3_bucket,
|
179
|
+
"iam_role_name": self.iam_role_name,
|
180
|
+
"iam_role_arn": self.iam_role_arn,
|
181
|
+
"create_job_kwargs": self.create_job_kwargs,
|
182
|
+
"update_config": self.update_config,
|
183
|
+
"job_poll_interval": self.job_poll_interval,
|
184
|
+
}
|
185
|
+
|
186
|
+
def upload_etl_script_to_s3(self):
|
187
|
+
"""Upload the ETL script to S3."""
|
188
|
+
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
|
189
|
+
script_name = os.path.basename(self.script_location)
|
190
|
+
s3_hook.load_file(
|
191
|
+
self.script_location,
|
192
|
+
self.s3_artifacts_prefix + script_name,
|
193
|
+
bucket_name=self.s3_bucket,
|
194
|
+
replace=self.replace_script_file,
|
190
195
|
)
|
196
|
+
self.s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
|
191
197
|
|
192
198
|
def execute(self, context: Context):
|
193
199
|
"""
|
@@ -200,19 +206,19 @@ class GlueJobOperator(BaseOperator):
|
|
200
206
|
self.job_name,
|
201
207
|
self.wait_for_completion,
|
202
208
|
)
|
203
|
-
glue_job_run = self.
|
209
|
+
glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs)
|
204
210
|
self._job_run_id = glue_job_run["JobRunId"]
|
205
211
|
glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
|
206
|
-
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.
|
207
|
-
region_name=self.
|
212
|
+
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition),
|
213
|
+
region_name=self.hook.conn_region_name,
|
208
214
|
job_name=urllib.parse.quote(self.job_name, safe=""),
|
209
215
|
job_run_id=self._job_run_id,
|
210
216
|
)
|
211
217
|
GlueJobRunDetailsLink.persist(
|
212
218
|
context=context,
|
213
219
|
operator=self,
|
214
|
-
region_name=self.
|
215
|
-
aws_partition=self.
|
220
|
+
region_name=self.hook.conn_region_name,
|
221
|
+
aws_partition=self.hook.conn_partition,
|
216
222
|
job_name=urllib.parse.quote(self.job_name, safe=""),
|
217
223
|
job_run_id=self._job_run_id,
|
218
224
|
)
|
@@ -230,7 +236,7 @@ class GlueJobOperator(BaseOperator):
|
|
230
236
|
method_name="execute_complete",
|
231
237
|
)
|
232
238
|
elif self.wait_for_completion:
|
233
|
-
glue_job_run = self.
|
239
|
+
glue_job_run = self.hook.job_completion(
|
234
240
|
self.job_name, self._job_run_id, self.verbose, self.sleep_before_return
|
235
241
|
)
|
236
242
|
self.log.info(
|
@@ -254,7 +260,7 @@ class GlueJobOperator(BaseOperator):
|
|
254
260
|
"""Cancel the running AWS Glue Job."""
|
255
261
|
if self.stop_job_run_on_kill:
|
256
262
|
self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
|
257
|
-
response = self.
|
263
|
+
response = self.hook.conn.batch_stop_job_run(
|
258
264
|
JobName=self.job_name,
|
259
265
|
JobRunIds=[self._job_run_id],
|
260
266
|
)
|
@@ -290,7 +296,9 @@ class GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
|
|
290
296
|
"""
|
291
297
|
|
292
298
|
aws_hook_class = GlueDataQualityHook
|
293
|
-
template_fields: Sequence[str] = (
|
299
|
+
template_fields: Sequence[str] = aws_template_fields(
|
300
|
+
"name", "ruleset", "description", "data_quality_ruleset_kwargs"
|
301
|
+
)
|
294
302
|
|
295
303
|
template_fields_renderers = {
|
296
304
|
"data_quality_ruleset_kwargs": "json",
|
@@ -387,7 +395,7 @@ class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
|
|
387
395
|
|
388
396
|
aws_hook_class = GlueDataQualityHook
|
389
397
|
|
390
|
-
template_fields: Sequence[str] = (
|
398
|
+
template_fields: Sequence[str] = aws_template_fields(
|
391
399
|
"datasource",
|
392
400
|
"role",
|
393
401
|
"rule_set_names",
|
@@ -553,7 +561,7 @@ class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
|
|
553
561
|
"""
|
554
562
|
|
555
563
|
aws_hook_class = GlueDataQualityHook
|
556
|
-
template_fields: Sequence[str] = (
|
564
|
+
template_fields: Sequence[str] = aws_template_fields(
|
557
565
|
"datasource",
|
558
566
|
"role",
|
559
567
|
"recommendation_run_kwargs",
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import re
|
20
|
+
from typing import TYPE_CHECKING
|
21
|
+
|
22
|
+
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
23
|
+
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
|
24
|
+
|
25
|
+
try:
|
26
|
+
from airflow.providers.common.messaging.providers.base_provider import BaseMessageQueueProvider
|
27
|
+
except ImportError:
|
28
|
+
raise AirflowOptionalProviderFeatureException(
|
29
|
+
"This feature requires the 'common.messaging' provider to be installed in version >= 1.0.1."
|
30
|
+
)
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from airflow.triggers.base import BaseEventTrigger
|
34
|
+
|
35
|
+
# [START queue_regexp]
|
36
|
+
QUEUE_REGEXP = r"^https://sqs\.[^.]+\.amazonaws\.com/[0-9]+/.+"
|
37
|
+
# [END queue_regexp]
|
38
|
+
|
39
|
+
|
40
|
+
class SqsMessageQueueProvider(BaseMessageQueueProvider):
|
41
|
+
"""Configuration for SQS integration with common-messaging."""
|
42
|
+
|
43
|
+
def queue_matches(self, queue: str) -> bool:
|
44
|
+
return bool(re.match(QUEUE_REGEXP, queue))
|
45
|
+
|
46
|
+
def trigger_class(self) -> type[BaseEventTrigger]:
|
47
|
+
return SqsSensorTrigger
|
48
|
+
|
49
|
+
def trigger_kwargs(self, queue: str, **kwargs) -> dict:
|
50
|
+
return {
|
51
|
+
"sqs_queue": queue,
|
52
|
+
}
|
@@ -19,7 +19,6 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from collections.abc import Iterable, Sequence
|
21
21
|
from datetime import timedelta
|
22
|
-
from functools import cached_property
|
23
22
|
from typing import TYPE_CHECKING, Any
|
24
23
|
|
25
24
|
from airflow.configuration import conf
|
@@ -28,19 +27,20 @@ from airflow.exceptions import (
|
|
28
27
|
)
|
29
28
|
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
|
30
29
|
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
|
30
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
31
31
|
from airflow.providers.amazon.aws.triggers.emr import (
|
32
32
|
EmrContainerTrigger,
|
33
33
|
EmrStepSensorTrigger,
|
34
34
|
EmrTerminateJobFlowTrigger,
|
35
35
|
)
|
36
36
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
37
|
-
from airflow.
|
37
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
38
38
|
|
39
39
|
if TYPE_CHECKING:
|
40
40
|
from airflow.utils.context import Context
|
41
41
|
|
42
42
|
|
43
|
-
class EmrBaseSensor(
|
43
|
+
class EmrBaseSensor(AwsBaseSensor[EmrHook]):
|
44
44
|
"""
|
45
45
|
Contains general sensor behavior for EMR.
|
46
46
|
|
@@ -52,24 +52,23 @@ class EmrBaseSensor(BaseSensorOperator):
|
|
52
52
|
Subclasses should set ``target_states`` and ``failed_states`` fields.
|
53
53
|
|
54
54
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
55
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
55
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
56
56
|
running Airflow in a distributed manner and aws_conn_id is None or
|
57
57
|
empty, then default boto3 configuration would be used (and must be
|
58
58
|
maintained on each worker node).
|
59
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
60
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
61
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
|
59
62
|
"""
|
60
63
|
|
64
|
+
aws_hook_class = EmrHook
|
61
65
|
ui_color = "#66c3ff"
|
62
66
|
|
63
|
-
def __init__(self,
|
67
|
+
def __init__(self, **kwargs):
|
64
68
|
super().__init__(**kwargs)
|
65
|
-
self.aws_conn_id = aws_conn_id
|
66
69
|
self.target_states: Iterable[str] = [] # will be set in subclasses
|
67
70
|
self.failed_states: Iterable[str] = [] # will be set in subclasses
|
68
71
|
|
69
|
-
@cached_property
|
70
|
-
def hook(self) -> EmrHook:
|
71
|
-
return EmrHook(aws_conn_id=self.aws_conn_id)
|
72
|
-
|
73
72
|
def poke(self, context: Context):
|
74
73
|
response = self.get_emr_response(context=context)
|
75
74
|
|
@@ -117,7 +116,7 @@ class EmrBaseSensor(BaseSensorOperator):
|
|
117
116
|
raise NotImplementedError("Please implement failure_message_from_response() in subclass")
|
118
117
|
|
119
118
|
|
120
|
-
class EmrServerlessJobSensor(
|
119
|
+
class EmrServerlessJobSensor(AwsBaseSensor[EmrServerlessHook]):
|
121
120
|
"""
|
122
121
|
Poll the state of the job run until it reaches a terminal state; fails if the job run fails.
|
123
122
|
|
@@ -128,14 +127,18 @@ class EmrServerlessJobSensor(BaseSensorOperator):
|
|
128
127
|
:param application_id: application_id to check the state of
|
129
128
|
:param job_run_id: job_run_id to check the state of
|
130
129
|
:param target_states: a set of states to wait for, defaults to 'SUCCESS'
|
131
|
-
:param aws_conn_id:
|
132
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
130
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
131
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
133
132
|
running Airflow in a distributed manner and aws_conn_id is None or
|
134
133
|
empty, then default boto3 configuration would be used (and must be
|
135
134
|
maintained on each worker node).
|
135
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
136
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
137
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
|
136
138
|
"""
|
137
139
|
|
138
|
-
|
140
|
+
aws_hook_class = EmrServerlessHook
|
141
|
+
template_fields: Sequence[str] = aws_template_fields(
|
139
142
|
"application_id",
|
140
143
|
"job_run_id",
|
141
144
|
)
|
@@ -146,10 +149,8 @@ class EmrServerlessJobSensor(BaseSensorOperator):
|
|
146
149
|
application_id: str,
|
147
150
|
job_run_id: str,
|
148
151
|
target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES),
|
149
|
-
aws_conn_id: str | None = "aws_default",
|
150
152
|
**kwargs: Any,
|
151
153
|
) -> None:
|
152
|
-
self.aws_conn_id = aws_conn_id
|
153
154
|
self.target_states = target_states
|
154
155
|
self.application_id = application_id
|
155
156
|
self.job_run_id = job_run_id
|
@@ -167,11 +168,6 @@ class EmrServerlessJobSensor(BaseSensorOperator):
|
|
167
168
|
|
168
169
|
return state in self.target_states
|
169
170
|
|
170
|
-
@cached_property
|
171
|
-
def hook(self) -> EmrServerlessHook:
|
172
|
-
"""Create and return an EmrServerlessHook."""
|
173
|
-
return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
|
174
|
-
|
175
171
|
@staticmethod
|
176
172
|
def failure_message_from_response(response: dict[str, Any]) -> str | None:
|
177
173
|
"""
|
@@ -183,7 +179,7 @@ class EmrServerlessJobSensor(BaseSensorOperator):
|
|
183
179
|
return response["jobRun"]["stateDetails"]
|
184
180
|
|
185
181
|
|
186
|
-
class EmrServerlessApplicationSensor(
|
182
|
+
class EmrServerlessApplicationSensor(AwsBaseSensor[EmrServerlessHook]):
|
187
183
|
"""
|
188
184
|
Poll the state of the application until it reaches a terminal state; fails if the application fails.
|
189
185
|
|
@@ -193,24 +189,28 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
|
|
193
189
|
|
194
190
|
:param application_id: application_id to check the state of
|
195
191
|
:param target_states: a set of states to wait for, defaults to {'CREATED', 'STARTED'}
|
196
|
-
:param aws_conn_id:
|
197
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
192
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
193
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
198
194
|
running Airflow in a distributed manner and aws_conn_id is None or
|
199
195
|
empty, then default boto3 configuration would be used (and must be
|
200
196
|
maintained on each worker node).
|
197
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
198
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
199
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
|
201
200
|
"""
|
202
201
|
|
203
|
-
|
202
|
+
aws_hook_class = EmrServerlessHook
|
203
|
+
template_fields: Sequence[str] = aws_template_fields(
|
204
|
+
"application_id",
|
205
|
+
)
|
204
206
|
|
205
207
|
def __init__(
|
206
208
|
self,
|
207
209
|
*,
|
208
210
|
application_id: str,
|
209
211
|
target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES),
|
210
|
-
aws_conn_id: str | None = "aws_default",
|
211
212
|
**kwargs: Any,
|
212
213
|
) -> None:
|
213
|
-
self.aws_conn_id = aws_conn_id
|
214
214
|
self.target_states = target_states
|
215
215
|
self.application_id = application_id
|
216
216
|
super().__init__(**kwargs)
|
@@ -227,11 +227,6 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
|
|
227
227
|
|
228
228
|
return state in self.target_states
|
229
229
|
|
230
|
-
@cached_property
|
231
|
-
def hook(self) -> EmrServerlessHook:
|
232
|
-
"""Create and return an EmrServerlessHook."""
|
233
|
-
return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
|
234
|
-
|
235
230
|
@staticmethod
|
236
231
|
def failure_message_from_response(response: dict[str, Any]) -> str | None:
|
237
232
|
"""
|
@@ -243,7 +238,7 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
|
|
243
238
|
return response["application"]["stateDetails"]
|
244
239
|
|
245
240
|
|
246
|
-
class EmrContainerSensor(
|
241
|
+
class EmrContainerSensor(AwsBaseSensor[EmrContainerHook]):
|
247
242
|
"""
|
248
243
|
Poll the state of the job run until it reaches a terminal state; fail if the job run fails.
|
249
244
|
|
@@ -254,11 +249,14 @@ class EmrContainerSensor(BaseSensorOperator):
|
|
254
249
|
:param job_id: job_id to check the state of
|
255
250
|
:param max_retries: Number of times to poll for query state before
|
256
251
|
returning the current state, defaults to None
|
257
|
-
:param aws_conn_id:
|
258
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
252
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
253
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
259
254
|
running Airflow in a distributed manner and aws_conn_id is None or
|
260
255
|
empty, then default boto3 configuration would be used (and must be
|
261
256
|
maintained on each worker node).
|
257
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
258
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
259
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
|
262
260
|
:param poll_interval: Time in seconds to wait between two consecutive call to
|
263
261
|
check query status on athena, defaults to 10
|
264
262
|
:param deferrable: Run sensor in the deferrable mode.
|
@@ -276,7 +274,8 @@ class EmrContainerSensor(BaseSensorOperator):
|
|
276
274
|
)
|
277
275
|
SUCCESS_STATES = ("COMPLETED",)
|
278
276
|
|
279
|
-
|
277
|
+
aws_hook_class = EmrContainerHook
|
278
|
+
template_fields: Sequence[str] = aws_template_fields("virtual_cluster_id", "job_id")
|
280
279
|
template_ext: Sequence[str] = ()
|
281
280
|
ui_color = "#66c3ff"
|
282
281
|
|
@@ -286,22 +285,20 @@ class EmrContainerSensor(BaseSensorOperator):
|
|
286
285
|
virtual_cluster_id: str,
|
287
286
|
job_id: str,
|
288
287
|
max_retries: int | None = None,
|
289
|
-
aws_conn_id: str | None = "aws_default",
|
290
288
|
poll_interval: int = 10,
|
291
289
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
292
290
|
**kwargs: Any,
|
293
291
|
) -> None:
|
294
292
|
super().__init__(**kwargs)
|
295
|
-
self.aws_conn_id = aws_conn_id
|
296
293
|
self.virtual_cluster_id = virtual_cluster_id
|
297
294
|
self.job_id = job_id
|
298
295
|
self.poll_interval = poll_interval
|
299
296
|
self.max_retries = max_retries
|
300
297
|
self.deferrable = deferrable
|
301
298
|
|
302
|
-
@
|
303
|
-
def
|
304
|
-
return
|
299
|
+
@property
|
300
|
+
def _hook_parameters(self):
|
301
|
+
return {**super()._hook_parameters, "virtual_cluster_id": self.virtual_cluster_id}
|
305
302
|
|
306
303
|
def poke(self, context: Context) -> bool:
|
307
304
|
state = self.hook.poll_query_status(
|
@@ -369,7 +366,9 @@ class EmrNotebookExecutionSensor(EmrBaseSensor):
|
|
369
366
|
Default failed_states is ``FAILED``.
|
370
367
|
"""
|
371
368
|
|
372
|
-
template_fields: Sequence[str] = (
|
369
|
+
template_fields: Sequence[str] = aws_template_fields(
|
370
|
+
"notebook_execution_id",
|
371
|
+
)
|
373
372
|
|
374
373
|
FAILURE_STATES = {"FAILED"}
|
375
374
|
COMPLETED_STATES = {"FINISHED"}
|
@@ -387,10 +386,9 @@ class EmrNotebookExecutionSensor(EmrBaseSensor):
|
|
387
386
|
self.failed_states = failed_states or self.FAILURE_STATES
|
388
387
|
|
389
388
|
def get_emr_response(self, context: Context) -> dict[str, Any]:
|
390
|
-
emr_client = self.hook.conn
|
391
389
|
self.log.info("Poking notebook %s", self.notebook_execution_id)
|
392
390
|
|
393
|
-
return
|
391
|
+
return self.hook.conn.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id)
|
394
392
|
|
395
393
|
@staticmethod
|
396
394
|
def state_from_response(response: dict[str, Any]) -> str:
|
@@ -438,7 +436,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
|
|
438
436
|
:param deferrable: Run sensor in the deferrable mode.
|
439
437
|
"""
|
440
438
|
|
441
|
-
template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
|
439
|
+
template_fields: Sequence[str] = aws_template_fields("job_flow_id", "target_states", "failed_states")
|
442
440
|
template_ext: Sequence[str] = ()
|
443
441
|
operator_extra_links = (
|
444
442
|
EmrClusterLink(),
|
@@ -471,9 +469,8 @@ class EmrJobFlowSensor(EmrBaseSensor):
|
|
471
469
|
|
472
470
|
:return: response
|
473
471
|
"""
|
474
|
-
emr_client = self.hook.conn
|
475
472
|
self.log.info("Poking cluster %s", self.job_flow_id)
|
476
|
-
response =
|
473
|
+
response = self.hook.conn.describe_cluster(ClusterId=self.job_flow_id)
|
477
474
|
|
478
475
|
EmrClusterLink.persist(
|
479
476
|
context=context,
|
@@ -563,7 +560,9 @@ class EmrStepSensor(EmrBaseSensor):
|
|
563
560
|
:param deferrable: Run sensor in the deferrable mode.
|
564
561
|
"""
|
565
562
|
|
566
|
-
template_fields: Sequence[str] = (
|
563
|
+
template_fields: Sequence[str] = aws_template_fields(
|
564
|
+
"job_flow_id", "step_id", "target_states", "failed_states"
|
565
|
+
)
|
567
566
|
template_ext: Sequence[str] = ()
|
568
567
|
operator_extra_links = (
|
569
568
|
EmrClusterLink(),
|
@@ -598,10 +597,8 @@ class EmrStepSensor(EmrBaseSensor):
|
|
598
597
|
|
599
598
|
:return: response
|
600
599
|
"""
|
601
|
-
emr_client = self.hook.conn
|
602
|
-
|
603
600
|
self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id)
|
604
|
-
response =
|
601
|
+
response = self.hook.conn.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
|
605
602
|
|
606
603
|
EmrClusterLink.persist(
|
607
604
|
context=context,
|
@@ -616,7 +613,7 @@ class EmrStepSensor(EmrBaseSensor):
|
|
616
613
|
region_name=self.hook.conn_region_name,
|
617
614
|
aws_partition=self.hook.conn_partition,
|
618
615
|
job_flow_id=self.job_flow_id,
|
619
|
-
log_uri=get_log_uri(emr_client=
|
616
|
+
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self.job_flow_id),
|
620
617
|
)
|
621
618
|
|
622
619
|
return response
|
@@ -1258,13 +1258,6 @@ def get_provider_info():
|
|
1258
1258
|
"aws_auth_manager": {
|
1259
1259
|
"description": "This section only applies if you are using the AwsAuthManager. In other words, if you set\n``[core] auth_manager = airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager`` in\nAirflow's configuration.\n",
|
1260
1260
|
"options": {
|
1261
|
-
"enable": {
|
1262
|
-
"description": "AWS auth manager is not ready to be used. Turn on this flag to use it anyway.\nDo that at your own risk since the AWS auth manager is not in an usable state.\n",
|
1263
|
-
"version_added": "8.12.0",
|
1264
|
-
"type": "boolean",
|
1265
|
-
"example": "True",
|
1266
|
-
"default": "False",
|
1267
|
-
},
|
1268
1261
|
"conn_id": {
|
1269
1262
|
"description": "The Airflow connection (i.e. credentials) used by the AWS auth manager to make API calls to AWS\nIdentity Center and Amazon Verified Permissions.\n",
|
1270
1263
|
"version_added": "8.12.0",
|
@@ -1297,4 +1290,6 @@ def get_provider_info():
|
|
1297
1290
|
},
|
1298
1291
|
},
|
1299
1292
|
"executors": ["airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"],
|
1293
|
+
"auth-managers": ["airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager"],
|
1294
|
+
"queues": ["airflow.providers.amazon.aws.queues.sqs.SqsMessageQueueProvider"],
|
1300
1295
|
}
|
@@ -32,5 +32,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
|
|
32
32
|
return airflow_version.major, airflow_version.minor, airflow_version.micro
|
33
33
|
|
34
34
|
|
35
|
-
AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0)
|
36
35
|
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
|