apache-airflow-providers-amazon 9.6.1rc1__py3-none-any.whl → 9.7.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. airflow/providers/amazon/__init__.py +3 -3
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/schema.json +33 -7
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +8 -5
  5. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +6 -9
  6. airflow/providers/amazon/aws/auth_manager/cli/definition.py +2 -12
  7. airflow/providers/amazon/aws/auth_manager/datamodels/login.py +26 -0
  8. airflow/providers/amazon/aws/auth_manager/routes/__init__.py +16 -0
  9. airflow/providers/amazon/aws/auth_manager/{router → routes}/login.py +29 -10
  10. airflow/providers/amazon/aws/executors/batch/batch_executor.py +1 -5
  11. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -6
  12. airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -4
  13. airflow/providers/amazon/aws/operators/emr.py +147 -142
  14. airflow/providers/amazon/aws/operators/glue.py +56 -48
  15. airflow/providers/amazon/aws/queues/__init__.py +16 -0
  16. airflow/providers/amazon/aws/queues/sqs.py +52 -0
  17. airflow/providers/amazon/aws/sensors/emr.py +49 -52
  18. airflow/providers/amazon/get_provider_info.py +2 -7
  19. airflow/providers/amazon/version_compat.py +0 -1
  20. {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.1a1.dist-info}/METADATA +17 -11
  21. {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.1a1.dist-info}/RECORD +24 -20
  22. /airflow/providers/amazon/aws/auth_manager/{router → datamodels}/__init__.py +0 -0
  23. {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.1a1.dist-info}/WHEEL +0 -0
  24. {apache_airflow_providers_amazon-9.6.1rc1.dist-info → apache_airflow_providers_amazon-9.7.1a1.dist-info}/entry_points.txt +0 -0
@@ -20,14 +20,12 @@ from __future__ import annotations
20
20
  import os
21
21
  import urllib.parse
22
22
  from collections.abc import Sequence
23
- from functools import cached_property
24
23
  from typing import TYPE_CHECKING, Any
25
24
 
26
25
  from botocore.exceptions import ClientError
27
26
 
28
27
  from airflow.configuration import conf
29
28
  from airflow.exceptions import AirflowException
30
- from airflow.models import BaseOperator
31
29
  from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
32
30
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
33
31
  from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
@@ -38,12 +36,13 @@ from airflow.providers.amazon.aws.triggers.glue import (
38
36
  GlueJobCompleteTrigger,
39
37
  )
40
38
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
39
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
41
40
 
42
41
  if TYPE_CHECKING:
43
42
  from airflow.utils.context import Context
44
43
 
45
44
 
46
- class GlueJobOperator(BaseOperator):
45
+ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
47
46
  """
48
47
  Create an AWS Glue Job.
49
48
 
@@ -82,7 +81,8 @@ class GlueJobOperator(BaseOperator):
82
81
  For more information see: https://repost.aws/questions/QUaKgpLBMPSGWO0iq2Fob_bw/glue-run-concurrent-jobs#ANFpCL2fRnQRqgDFuIU_rpvA
83
82
  """
84
83
 
