apache-airflow-providers-amazon 9.6.0rc1__py3-none-any.whl → 9.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +15 -18
  3. airflow/providers/amazon/aws/auth_manager/router/login.py +1 -1
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +3 -4
  5. airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +1 -1
  6. airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
  7. airflow/providers/amazon/aws/hooks/athena.py +1 -1
  8. airflow/providers/amazon/aws/hooks/base_aws.py +12 -15
  9. airflow/providers/amazon/aws/hooks/batch_client.py +11 -0
  10. airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -2
  11. airflow/providers/amazon/aws/hooks/datasync.py +2 -2
  12. airflow/providers/amazon/aws/hooks/dms.py +2 -3
  13. airflow/providers/amazon/aws/hooks/dynamodb.py +1 -2
  14. airflow/providers/amazon/aws/hooks/emr.py +14 -17
  15. airflow/providers/amazon/aws/hooks/glue.py +9 -13
  16. airflow/providers/amazon/aws/hooks/mwaa.py +6 -7
  17. airflow/providers/amazon/aws/hooks/redshift_data.py +1 -1
  18. airflow/providers/amazon/aws/hooks/redshift_sql.py +5 -6
  19. airflow/providers/amazon/aws/hooks/s3.py +3 -6
  20. airflow/providers/amazon/aws/hooks/sagemaker.py +6 -9
  21. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +5 -6
  22. airflow/providers/amazon/aws/links/base_aws.py +2 -2
  23. airflow/providers/amazon/aws/links/emr.py +2 -4
  24. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +3 -5
  25. airflow/providers/amazon/aws/log/s3_task_handler.py +1 -2
  26. airflow/providers/amazon/aws/operators/athena.py +1 -1
  27. airflow/providers/amazon/aws/operators/batch.py +37 -42
  28. airflow/providers/amazon/aws/operators/bedrock.py +1 -1
  29. airflow/providers/amazon/aws/operators/ecs.py +4 -6
  30. airflow/providers/amazon/aws/operators/eks.py +146 -139
  31. airflow/providers/amazon/aws/operators/emr.py +4 -5
  32. airflow/providers/amazon/aws/operators/mwaa.py +1 -1
  33. airflow/providers/amazon/aws/operators/neptune.py +2 -2
  34. airflow/providers/amazon/aws/operators/redshift_data.py +1 -2
  35. airflow/providers/amazon/aws/operators/s3.py +9 -13
  36. airflow/providers/amazon/aws/operators/sagemaker.py +11 -19
  37. airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -2
  38. airflow/providers/amazon/aws/sensors/batch.py +33 -55
  39. airflow/providers/amazon/aws/sensors/eks.py +64 -54
  40. airflow/providers/amazon/aws/sensors/glacier.py +4 -5
  41. airflow/providers/amazon/aws/sensors/glue.py +6 -9
  42. airflow/providers/amazon/aws/sensors/glue_crawler.py +2 -4
  43. airflow/providers/amazon/aws/sensors/redshift_cluster.py +1 -1
  44. airflow/providers/amazon/aws/sensors/s3.py +1 -2
  45. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +4 -5
  46. airflow/providers/amazon/aws/sensors/sqs.py +1 -2
  47. airflow/providers/amazon/aws/utils/connection_wrapper.py +1 -1
  48. airflow/providers/amazon/aws/utils/sqs.py +1 -2
  49. airflow/providers/amazon/aws/utils/tags.py +2 -3
  50. airflow/providers/amazon/aws/waiters/mwaa.json +1 -1
  51. {apache_airflow_providers_amazon-9.6.0rc1.dist-info → apache_airflow_providers_amazon-9.6.1.dist-info}/METADATA +11 -10
  52. {apache_airflow_providers_amazon-9.6.0rc1.dist-info → apache_airflow_providers_amazon-9.6.1.dist-info}/RECORD +54 -54
  53. {apache_airflow_providers_amazon-9.6.0rc1.dist-info → apache_airflow_providers_amazon-9.6.1.dist-info}/WHEEL +0 -0
  54. {apache_airflow_providers_amazon-9.6.0rc1.dist-info → apache_airflow_providers_amazon-9.6.1.dist-info}/entry_points.txt +0 -0
@@ -850,9 +850,8 @@ class EmrModifyClusterOperator(BaseOperator):
850
850
 
851
851
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
852
852
  raise AirflowException(f"Modify cluster failed: {response}")
853
- else:
854
- self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"])
855
- return response["StepConcurrencyLevel"]
853
+ self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"])
854
+ return response["StepConcurrencyLevel"]
856
855
 
857
856
 
858
857
  class EmrTerminateJobFlowOperator(BaseOperator):
@@ -1070,7 +1069,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
1070
1069
  if event is None:
1071
1070
  self.log.error("Trigger error: event is None")
1072
1071
  raise AirflowException("Trigger error: event is None")
1073
- elif event["status"] != "success":
1072
+ if event["status"] != "success":
1074
1073
  raise AirflowException(f"Application {event['application_id']} failed to create")
1075
1074
  self.log.info("Starting application %s", event["application_id"])
1076
1075
  self.hook.conn.start_application(applicationId=event["application_id"])
@@ -1533,7 +1532,7 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
1533
1532
  if event is None:
1534
1533
  self.log.error("Trigger error: event is None")
1535
1534
  raise AirflowException("Trigger error: event is None")
1536
- elif event["status"] == "success":
1535
+ if event["status"] == "success":
1537
1536
  self.hook.conn.stop_application(applicationId=self.application_id)
1538
1537
  self.defer(
1539
1538
  trigger=EmrServerlessStopApplicationTrigger(
@@ -97,7 +97,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
97
97
  note: str | None = None,
98
98
  wait_for_completion: bool = False,
99
99
  waiter_delay: int = 60,
100
- waiter_max_attempts: int = 720,
100
+ waiter_max_attempts: int = 20,
101
101
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
102
102
  **kwargs,
103
103
  ):
@@ -139,7 +139,7 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
139
139
  if status.lower() in NeptuneHook.AVAILABLE_STATES:
140
140
  self.log.info("Neptune cluster %s is already available.", self.cluster_id)
141
141
  return {"db_cluster_id": self.cluster_id}
142
- elif status.lower() in NeptuneHook.ERROR_STATES:
142
+ if status.lower() in NeptuneHook.ERROR_STATES:
143
143
  # some states will not allow you to start the cluster
144
144
  self.log.error(
145
145
  "Neptune cluster %s is in error state %s and cannot be started", self.cluster_id, status
@@ -259,7 +259,7 @@ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
259
259
  if status.lower() in NeptuneHook.STOPPED_STATES:
260
260
  self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
261
261
  return {"db_cluster_id": self.cluster_id}
262
- elif status.lower() in NeptuneHook.ERROR_STATES:
262
+ if status.lower() in NeptuneHook.ERROR_STATES:
263
263
  # some states will not allow you to stop the cluster
264
264
  self.log.error(
265
265
  "Neptune cluster %s is in error state %s and cannot be stopped", self.cluster_id, status
@@ -224,8 +224,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
224
224
  results: list = [self.hook.conn.get_statement_result(Id=sid) for sid in statement_ids]
225
225
  self.log.debug("Statement result(s): %s", results)
226
226
  return results
227
- else:
228
- return statement_ids
227
+ return statement_ids
229
228
 
230
229
  def on_kill(self) -> None:
231
230
  """Cancel the submitted redshift query."""
@@ -158,9 +158,8 @@ class S3GetBucketTaggingOperator(AwsBaseOperator[S3Hook]):
158
158
  if self.hook.check_for_bucket(self.bucket_name):
159
159
  self.log.info("Getting tags for bucket %s", self.bucket_name)
160
160
  return self.hook.get_bucket_tagging(self.bucket_name)
161
- else:
162
- self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
163
- return None
161
+ self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
162
+ return None
164
163
 
165
164
 
166
165
  class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]):
@@ -213,9 +212,8 @@ class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]):
213
212
  return self.hook.put_bucket_tagging(
214
213
  key=self.key, value=self.value, tag_set=self.tag_set, bucket_name=self.bucket_name
215
214
  )