85
- template_fields: Sequence[str] = (
84
+ aws_hook_class = GlueJobHook
85
+ template_fields: Sequence[str] = aws_template_fields(
86
86
  "job_name",
87
87
  "script_location",
88
88
  "script_args",
@@ -112,8 +112,6 @@ class GlueJobOperator(BaseOperator):
112
112
  script_args: dict | None = None,
113
113
  retry_limit: int = 0,
114
114
  num_of_dpus: int | float | None = None,
115
- aws_conn_id: str | None = "aws_default",
116
- region_name: str | None = None,
117
115
  s3_bucket: str | None = None,
118
116
  iam_role_name: str | None = None,
119
117
  iam_role_arn: str | None = None,
@@ -137,8 +135,6 @@ class GlueJobOperator(BaseOperator):
137
135
  self.script_args = script_args or {}
138
136
  self.retry_limit = retry_limit
139
137
  self.num_of_dpus = num_of_dpus
140
- self.aws_conn_id = aws_conn_id
141
- self.region_name = region_name
142
138
  self.s3_bucket = s3_bucket
143
139
  self.iam_role_name = iam_role_name
144
140
  self.iam_role_arn = iam_role_arn
@@ -155,39 +151,49 @@ class GlueJobOperator(BaseOperator):
155
151
  self.stop_job_run_on_kill = stop_job_run_on_kill
156
152
  self._job_run_id: str | None = None
157
153
  self.sleep_before_return: int = sleep_before_return
154
+ self.s3_script_location: str | None = None
158
155
 
159
- @cached_property
160
- def glue_job_hook(self) -> GlueJobHook:
156
+ @property
157
+ def _hook_parameters(self):
158
+ # Upload script to S3 before creating the hook.
161
159
  if self.script_location is None:
162
- s3_script_location = None
163
- elif not self.script_location.startswith(self.s3_protocol):
164
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
165
- script_name = os.path.basename(self.script_location)
166
- s3_hook.load_file(
167
- self.script_location,
168
- self.s3_artifacts_prefix + script_name,
169
- bucket_name=self.s3_bucket,
170
- replace=self.replace_script_file,
171
- )
172
- s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
173
- else:
174
- s3_script_location = self.script_location
175
- return GlueJobHook(
176
- job_name=self.job_name,
177
- desc=self.job_desc,
178
- concurrent_run_limit=self.concurrent_run_limit,
179
- script_location=s3_script_location,
180
- retry_limit=self.retry_limit,
181
- num_of_dpus=self.num_of_dpus,
182
- aws_conn_id=self.aws_conn_id,
183
- region_name=self.region_name,
184
- s3_bucket=self.s3_bucket,
185
- iam_role_name=self.iam_role_name,
186
- iam_role_arn=self.iam_role_arn,
187
- create_job_kwargs=self.create_job_kwargs,
188
- update_config=self.update_config,
189
- job_poll_interval=self.job_poll_interval,
160
+ self.s3_script_location = None
161
+ # location provided, but it's not in S3 yet.
162
+ elif self.script_location and self.s3_script_location is None:
163
+ if not self.script_location.startswith(self.s3_protocol):
164
+ self.upload_etl_script_to_s3()
165
+ else:
166
+ self.s3_script_location = self.script_location
167
+
168
+ return {
169
+ **super()._hook_parameters,
170
+ "job_name": self.job_name,
171
+ "desc": self.job_desc,
172
+ "concurrent_run_limit": self.concurrent_run_limit,
173
+ "script_location": self.s3_script_location,
174
+ "retry_limit": self.retry_limit,
175
+ "num_of_dpus": self.num_of_dpus,
176
+ "aws_conn_id": self.aws_conn_id,
177
+ "region_name": self.region_name,
178
+ "s3_bucket": self.s3_bucket,
179
+ "iam_role_name": self.iam_role_name,
180
+ "iam_role_arn": self.iam_role_arn,
181
+ "create_job_kwargs": self.create_job_kwargs,
182
+ "update_config": self.update_config,
183
+ "job_poll_interval": self.job_poll_interval,
184
+ }
185
+
186
+ def upload_etl_script_to_s3(self):
187
+ """Upload the ETL script to S3."""
188
+ s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
189
+ script_name = os.path.basename(self.script_location)
190
+ s3_hook.load_file(
191
+ self.script_location,
192
+ self.s3_artifacts_prefix + script_name,
193
+ bucket_name=self.s3_bucket,
194
+ replace=self.replace_script_file,
190
195
  )
196
+ self.s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
191
197
 
192
198
  def execute(self, context: Context):
193
199
  """
@@ -200,19 +206,19 @@ class GlueJobOperator(BaseOperator):
200
206
  self.job_name,
201
207
  self.wait_for_completion,
202
208
  )
203
- glue_job_run = self.glue_job_hook.initialize_job(self.script_args, self.run_job_kwargs)
209
+ glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs)
204
210
  self._job_run_id = glue_job_run["JobRunId"]
205
211
  glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
206
- aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.glue_job_hook.conn_partition),
207
- region_name=self.glue_job_hook.conn_region_name,
212
+ aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition),
213
+ region_name=self.hook.conn_region_name,
208
214
  job_name=urllib.parse.quote(self.job_name, safe=""),
209
215
  job_run_id=self._job_run_id,
210
216
  )
211
217
  GlueJobRunDetailsLink.persist(
212
218
  context=context,
213
219
  operator=self,
214
- region_name=self.glue_job_hook.conn_region_name,
215
- aws_partition=self.glue_job_hook.conn_partition,
220
+ region_name=self.hook.conn_region_name,
221
+ aws_partition=self.hook.conn_partition,
216
222
  job_name=urllib.parse.quote(self.job_name, safe=""),
217
223
  job_run_id=self._job_run_id,
218
224
  )
@@ -230,7 +236,7 @@ class GlueJobOperator(BaseOperator):
230
236
  method_name="execute_complete",
231
237
  )
232
238
  elif self.wait_for_completion:
233
- glue_job_run = self.glue_job_hook.job_completion(
239
+ glue_job_run = self.hook.job_completion(
234
240
  self.job_name, self._job_run_id, self.verbose, self.sleep_before_return
235
241
  )
236
242
  self.log.info(
@@ -254,7 +260,7 @@ class GlueJobOperator(BaseOperator):
254
260
  """Cancel the running AWS Glue Job."""
255
261
  if self.stop_job_run_on_kill:
256
262
  self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
257
- response = self.glue_job_hook.conn.batch_stop_job_run(
263
+ response = self.hook.conn.batch_stop_job_run(
258
264
  JobName=self.job_name,
259
265
  JobRunIds=[self._job_run_id],
260
266
  )
@@ -290,7 +296,9 @@ class GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
290
296
  """
291
297
 
292
298
  aws_hook_class = GlueDataQualityHook
293
- template_fields: Sequence[str] = ("name", "ruleset", "description", "data_quality_ruleset_kwargs")
299
+ template_fields: Sequence[str] = aws_template_fields(
300
+ "name", "ruleset", "description", "data_quality_ruleset_kwargs"
301
+ )
294
302
 
295
303
  template_fields_renderers = {
296
304
  "data_quality_ruleset_kwargs": "json",
@@ -387,7 +395,7 @@ class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
387
395
 
388
396
  aws_hook_class = GlueDataQualityHook
389
397
 
390
- template_fields: Sequence[str] = (
398
+ template_fields: Sequence[str] = aws_template_fields(
391
399
  "datasource",
392
400
  "role",
393
401
  "rule_set_names",
@@ -553,7 +561,7 @@ class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
553
561
  """
554
562
 
555
563
  aws_hook_class = GlueDataQualityHook
556
- template_fields: Sequence[str] = (
564
+ template_fields: Sequence[str] = aws_template_fields(
557
565
  "datasource",
558
566
  "role",
559
567
  "recommendation_run_kwargs",
@@ -0,0 +1,16 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,52 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import re
20
+ from typing import TYPE_CHECKING
21
+
22
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
23
+ from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
24
+
25
+ try:
26
+ from airflow.providers.common.messaging.providers.base_provider import BaseMessageQueueProvider
27
+ except ImportError:
28
+ raise AirflowOptionalProviderFeatureException(
29
+ "This feature requires the 'common.messaging' provider to be installed in version >= 1.0.1."
30
+ )
31
+
32
+ if TYPE_CHECKING:
33
+ from airflow.triggers.base import BaseEventTrigger
34
+
35
+ # [START queue_regexp]
36
+ QUEUE_REGEXP = r"^https://sqs\.[^.]+\.amazonaws\.com/[0-9]+/.+"
37
+ # [END queue_regexp]
38
+
39
+
40
+ class SqsMessageQueueProvider(BaseMessageQueueProvider):
41
+ """Configuration for SQS integration with common-messaging."""
42
+
43
+ def queue_matches(self, queue: str) -> bool:
44
+ return bool(re.match(QUEUE_REGEXP, queue))
45
+
46
+ def trigger_class(self) -> type[BaseEventTrigger]:
47
+ return SqsSensorTrigger
48
+
49
+ def trigger_kwargs(self, queue: str, **kwargs) -> dict:
50
+ return {
51
+ "sqs_queue": queue,
52
+ }
@@ -19,7 +19,6 @@ from __future__ import annotations
19
19
 
20
20
  from collections.abc import Iterable, Sequence
21
21
  from datetime import timedelta
22
- from functools import cached_property
23
22
  from typing import TYPE_CHECKING, Any
24
23
 
25
24
  from airflow.configuration import conf
@@ -28,19 +27,20 @@ from airflow.exceptions import (
28
27
  )
29
28
  from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
30
29
  from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
30
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
31
31
  from airflow.providers.amazon.aws.triggers.emr import (
32
32
  EmrContainerTrigger,
33
33
  EmrStepSensorTrigger,
34
34
  EmrTerminateJobFlowTrigger,
35
35
  )
36
36
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
37
- from airflow.sensors.base import BaseSensorOperator
37
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
38
38
 
39
39
  if TYPE_CHECKING:
40
40
  from airflow.utils.context import Context
41
41
 
42
42
 
43
- class EmrBaseSensor(BaseSensorOperator):
43
+ class EmrBaseSensor(AwsBaseSensor[EmrHook]):
44
44
  """
45
45
  Contains general sensor behavior for EMR.
46
46
 
@@ -52,24 +52,23 @@ class EmrBaseSensor(BaseSensorOperator):
52
52
  Subclasses should set ``target_states`` and ``failed_states`` fields.
53
53
 
54
54
  :param aws_conn_id: The Airflow connection used for AWS credentials.
55
- If this is None or empty then the default boto3 behaviour is used. If
55
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
56
56
  running Airflow in a distributed manner and aws_conn_id is None or
57
57
  empty, then default boto3 configuration would be used (and must be
58
58
  maintained on each worker node).
59
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
60
+ :param verify: Whether or not to verify SSL certificates. See:
61
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
59
62
  """
60
63
 
64
+ aws_hook_class = EmrHook
61
65
  ui_color = "#66c3ff"
62
66
 
63
- def __init__(self, *, aws_conn_id: str | None = "aws_default", **kwargs):
67
+ def __init__(self, **kwargs):
64
68
  super().__init__(**kwargs)
65
- self.aws_conn_id = aws_conn_id
66
69
  self.target_states: Iterable[str] = [] # will be set in subclasses
67
70
  self.failed_states: Iterable[str] = [] # will be set in subclasses
68
71
 
69
- @cached_property
70
- def hook(self) -> EmrHook:
71
- return EmrHook(aws_conn_id=self.aws_conn_id)
72
-
73
72
  def poke(self, context: Context):
74
73
  response = self.get_emr_response(context=context)
75
74
 
@@ -117,7 +116,7 @@ class EmrBaseSensor(BaseSensorOperator):
117
116
  raise NotImplementedError("Please implement failure_message_from_response() in subclass")
118
117
 
119
118
 
120
- class EmrServerlessJobSensor(BaseSensorOperator):
119
+ class EmrServerlessJobSensor(AwsBaseSensor[EmrServerlessHook]):
121
120
  """
122
121
  Poll the state of the job run until it reaches a terminal state; fails if the job run fails.
123
122
 
@@ -128,14 +127,18 @@ class EmrServerlessJobSensor(BaseSensorOperator):
128
127
  :param application_id: application_id to check the state of
129
128
  :param job_run_id: job_run_id to check the state of
130
129
  :param target_states: a set of states to wait for, defaults to 'SUCCESS'
131
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
132
- If this is None or empty then the default boto3 behaviour is used. If
130
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
131
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
133
132
  running Airflow in a distributed manner and aws_conn_id is None or
134
133
  empty, then default boto3 configuration would be used (and must be
135
134
  maintained on each worker node).
135
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
136
+ :param verify: Whether or not to verify SSL certificates. See:
137
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
136
138
  """
137
139
 
138
- template_fields: Sequence[str] = (
140
+ aws_hook_class = EmrServerlessHook
141
+ template_fields: Sequence[str] = aws_template_fields(
139
142
  "application_id",
140
143
  "job_run_id",
141
144
  )
@@ -146,10 +149,8 @@ class EmrServerlessJobSensor(BaseSensorOperator):
146
149
  application_id: str,
147
150
  job_run_id: str,
148
151
  target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES),
149
- aws_conn_id: str | None = "aws_default",
150
152
  **kwargs: Any,
151
153
  ) -> None:
152
- self.aws_conn_id = aws_conn_id
153
154
  self.target_states = target_states
154
155
  self.application_id = application_id
155
156
  self.job_run_id = job_run_id
@@ -167,11 +168,6 @@ class EmrServerlessJobSensor(BaseSensorOperator):
167
168
 
168
169
  return state in self.target_states
169
170
 
170
- @cached_property
171
- def hook(self) -> EmrServerlessHook:
172
- """Create and return an EmrServerlessHook."""
173
- return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
174
-
175
171
  @staticmethod
176
172
  def failure_message_from_response(response: dict[str, Any]) -> str | None:
177
173
  """
@@ -183,7 +179,7 @@ class EmrServerlessJobSensor(BaseSensorOperator):
183
179
  return response["jobRun"]["stateDetails"]
184
180
 
185
181
 
186
- class EmrServerlessApplicationSensor(BaseSensorOperator):
182
+ class EmrServerlessApplicationSensor(AwsBaseSensor[EmrServerlessHook]):
187
183
  """
188
184
  Poll the state of the application until it reaches a terminal state; fails if the application fails.
189
185
 
@@ -193,24 +189,28 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
193
189
 
194
190
  :param application_id: application_id to check the state of
195
191
  :param target_states: a set of states to wait for, defaults to {'CREATED', 'STARTED'}
196
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
197
- If this is None or empty then the default boto3 behaviour is used. If
192
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
193
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
198
194
  running Airflow in a distributed manner and aws_conn_id is None or
199
195
  empty, then default boto3 configuration would be used (and must be
200
196
  maintained on each worker node).
197
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
198
+ :param verify: Whether or not to verify SSL certificates. See:
199
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
201
200
  """
202
201
 
203
- template_fields: Sequence[str] = ("application_id",)
202
+ aws_hook_class = EmrServerlessHook
203
+ template_fields: Sequence[str] = aws_template_fields(
204
+ "application_id",
205
+ )
204
206
 
205
207
  def __init__(
206
208
  self,
207
209
  *,
208
210
  application_id: str,
209
211
  target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES),
210
- aws_conn_id: str | None = "aws_default",
211
212
  **kwargs: Any,
212
213
  ) -> None:
213
- self.aws_conn_id = aws_conn_id
214
214
  self.target_states = target_states
215
215
  self.application_id = application_id
216
216
  super().__init__(**kwargs)
@@ -227,11 +227,6 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
227
227
 
228
228
  return state in self.target_states
229
229
 
230
- @cached_property
231
- def hook(self) -> EmrServerlessHook:
232
- """Create and return an EmrServerlessHook."""
233
- return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
234
-
235
230
  @staticmethod
236
231
  def failure_message_from_response(response: dict[str, Any]) -> str | None:
237
232
  """
@@ -243,7 +238,7 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
243
238
  return response["application"]["stateDetails"]
244
239
 
245
240
 
246
- class EmrContainerSensor(BaseSensorOperator):
241
+ class EmrContainerSensor(AwsBaseSensor[EmrContainerHook]):
247
242
  """
248
243
  Poll the state of the job run until it reaches a terminal state; fail if the job run fails.
249
244
 
@@ -254,11 +249,14 @@ class EmrContainerSensor(BaseSensorOperator):
254
249
  :param job_id: job_id to check the state of
255
250
  :param max_retries: Number of times to poll for query state before
256
251
  returning the current state, defaults to None
257
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
258
- If this is None or empty then the default boto3 behaviour is used. If
252
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
253
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
259
254
  running Airflow in a distributed manner and aws_conn_id is None or
260
255
  empty, then default boto3 configuration would be used (and must be
261
256
  maintained on each worker node).
257
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
258
+ :param verify: Whether or not to verify SSL certificates. See:
259
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h
262
260
  :param poll_interval: Time in seconds to wait between two consecutive call to
263
261
  check query status on athena, defaults to 10
264
262
  :param deferrable: Run sensor in the deferrable mode.
@@ -276,7 +274,8 @@ class EmrContainerSensor(BaseSensorOperator):
276
274
  )
277
275
  SUCCESS_STATES = ("COMPLETED",)
278
276
 
279
- template_fields: Sequence[str] = ("virtual_cluster_id", "job_id")
277
+ aws_hook_class = EmrContainerHook
278
+ template_fields: Sequence[str] = aws_template_fields("virtual_cluster_id", "job_id")
280
279
  template_ext: Sequence[str] = ()
281
280
  ui_color = "#66c3ff"
282
281
 
@@ -286,22 +285,20 @@ class EmrContainerSensor(BaseSensorOperator):
286
285
  virtual_cluster_id: str,
287
286
  job_id: str,
288
287
  max_retries: int | None = None,
289
- aws_conn_id: str | None = "aws_default",
290
288
  poll_interval: int = 10,
291
289
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
292
290
  **kwargs: Any,
293
291
  ) -> None:
294
292
  super().__init__(**kwargs)
295
- self.aws_conn_id = aws_conn_id
296
293
  self.virtual_cluster_id = virtual_cluster_id
297
294
  self.job_id = job_id
298
295
  self.poll_interval = poll_interval
299
296
  self.max_retries = max_retries
300
297
  self.deferrable = deferrable
301
298
 
302
- @cached_property
303
- def hook(self) -> EmrContainerHook:
304
- return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
299
+ @property
300
+ def _hook_parameters(self):
301
+ return {**super()._hook_parameters, "virtual_cluster_id": self.virtual_cluster_id}
305
302
 
306
303
  def poke(self, context: Context) -> bool:
307
304
  state = self.hook.poll_query_status(
@@ -369,7 +366,9 @@ class EmrNotebookExecutionSensor(EmrBaseSensor):
369
366
  Default failed_states is ``FAILED``.
370
367
  """
371
368
 
372
- template_fields: Sequence[str] = ("notebook_execution_id",)
369
+ template_fields: Sequence[str] = aws_template_fields(
370
+ "notebook_execution_id",
371
+ )
373
372
 
374
373
  FAILURE_STATES = {"FAILED"}
375
374
  COMPLETED_STATES = {"FINISHED"}
@@ -387,10 +386,9 @@ class EmrNotebookExecutionSensor(EmrBaseSensor):
387
386
  self.failed_states = failed_states or self.FAILURE_STATES
388
387
 
389
388
  def get_emr_response(self, context: Context) -> dict[str, Any]:
390
- emr_client = self.hook.conn
391
389
  self.log.info("Poking notebook %s", self.notebook_execution_id)
392
390
 
393
- return emr_client.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id)
391
+ return self.hook.conn.describe_notebook_execution(NotebookExecutionId=self.notebook_execution_id)
394
392
 
395
393
  @staticmethod
396
394
  def state_from_response(response: dict[str, Any]) -> str:
@@ -438,7 +436,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
438
436
  :param deferrable: Run sensor in the deferrable mode.
439
437
  """
440
438
 
441
- template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
439
+ template_fields: Sequence[str] = aws_template_fields("job_flow_id", "target_states", "failed_states")
442
440
  template_ext: Sequence[str] = ()
443
441
  operator_extra_links = (
444
442
  EmrClusterLink(),
@@ -471,9 +469,8 @@ class EmrJobFlowSensor(EmrBaseSensor):
471
469
 
472
470
  :return: response
473
471
  """
474
- emr_client = self.hook.conn
475
472
  self.log.info("Poking cluster %s", self.job_flow_id)
476
- response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
473
+ response = self.hook.conn.describe_cluster(ClusterId=self.job_flow_id)
477
474
 
478
475
  EmrClusterLink.persist(
479
476
  context=context,
@@ -563,7 +560,9 @@ class EmrStepSensor(EmrBaseSensor):
563
560
  :param deferrable: Run sensor in the deferrable mode.
564
561
  """
565
562
 
566
- template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states")
563
+ template_fields: Sequence[str] = aws_template_fields(
564
+ "job_flow_id", "step_id", "target_states", "failed_states"
565
+ )
567
566
  template_ext: Sequence[str] = ()
568
567
  operator_extra_links = (
569
568
  EmrClusterLink(),
@@ -598,10 +597,8 @@ class EmrStepSensor(EmrBaseSensor):
598
597
 
599
598
  :return: response
600
599
  """
601
- emr_client = self.hook.conn
602
-
603
600
  self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id)
604
- response = emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
601
+ response = self.hook.conn.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
605
602
 
606
603
  EmrClusterLink.persist(
607
604
  context=context,
@@ -616,7 +613,7 @@ class EmrStepSensor(EmrBaseSensor):
616
613
  region_name=self.hook.conn_region_name,
617
614
  aws_partition=self.hook.conn_partition,
618
615
  job_flow_id=self.job_flow_id,
619
- log_uri=get_log_uri(emr_client=emr_client, job_flow_id=self.job_flow_id),
616
+ log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self.job_flow_id),
620
617
  )
621
618
 
622
619
  return response
@@ -1258,13 +1258,6 @@ def get_provider_info():
1258
1258
  "aws_auth_manager": {
1259
1259
  "description": "This section only applies if you are using the AwsAuthManager. In other words, if you set\n``[core] auth_manager = airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager`` in\nAirflow's configuration.\n",
1260
1260
  "options": {
1261
- "enable": {
1262
- "description": "AWS auth manager is not ready to be used. Turn on this flag to use it anyway.\nDo that at your own risk since the AWS auth manager is not in an usable state.\n",
1263
- "version_added": "8.12.0",
1264
- "type": "boolean",
1265
- "example": "True",
1266
- "default": "False",
1267
- },
1268
1261
  "conn_id": {
1269
1262
  "description": "The Airflow connection (i.e. credentials) used by the AWS auth manager to make API calls to AWS\nIdentity Center and Amazon Verified Permissions.\n",
1270
1263
  "version_added": "8.12.0",
@@ -1297,4 +1290,6 @@ def get_provider_info():
1297
1290
  },
1298
1291
  },
1299
1292
  "executors": ["airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor"],
1293
+ "auth-managers": ["airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager"],
1294
+ "queues": ["airflow.providers.amazon.aws.queues.sqs.SqsMessageQueueProvider"],
1300
1295
  }
@@ -32,5 +32,4 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
32
32
  return airflow_version.major, airflow_version.minor, airflow_version.micro
33
33
 
34
34
 
35
- AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0)
36
35
  AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)