216
- else:
217
- self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
218
- return None
215
+ self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
216
+ return None
219
217
 
220
218
 
221
219
  class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]):
@@ -254,9 +252,8 @@ class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]):
254
252
  if self.hook.check_for_bucket(self.bucket_name):
255
253
  self.log.info("Deleting tags for bucket %s", self.bucket_name)
256
254
  return self.hook.delete_bucket_tagging(self.bucket_name)
257
- else:
258
- self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
259
- return None
255
+ self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
256
+ return None
260
257
 
261
258
 
262
259
  class S3CopyObjectOperator(AwsBaseOperator[S3Hook]):
@@ -725,10 +722,9 @@ class S3FileTransformOperator(AwsBaseOperator[S3Hook]):
725
722
 
726
723
  if process.returncode:
727
724
  raise AirflowException(f"Transform script failed: {process.returncode}")
728
- else:
729
- self.log.info(
730
- "Transform script successful. Output temporarily located at %s", f_dest.name
731
- )
725
+ self.log.info(
726
+ "Transform script successful. Output temporarily located at %s", f_dest.name
727
+ )
732
728
 
733
729
  self.log.info("Uploading transformed file to S3")
734
730
  f_dest.flush()
@@ -165,13 +165,10 @@ class SageMakerBaseOperator(BaseOperator):
165
165
  # in case there is collision.
166
166
  if fail_if_exists:
167
167
  raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
168
- else:
169
- max_name_len = 63
170
- timestamp = str(
171
- time.time_ns() // 1000000000
172
- ) # only keep the relevant datetime (first 10 digits)
173
- name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
174
- self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
168
+ max_name_len = 63
169
+ timestamp = str(time.time_ns() // 1000000000) # only keep the relevant datetime (first 10 digits)
170
+ name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
171
+ self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
175
172
  return name
176
173
 
177
174
  def _check_resource_type(self, resource_type: str):
@@ -197,8 +194,7 @@ class SageMakerBaseOperator(BaseOperator):
197
194
  except ClientError as e:
198
195
  if e.response["Error"]["Code"] == "ValidationException":
199
196
  return False # ValidationException is thrown when the resource could not be found
200
- else:
201
- raise e
197
+ raise e
202
198
 
203
199
  def execute(self, context: Context):
204
200
  raise NotImplementedError("Please implement execute() in sub class!")
@@ -326,7 +322,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
326
322
  status = response["ProcessingJobStatus"]
327
323
  if status in self.hook.failed_states:
328
324
  raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
329
- elif status == "Completed":
325
+ if status == "Completed":
330
326
  self.log.info("%s completed successfully.", self.task_id)
331
327
  return {"Processing": serialize(response)}
332
328
 
@@ -430,12 +426,9 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
430
426
  response = self.hook.create_endpoint_config(self.config)
431
427
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
432
428
  raise AirflowException(f"Sagemaker endpoint config creation failed: {response}")
433
- else:
434
- return {
435
- "EndpointConfig": serialize(
436
- self.hook.describe_endpoint_config(self.config["EndpointConfigName"])
437
- )
438
- }
429
+ return {
430
+ "EndpointConfig": serialize(self.hook.describe_endpoint_config(self.config["EndpointConfigName"]))
431
+ }
439
432
 
440
433
 
441
434
  class SageMakerEndpointOperator(SageMakerBaseOperator):
@@ -1038,8 +1031,7 @@ class SageMakerModelOperator(SageMakerBaseOperator):
1038
1031
  response = self.hook.create_model(self.config)
1039
1032
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
1040
1033
  raise AirflowException(f"Sagemaker model creation failed: {response}")
1041
- else:
1042
- return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))}
1034
+ return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))}
1043
1035
 
1044
1036
 
1045
1037
  class SageMakerTrainingOperator(SageMakerBaseOperator):
@@ -1177,7 +1169,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1177
1169
  if status in self.hook.failed_states:
1178
1170
  reason = description.get("FailureReason", "(No reason provided)")
1179
1171
  raise AirflowException(f"SageMaker job failed because {reason}")
1180
- elif status == "Completed":
1172
+ if status == "Completed":
1181
1173
  log_message = f"{self.task_id} completed successfully."
1182
1174
  if self.print_log:
1183
1175
  billable_seconds = SageMakerHook.count_billable_seconds(
@@ -224,8 +224,7 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
224
224
  standardized_secret_dict = self._standardize_secret_keys(secret_dict)
225
225
  standardized_secret = json.dumps(standardized_secret_dict)
226
226
  return standardized_secret
227
- else:
228
- return secret
227
+ return secret
229
228
 
230
229
  def get_variable(self, key: str) -> str | None:
231
230
  """
@@ -18,20 +18,20 @@ from __future__ import annotations
18
18
 
19
19
  from collections.abc import Sequence
20
20
  from datetime import timedelta
21
- from functools import cached_property
22
21
  from typing import TYPE_CHECKING, Any
23
22
 
24
23
  from airflow.configuration import conf
25
24
  from airflow.exceptions import AirflowException
26
25
  from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
26
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
27
  from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
28
- from airflow.sensors.base import BaseSensorOperator
28
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from airflow.utils.context import Context
32
32
 
33
33
 
34
- class BatchSensor(BaseSensorOperator):
34
+ class BatchSensor(AwsBaseSensor[BatchClientHook]):
35
35
  """
36
36
  Poll the state of the Batch Job until it reaches a terminal state; fails if the job fails.
37
37
 
@@ -40,19 +40,24 @@ class BatchSensor(BaseSensorOperator):
40
40
  :ref:`howto/sensor:BatchSensor`
41
41
 
42
42
  :param job_id: Batch job_id to check the state for
43
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
44
- If this is None or empty then the default boto3 behaviour is used. If
43
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
44
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
45
45
  running Airflow in a distributed manner and aws_conn_id is None or
46
46
  empty, then default boto3 configuration would be used (and must be
47
47
  maintained on each worker node).
48
- :param region_name: aws region name associated with the client
48
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
49
+ :param verify: Whether or not to verify SSL certificates. See:
50
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
49
51
  :param deferrable: Run sensor in the deferrable mode.
50
52
  :param poke_interval: polling period in seconds to check for the status of the job.
51
53
  :param max_retries: Number of times to poll for job state before
52
54
  returning the current state.
53
55
  """
54
56
 
55
- template_fields: Sequence[str] = ("job_id",)
57
+ aws_hook_class = BatchClientHook
58
+ template_fields: Sequence[str] = aws_template_fields(
59
+ "job_id",
60
+ )
56
61
  template_ext: Sequence[str] = ()
57
62
  ui_color = "#66c3ff"
58
63
 
@@ -60,8 +65,6 @@ class BatchSensor(BaseSensorOperator):
60
65
  self,
61
66
  *,
62
67
  job_id: str,
63
- aws_conn_id: str | None = "aws_default",
64
- region_name: str | None = None,
65
68
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
66
69
  poke_interval: float = 30,
67
70
  max_retries: int = 4200,
@@ -69,8 +72,6 @@ class BatchSensor(BaseSensorOperator):
69
72
  ):
70
73
  super().__init__(**kwargs)
71
74
  self.job_id = job_id
72
- self.aws_conn_id = aws_conn_id
73
- self.region_name = region_name
74
75
  self.deferrable = deferrable
75
76
  self.poke_interval = poke_interval
76
77
  self.max_retries = max_retries
@@ -119,15 +120,8 @@ class BatchSensor(BaseSensorOperator):
119
120
  job_id = event["job_id"]
120
121
  self.log.info("Batch Job %s complete", job_id)
121
122
 
122
- @cached_property
123
- def hook(self) -> BatchClientHook:
124
- return BatchClientHook(
125
- aws_conn_id=self.aws_conn_id,
126
- region_name=self.region_name,
127
- )
128
-
129
123
 
130
- class BatchComputeEnvironmentSensor(BaseSensorOperator):
124
+ class BatchComputeEnvironmentSensor(AwsBaseSensor[BatchClientHook]):
131
125
  """
132
126
  Poll the state of the Batch environment until it reaches a terminal state; fails if the environment fails.
133
127
 
@@ -137,38 +131,31 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
137
131
 
138
132
  :param compute_environment: Batch compute environment name
139
133
 
140
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
141
- If this is None or empty then the default boto3 behaviour is used. If
134
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
135
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
142
136
  running Airflow in a distributed manner and aws_conn_id is None or
143
137
  empty, then default boto3 configuration would be used (and must be
144
138
  maintained on each worker node).
139
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
140
+ :param verify: Whether or not to verify SSL certificates. See:
141
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
145
142
 
146
- :param region_name: aws region name associated with the client
147
143
  """
148
144
 
149
- template_fields: Sequence[str] = ("compute_environment",)
145
+ aws_hook_class = BatchClientHook
146
+ template_fields: Sequence[str] = aws_template_fields(
147
+ "compute_environment",
148
+ )
150
149
  template_ext: Sequence[str] = ()
151
150
  ui_color = "#66c3ff"
152
151
 
153
152
  def __init__(
154
153
  self,
155
154
  compute_environment: str,
156
- aws_conn_id: str | None = "aws_default",
157
- region_name: str | None = None,
158
155
  **kwargs,
159
156
  ):
160
157
  super().__init__(**kwargs)
161
158
  self.compute_environment = compute_environment
162
- self.aws_conn_id = aws_conn_id
163
- self.region_name = region_name
164
-
165
- @cached_property
166
- def hook(self) -> BatchClientHook:
167
- """Create and return a BatchClientHook."""
168
- return BatchClientHook(
169
- aws_conn_id=self.aws_conn_id,
170
- region_name=self.region_name,
171
- )
172
159
 
173
160
  def poke(self, context: Context) -> bool:
174
161
  response = self.hook.client.describe_compute_environments( # type: ignore[union-attr]
@@ -191,7 +178,7 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
191
178
  )
192
179
 
193
180
 
194
- class BatchJobQueueSensor(BaseSensorOperator):
181
+ class BatchJobQueueSensor(AwsBaseSensor[BatchClientHook]):
195
182
  """
196
183
  Poll the state of the Batch job queue until it reaches a terminal state; fails if the queue fails.
197
184
 
@@ -204,16 +191,20 @@ class BatchJobQueueSensor(BaseSensorOperator):
204
191
  :param treat_non_existing_as_deleted: If True, a non-existing Batch job queue is considered as a deleted
205
192
  queue and as such a valid case.
206
193
 
207
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
208
- If this is None or empty then the default boto3 behaviour is used. If
194
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
195
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
209
196
  running Airflow in a distributed manner and aws_conn_id is None or
210
197
  empty, then default boto3 configuration would be used (and must be
211
198
  maintained on each worker node).
212
-
213
- :param region_name: aws region name associated with the client
199
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
200
+ :param verify: Whether or not to verify SSL certificates. See:
201
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
214
202
  """
215
203
 
216
- template_fields: Sequence[str] = ("job_queue",)
204
+ aws_hook_class = BatchClientHook
205
+ template_fields: Sequence[str] = aws_template_fields(
206
+ "job_queue",
207
+ )
217
208
  template_ext: Sequence[str] = ()
218
209
  ui_color = "#66c3ff"
219
210
 
@@ -221,23 +212,11 @@ class BatchJobQueueSensor(BaseSensorOperator):
221
212
  self,
222
213
  job_queue: str,
223
214
  treat_non_existing_as_deleted: bool = False,
224
- aws_conn_id: str | None = "aws_default",
225
- region_name: str | None = None,
226
215
  **kwargs,
227
216
  ):
228
217
  super().__init__(**kwargs)
229
218
  self.job_queue = job_queue
230
219
  self.treat_non_existing_as_deleted = treat_non_existing_as_deleted
231
- self.aws_conn_id = aws_conn_id
232
- self.region_name = region_name
233
-
234
- @cached_property
235
- def hook(self) -> BatchClientHook:
236
- """Create and return a BatchClientHook."""
237
- return BatchClientHook(
238
- aws_conn_id=self.aws_conn_id,
239
- region_name=self.region_name,
240
- )
241
220
 
242
221
  def poke(self, context: Context) -> bool:
243
222
  response = self.hook.client.describe_job_queues( # type: ignore[union-attr]
@@ -247,8 +226,7 @@ class BatchJobQueueSensor(BaseSensorOperator):
247
226
  if not response["jobQueues"]:
248
227
  if self.treat_non_existing_as_deleted:
249
228
  return True
250
- else:
251
- raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
229
+ raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
252
230
 
253
231
  status = response["jobQueues"][0]["status"]
254
232
 
@@ -18,19 +18,20 @@
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import warnings
21
22
  from abc import abstractmethod
22
23
  from collections.abc import Sequence
23
- from functools import cached_property
24
24
  from typing import TYPE_CHECKING
25
25
 
26
- from airflow.exceptions import AirflowException
26
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
27
27
  from airflow.providers.amazon.aws.hooks.eks import (
28
28
  ClusterStates,
29
29
  EksHook,
30
30
  FargateProfileStates,
31
31
  NodegroupStates,
32
32
  )
33
- from airflow.sensors.base import BaseSensorOperator
33
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
34
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  from airflow.utils.context import Context
@@ -57,7 +58,7 @@ NODEGROUP_TERMINAL_STATES = frozenset(
57
58
  )
58
59
 
59
60
 
60
- class EksBaseSensor(BaseSensorOperator):
61
+ class EksBaseSensor(AwsBaseSensor):
61
62
  """
62
63
  Base class to check various EKS states.
63
64
 
@@ -68,41 +69,33 @@ class EksBaseSensor(BaseSensorOperator):
68
69
  :param target_state_type: The enum containing the states,
69
70
  will be used to convert the target state if it has to be converted from a string
70
71
  :param aws_conn_id: The Airflow connection used for AWS credentials.
71
- If this is None or empty then the default boto3 behaviour is used. If
72
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
72
73
  running Airflow in a distributed manner and aws_conn_id is None or
73
- empty, then the default boto3 configuration would be used (and must be
74
+ empty, then default boto3 configuration would be used (and must be
74
75
  maintained on each worker node).
75
- :param region: Which AWS region the connection should use.
76
- If this is None or empty then the default boto3 behaviour is used.
76
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
77
+ :param verify: Whether or not to verify SSL certificates. See:
78
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
77
79
  """
78
80
 
81
+ aws_hook_class = EksHook
82
+
79
83
  def __init__(
80
84
  self,
81
85
  *,
82
86
  cluster_name: str,
83
87
  target_state: ClusterStates | NodegroupStates | FargateProfileStates,
84
88
  target_state_type: type,
85
- aws_conn_id: str | None = DEFAULT_CONN_ID,
86
- region: str | None = None,
87
89
  **kwargs,
88
90
  ):
89
91
  super().__init__(**kwargs)
90
92
  self.cluster_name = cluster_name
91
- self.aws_conn_id = aws_conn_id
92
- self.region = region
93
93
  self.target_state = (
94
94
  target_state
95
95
  if isinstance(target_state, target_state_type)
96
96
  else target_state_type(str(target_state).upper())
97
97
  )
98
98
 
99
- @cached_property
100
- def hook(self) -> EksHook:
101
- return EksHook(
102
- aws_conn_id=self.aws_conn_id,
103
- region_name=self.region,
104
- )
105
-
106
99
  def poke(self, context: Context) -> bool:
107
100
  state = self.get_state()
108
101
  self.log.info("Current state: %s", state)
@@ -130,16 +123,17 @@ class EksClusterStateSensor(EksBaseSensor):
130
123
 
131
124
  :param cluster_name: The name of the Cluster to watch. (templated)
132
125
  :param target_state: Target state of the Cluster. (templated)
133
- :param region: Which AWS region the connection should use. (templated)
134
- If this is None or empty then the default boto3 behaviour is used.
135
- :param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
136
- If this is None or empty then the default boto3 behaviour is used. If
137
- running Airflow in a distributed manner and aws_conn_id is None or
138
- empty, then the default boto3 configuration would be used (and must be
139
- maintained on each worker node).
126
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
127
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
128
+ running Airflow in a distributed manner and aws_conn_id is None or
129
+ empty, then default boto3 configuration would be used (and must be
130
+ maintained on each worker node).
131
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
132
+ :param verify: Whether or not to verify SSL certificates. See:
133
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
140
134
  """
141
135
 
142
- template_fields: Sequence[str] = ("cluster_name", "target_state", "aws_conn_id", "region")
136
+ template_fields: Sequence[str] = aws_template_fields("cluster_name", "target_state")
143
137
  ui_color = "#ff9900"
144
138
  ui_fgcolor = "#232F3E"
145
139
 
@@ -147,8 +141,16 @@ class EksClusterStateSensor(EksBaseSensor):
147
141
  self,
148
142
  *,
149
143
  target_state: ClusterStates = ClusterStates.ACTIVE,
144
+ region: str | None = None,
150
145
  **kwargs,
151
146
  ):
147
+ if region is not None:
148
+ warnings.warn(
149
+ message="Parameter `region` is deprecated. Use the parameter `region_name` instead",
150
+ category=AirflowProviderDeprecationWarning,
151
+ stacklevel=2,
152
+ )
153
+ kwargs["region_name"] = region
152
154
  super().__init__(target_state=target_state, target_state_type=ClusterStates, **kwargs)
153
155
 
154
156
  def get_state(self) -> ClusterStates:
@@ -169,21 +171,18 @@ class EksFargateProfileStateSensor(EksBaseSensor):
169
171
  :param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated)
170
172
  :param fargate_profile_name: The name of the Fargate profile to watch. (templated)
171
173
  :param target_state: Target state of the Fargate profile. (templated)
172
- :param region: Which AWS region the connection should use. (templated)
173
- If this is None or empty then the default boto3 behaviour is used.
174
- :param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
175
- If this is None or empty then the default boto3 behaviour is used. If
176
- running Airflow in a distributed manner and aws_conn_id is None or
177
- empty, then the default boto3 configuration would be used (and must be
178
- maintained on each worker node).
174
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
175
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
176
+ running Airflow in a distributed manner and aws_conn_id is None or
177
+ empty, then default boto3 configuration would be used (and must be
178
+ maintained on each worker node).
179
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
180
+ :param verify: Whether or not to verify SSL certificates. See:
181
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
179
182
  """
180
183
 
181
- template_fields: Sequence[str] = (
182
- "cluster_name",
183
- "fargate_profile_name",
184
- "target_state",
185
- "aws_conn_id",
186
- "region",
184
+ template_fields: Sequence[str] = aws_template_fields(
185
+ "cluster_name", "fargate_profile_name", "target_state"
187
186
  )
188
187
  ui_color = "#ff9900"
189
188
  ui_fgcolor = "#232F3E"
@@ -192,9 +191,17 @@ class EksFargateProfileStateSensor(EksBaseSensor):
192
191
  self,
193
192
  *,
194
193
  fargate_profile_name: str,
194
+ region: str | None = None,
195
195
  target_state: FargateProfileStates = FargateProfileStates.ACTIVE,
196
196
  **kwargs,
197
197
  ):
198
+ if region is not None:
199
+ warnings.warn(
200
+ message="Parameter `region` is deprecated. Use the parameter `region_name` instead",
201
+ category=AirflowProviderDeprecationWarning,
202
+ stacklevel=2,
203
+ )
204
+ kwargs["region_name"] = region
198
205
  super().__init__(target_state=target_state, target_state_type=FargateProfileStates, **kwargs)
199
206
  self.fargate_profile_name = fargate_profile_name
200
207
 
@@ -218,22 +225,17 @@ class EksNodegroupStateSensor(EksBaseSensor):
218
225
  :param cluster_name: The name of the Cluster which the Nodegroup is attached to. (templated)
219
226
  :param nodegroup_name: The name of the Nodegroup to watch. (templated)
220
227
  :param target_state: Target state of the Nodegroup. (templated)
221
- :param region: Which AWS region the connection should use. (templated)
222
- If this is None or empty then the default boto3 behaviour is used.
223
- :param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
224
- If this is None or empty then the default boto3 behaviour is used. If
225
- running Airflow in a distributed manner and aws_conn_id is None or
226
- empty, then the default boto3 configuration would be used (and must be
227
- maintained on each worker node).
228
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
229
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
230
+ running Airflow in a distributed manner and aws_conn_id is None or
231
+ empty, then default boto3 configuration would be used (and must be
232
+ maintained on each worker node).
233
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
234
+ :param verify: Whether or not to verify SSL certificates. See:
235
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
228
236
  """
229
237
 
230
- template_fields: Sequence[str] = (
231
- "cluster_name",
232
- "nodegroup_name",
233
- "target_state",
234
- "aws_conn_id",
235
- "region",
236
- )
238
+ template_fields: Sequence[str] = aws_template_fields("cluster_name", "nodegroup_name", "target_state")
237
239
  ui_color = "#ff9900"
238
240
  ui_fgcolor = "#232F3E"
239
241
 
@@ -242,8 +244,16 @@ class EksNodegroupStateSensor(EksBaseSensor):
242
244
  *,
243
245
  nodegroup_name: str,
244
246
  target_state: NodegroupStates = NodegroupStates.ACTIVE,
247
+ region: str | None = None,
245
248
  **kwargs,
246
249
  ):
250
+ if region is not None:
251
+ warnings.warn(
252
+ message="Parameter `region` is deprecated. Use the parameter `region_name` instead",
253
+ category=AirflowProviderDeprecationWarning,
254
+ stacklevel=2,
255
+ )
256
+ kwargs["region_name"] = region
247
257
  super().__init__(target_state=target_state, target_state_type=NodegroupStates, **kwargs)
248
258
  self.nodegroup_name = nodegroup_name
249
259
 
@@ -89,11 +89,10 @@ class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
89
89
  self.log.info("Job status: %s, code status: %s", response["Action"], response["StatusCode"])
90
90
  self.log.info("Job finished successfully")
91
91
  return True
92
- elif response["StatusCode"] == JobStatus.IN_PROGRESS.value:
92
+ if response["StatusCode"] == JobStatus.IN_PROGRESS.value:
93
93
  self.log.info("Processing...")
94
94
  self.log.warning("Code status: %s", response["StatusCode"])
95
95
  return False
96
- else:
97
- raise AirflowException(
98
- f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
99
- )
96
+ raise AirflowException(
97
+ f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
98
+ )