apache-airflow-providers-amazon 9.6.0__py3-none-any.whl → 9.6.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +15 -18
- airflow/providers/amazon/aws/auth_manager/router/login.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +3 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +1 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +12 -15
- airflow/providers/amazon/aws/hooks/batch_client.py +11 -0
- airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -2
- airflow/providers/amazon/aws/hooks/datasync.py +2 -2
- airflow/providers/amazon/aws/hooks/dms.py +2 -3
- airflow/providers/amazon/aws/hooks/dynamodb.py +1 -2
- airflow/providers/amazon/aws/hooks/emr.py +14 -17
- airflow/providers/amazon/aws/hooks/glue.py +9 -13
- airflow/providers/amazon/aws/hooks/mwaa.py +6 -7
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_sql.py +5 -6
- airflow/providers/amazon/aws/hooks/s3.py +3 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +6 -9
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +5 -6
- airflow/providers/amazon/aws/links/base_aws.py +2 -2
- airflow/providers/amazon/aws/links/emr.py +2 -4
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +3 -5
- airflow/providers/amazon/aws/log/s3_task_handler.py +1 -2
- airflow/providers/amazon/aws/operators/athena.py +1 -1
- airflow/providers/amazon/aws/operators/batch.py +37 -42
- airflow/providers/amazon/aws/operators/bedrock.py +1 -1
- airflow/providers/amazon/aws/operators/ecs.py +4 -6
- airflow/providers/amazon/aws/operators/eks.py +146 -139
- airflow/providers/amazon/aws/operators/emr.py +4 -5
- airflow/providers/amazon/aws/operators/mwaa.py +1 -1
- airflow/providers/amazon/aws/operators/neptune.py +2 -2
- airflow/providers/amazon/aws/operators/redshift_data.py +1 -2
- airflow/providers/amazon/aws/operators/s3.py +9 -13
- airflow/providers/amazon/aws/operators/sagemaker.py +11 -19
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -2
- airflow/providers/amazon/aws/sensors/batch.py +33 -55
- airflow/providers/amazon/aws/sensors/eks.py +64 -54
- airflow/providers/amazon/aws/sensors/glacier.py +4 -5
- airflow/providers/amazon/aws/sensors/glue.py +6 -9
- airflow/providers/amazon/aws/sensors/glue_crawler.py +2 -4
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +1 -1
- airflow/providers/amazon/aws/sensors/s3.py +1 -2
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +4 -5
- airflow/providers/amazon/aws/sensors/sqs.py +1 -2
- airflow/providers/amazon/aws/utils/connection_wrapper.py +1 -1
- airflow/providers/amazon/aws/utils/sqs.py +1 -2
- airflow/providers/amazon/aws/utils/tags.py +2 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +1 -1
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/METADATA +11 -10
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/RECORD +54 -54
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/entry_points.txt +0 -0
@@ -850,9 +850,8 @@ class EmrModifyClusterOperator(BaseOperator):
|
|
850
850
|
|
851
851
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
852
852
|
raise AirflowException(f"Modify cluster failed: {response}")
|
853
|
-
|
854
|
-
|
855
|
-
return response["StepConcurrencyLevel"]
|
853
|
+
self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"])
|
854
|
+
return response["StepConcurrencyLevel"]
|
856
855
|
|
857
856
|
|
858
857
|
class EmrTerminateJobFlowOperator(BaseOperator):
|
@@ -1070,7 +1069,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
|
|
1070
1069
|
if event is None:
|
1071
1070
|
self.log.error("Trigger error: event is None")
|
1072
1071
|
raise AirflowException("Trigger error: event is None")
|
1073
|
-
|
1072
|
+
if event["status"] != "success":
|
1074
1073
|
raise AirflowException(f"Application {event['application_id']} failed to create")
|
1075
1074
|
self.log.info("Starting application %s", event["application_id"])
|
1076
1075
|
self.hook.conn.start_application(applicationId=event["application_id"])
|
@@ -1533,7 +1532,7 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
|
|
1533
1532
|
if event is None:
|
1534
1533
|
self.log.error("Trigger error: event is None")
|
1535
1534
|
raise AirflowException("Trigger error: event is None")
|
1536
|
-
|
1535
|
+
if event["status"] == "success":
|
1537
1536
|
self.hook.conn.stop_application(applicationId=self.application_id)
|
1538
1537
|
self.defer(
|
1539
1538
|
trigger=EmrServerlessStopApplicationTrigger(
|
@@ -97,7 +97,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
97
97
|
note: str | None = None,
|
98
98
|
wait_for_completion: bool = False,
|
99
99
|
waiter_delay: int = 60,
|
100
|
-
waiter_max_attempts: int =
|
100
|
+
waiter_max_attempts: int = 20,
|
101
101
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
102
102
|
**kwargs,
|
103
103
|
):
|
@@ -139,7 +139,7 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
139
139
|
if status.lower() in NeptuneHook.AVAILABLE_STATES:
|
140
140
|
self.log.info("Neptune cluster %s is already available.", self.cluster_id)
|
141
141
|
return {"db_cluster_id": self.cluster_id}
|
142
|
-
|
142
|
+
if status.lower() in NeptuneHook.ERROR_STATES:
|
143
143
|
# some states will not allow you to start the cluster
|
144
144
|
self.log.error(
|
145
145
|
"Neptune cluster %s is in error state %s and cannot be started", self.cluster_id, status
|
@@ -259,7 +259,7 @@ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
259
259
|
if status.lower() in NeptuneHook.STOPPED_STATES:
|
260
260
|
self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
|
261
261
|
return {"db_cluster_id": self.cluster_id}
|
262
|
-
|
262
|
+
if status.lower() in NeptuneHook.ERROR_STATES:
|
263
263
|
# some states will not allow you to stop the cluster
|
264
264
|
self.log.error(
|
265
265
|
"Neptune cluster %s is in error state %s and cannot be stopped", self.cluster_id, status
|
@@ -224,8 +224,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
|
|
224
224
|
results: list = [self.hook.conn.get_statement_result(Id=sid) for sid in statement_ids]
|
225
225
|
self.log.debug("Statement result(s): %s", results)
|
226
226
|
return results
|
227
|
-
|
228
|
-
return statement_ids
|
227
|
+
return statement_ids
|
229
228
|
|
230
229
|
def on_kill(self) -> None:
|
231
230
|
"""Cancel the submitted redshift query."""
|
@@ -158,9 +158,8 @@ class S3GetBucketTaggingOperator(AwsBaseOperator[S3Hook]):
|
|
158
158
|
if self.hook.check_for_bucket(self.bucket_name):
|
159
159
|
self.log.info("Getting tags for bucket %s", self.bucket_name)
|
160
160
|
return self.hook.get_bucket_tagging(self.bucket_name)
|
161
|
-
|
162
|
-
|
163
|
-
return None
|
161
|
+
self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
|
162
|
+
return None
|
164
163
|
|
165
164
|
|
166
165
|
class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]):
|
@@ -213,9 +212,8 @@ class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]):
|
|
213
212
|
return self.hook.put_bucket_tagging(
|
214
213
|
key=self.key, value=self.value, tag_set=self.tag_set, bucket_name=self.bucket_name
|
215
214
|
)
|
216
|
-
|
217
|
-
|
218
|
-
return None
|
215
|
+
self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
|
216
|
+
return None
|
219
217
|
|
220
218
|
|
221
219
|
class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]):
|
@@ -254,9 +252,8 @@ class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]):
|
|
254
252
|
if self.hook.check_for_bucket(self.bucket_name):
|
255
253
|
self.log.info("Deleting tags for bucket %s", self.bucket_name)
|
256
254
|
return self.hook.delete_bucket_tagging(self.bucket_name)
|
257
|
-
|
258
|
-
|
259
|
-
return None
|
255
|
+
self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
|
256
|
+
return None
|
260
257
|
|
261
258
|
|
262
259
|
class S3CopyObjectOperator(AwsBaseOperator[S3Hook]):
|
@@ -725,10 +722,9 @@ class S3FileTransformOperator(AwsBaseOperator[S3Hook]):
|
|
725
722
|
|
726
723
|
if process.returncode:
|
727
724
|
raise AirflowException(f"Transform script failed: {process.returncode}")
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
)
|
725
|
+
self.log.info(
|
726
|
+
"Transform script successful. Output temporarily located at %s", f_dest.name
|
727
|
+
)
|
732
728
|
|
733
729
|
self.log.info("Uploading transformed file to S3")
|
734
730
|
f_dest.flush()
|
@@ -165,13 +165,10 @@ class SageMakerBaseOperator(BaseOperator):
|
|
165
165
|
# in case there is collision.
|
166
166
|
if fail_if_exists:
|
167
167
|
raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
) # only keep the relevant datetime (first 10 digits)
|
173
|
-
name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
174
|
-
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
|
168
|
+
max_name_len = 63
|
169
|
+
timestamp = str(time.time_ns() // 1000000000) # only keep the relevant datetime (first 10 digits)
|
170
|
+
name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
171
|
+
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
|
175
172
|
return name
|
176
173
|
|
177
174
|
def _check_resource_type(self, resource_type: str):
|
@@ -197,8 +194,7 @@ class SageMakerBaseOperator(BaseOperator):
|
|
197
194
|
except ClientError as e:
|
198
195
|
if e.response["Error"]["Code"] == "ValidationException":
|
199
196
|
return False # ValidationException is thrown when the resource could not be found
|
200
|
-
|
201
|
-
raise e
|
197
|
+
raise e
|
202
198
|
|
203
199
|
def execute(self, context: Context):
|
204
200
|
raise NotImplementedError("Please implement execute() in sub class!")
|
@@ -326,7 +322,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
|
|
326
322
|
status = response["ProcessingJobStatus"]
|
327
323
|
if status in self.hook.failed_states:
|
328
324
|
raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
|
329
|
-
|
325
|
+
if status == "Completed":
|
330
326
|
self.log.info("%s completed successfully.", self.task_id)
|
331
327
|
return {"Processing": serialize(response)}
|
332
328
|
|
@@ -430,12 +426,9 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
|
|
430
426
|
response = self.hook.create_endpoint_config(self.config)
|
431
427
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
432
428
|
raise AirflowException(f"Sagemaker endpoint config creation failed: {response}")
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
self.hook.describe_endpoint_config(self.config["EndpointConfigName"])
|
437
|
-
)
|
438
|
-
}
|
429
|
+
return {
|
430
|
+
"EndpointConfig": serialize(self.hook.describe_endpoint_config(self.config["EndpointConfigName"]))
|
431
|
+
}
|
439
432
|
|
440
433
|
|
441
434
|
class SageMakerEndpointOperator(SageMakerBaseOperator):
|
@@ -1038,8 +1031,7 @@ class SageMakerModelOperator(SageMakerBaseOperator):
|
|
1038
1031
|
response = self.hook.create_model(self.config)
|
1039
1032
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
1040
1033
|
raise AirflowException(f"Sagemaker model creation failed: {response}")
|
1041
|
-
|
1042
|
-
return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))}
|
1034
|
+
return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))}
|
1043
1035
|
|
1044
1036
|
|
1045
1037
|
class SageMakerTrainingOperator(SageMakerBaseOperator):
|
@@ -1177,7 +1169,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1177
1169
|
if status in self.hook.failed_states:
|
1178
1170
|
reason = description.get("FailureReason", "(No reason provided)")
|
1179
1171
|
raise AirflowException(f"SageMaker job failed because {reason}")
|
1180
|
-
|
1172
|
+
if status == "Completed":
|
1181
1173
|
log_message = f"{self.task_id} completed successfully."
|
1182
1174
|
if self.print_log:
|
1183
1175
|
billable_seconds = SageMakerHook.count_billable_seconds(
|
@@ -224,8 +224,7 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
|
|
224
224
|
standardized_secret_dict = self._standardize_secret_keys(secret_dict)
|
225
225
|
standardized_secret = json.dumps(standardized_secret_dict)
|
226
226
|
return standardized_secret
|
227
|
-
|
228
|
-
return secret
|
227
|
+
return secret
|
229
228
|
|
230
229
|
def get_variable(self, key: str) -> str | None:
|
231
230
|
"""
|
@@ -18,20 +18,20 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
from collections.abc import Sequence
|
20
20
|
from datetime import timedelta
|
21
|
-
from functools import cached_property
|
22
21
|
from typing import TYPE_CHECKING, Any
|
23
22
|
|
24
23
|
from airflow.configuration import conf
|
25
24
|
from airflow.exceptions import AirflowException
|
26
25
|
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
|
26
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
27
27
|
from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
|
28
|
-
from airflow.
|
28
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
31
|
from airflow.utils.context import Context
|
32
32
|
|
33
33
|
|
34
|
-
class BatchSensor(
|
34
|
+
class BatchSensor(AwsBaseSensor[BatchClientHook]):
|
35
35
|
"""
|
36
36
|
Poll the state of the Batch Job until it reaches a terminal state; fails if the job fails.
|
37
37
|
|
@@ -40,19 +40,24 @@ class BatchSensor(BaseSensorOperator):
|
|
40
40
|
:ref:`howto/sensor:BatchSensor`
|
41
41
|
|
42
42
|
:param job_id: Batch job_id to check the state for
|
43
|
-
:param aws_conn_id:
|
44
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
43
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
44
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
45
45
|
running Airflow in a distributed manner and aws_conn_id is None or
|
46
46
|
empty, then default boto3 configuration would be used (and must be
|
47
47
|
maintained on each worker node).
|
48
|
-
:param region_name:
|
48
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
49
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
50
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
49
51
|
:param deferrable: Run sensor in the deferrable mode.
|
50
52
|
:param poke_interval: polling period in seconds to check for the status of the job.
|
51
53
|
:param max_retries: Number of times to poll for job state before
|
52
54
|
returning the current state.
|
53
55
|
"""
|
54
56
|
|
55
|
-
|
57
|
+
aws_hook_class = BatchClientHook
|
58
|
+
template_fields: Sequence[str] = aws_template_fields(
|
59
|
+
"job_id",
|
60
|
+
)
|
56
61
|
template_ext: Sequence[str] = ()
|
57
62
|
ui_color = "#66c3ff"
|
58
63
|
|
@@ -60,8 +65,6 @@ class BatchSensor(BaseSensorOperator):
|
|
60
65
|
self,
|
61
66
|
*,
|
62
67
|
job_id: str,
|
63
|
-
aws_conn_id: str | None = "aws_default",
|
64
|
-
region_name: str | None = None,
|
65
68
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
66
69
|
poke_interval: float = 30,
|
67
70
|
max_retries: int = 4200,
|
@@ -69,8 +72,6 @@ class BatchSensor(BaseSensorOperator):
|
|
69
72
|
):
|
70
73
|
super().__init__(**kwargs)
|
71
74
|
self.job_id = job_id
|
72
|
-
self.aws_conn_id = aws_conn_id
|
73
|
-
self.region_name = region_name
|
74
75
|
self.deferrable = deferrable
|
75
76
|
self.poke_interval = poke_interval
|
76
77
|
self.max_retries = max_retries
|
@@ -119,15 +120,8 @@ class BatchSensor(BaseSensorOperator):
|
|
119
120
|
job_id = event["job_id"]
|
120
121
|
self.log.info("Batch Job %s complete", job_id)
|
121
122
|
|
122
|
-
@cached_property
|
123
|
-
def hook(self) -> BatchClientHook:
|
124
|
-
return BatchClientHook(
|
125
|
-
aws_conn_id=self.aws_conn_id,
|
126
|
-
region_name=self.region_name,
|
127
|
-
)
|
128
|
-
|
129
123
|
|
130
|
-
class BatchComputeEnvironmentSensor(
|
124
|
+
class BatchComputeEnvironmentSensor(AwsBaseSensor[BatchClientHook]):
|
131
125
|
"""
|
132
126
|
Poll the state of the Batch environment until it reaches a terminal state; fails if the environment fails.
|
133
127
|
|
@@ -137,38 +131,31 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
|
|
137
131
|
|
138
132
|
:param compute_environment: Batch compute environment name
|
139
133
|
|
140
|
-
:param aws_conn_id:
|
141
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
134
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
135
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
142
136
|
running Airflow in a distributed manner and aws_conn_id is None or
|
143
137
|
empty, then default boto3 configuration would be used (and must be
|
144
138
|
maintained on each worker node).
|
139
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
140
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
141
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
145
142
|
|
146
|
-
:param region_name: aws region name associated with the client
|
147
143
|
"""
|
148
144
|
|
149
|
-
|
145
|
+
aws_hook_class = BatchClientHook
|
146
|
+
template_fields: Sequence[str] = aws_template_fields(
|
147
|
+
"compute_environment",
|
148
|
+
)
|
150
149
|
template_ext: Sequence[str] = ()
|
151
150
|
ui_color = "#66c3ff"
|
152
151
|
|
153
152
|
def __init__(
|
154
153
|
self,
|
155
154
|
compute_environment: str,
|
156
|
-
aws_conn_id: str | None = "aws_default",
|
157
|
-
region_name: str | None = None,
|
158
155
|
**kwargs,
|
159
156
|
):
|
160
157
|
super().__init__(**kwargs)
|
161
158
|
self.compute_environment = compute_environment
|
162
|
-
self.aws_conn_id = aws_conn_id
|
163
|
-
self.region_name = region_name
|
164
|
-
|
165
|
-
@cached_property
|
166
|
-
def hook(self) -> BatchClientHook:
|
167
|
-
"""Create and return a BatchClientHook."""
|
168
|
-
return BatchClientHook(
|
169
|
-
aws_conn_id=self.aws_conn_id,
|
170
|
-
region_name=self.region_name,
|
171
|
-
)
|
172
159
|
|
173
160
|
def poke(self, context: Context) -> bool:
|
174
161
|
response = self.hook.client.describe_compute_environments( # type: ignore[union-attr]
|
@@ -191,7 +178,7 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
|
|
191
178
|
)
|
192
179
|
|
193
180
|
|
194
|
-
class BatchJobQueueSensor(
|
181
|
+
class BatchJobQueueSensor(AwsBaseSensor[BatchClientHook]):
|
195
182
|
"""
|
196
183
|
Poll the state of the Batch job queue until it reaches a terminal state; fails if the queue fails.
|
197
184
|
|
@@ -204,16 +191,20 @@ class BatchJobQueueSensor(BaseSensorOperator):
|
|
204
191
|
:param treat_non_existing_as_deleted: If True, a non-existing Batch job queue is considered as a deleted
|
205
192
|
queue and as such a valid case.
|
206
193
|
|
207
|
-
:param aws_conn_id:
|
208
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
194
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
195
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
209
196
|
running Airflow in a distributed manner and aws_conn_id is None or
|
210
197
|
empty, then default boto3 configuration would be used (and must be
|
211
198
|
maintained on each worker node).
|
212
|
-
|
213
|
-
:param
|
199
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
200
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
201
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
214
202
|
"""
|
215
203
|
|
216
|
-
|
204
|
+
aws_hook_class = BatchClientHook
|
205
|
+
template_fields: Sequence[str] = aws_template_fields(
|
206
|
+
"job_queue",
|
207
|
+
)
|
217
208
|
template_ext: Sequence[str] = ()
|
218
209
|
ui_color = "#66c3ff"
|
219
210
|
|
@@ -221,23 +212,11 @@ class BatchJobQueueSensor(BaseSensorOperator):
|
|
221
212
|
self,
|
222
213
|
job_queue: str,
|
223
214
|
treat_non_existing_as_deleted: bool = False,
|
224
|
-
aws_conn_id: str | None = "aws_default",
|
225
|
-
region_name: str | None = None,
|
226
215
|
**kwargs,
|
227
216
|
):
|
228
217
|
super().__init__(**kwargs)
|
229
218
|
self.job_queue = job_queue
|
230
219
|
self.treat_non_existing_as_deleted = treat_non_existing_as_deleted
|
231
|
-
self.aws_conn_id = aws_conn_id
|
232
|
-
self.region_name = region_name
|
233
|
-
|
234
|
-
@cached_property
|
235
|
-
def hook(self) -> BatchClientHook:
|
236
|
-
"""Create and return a BatchClientHook."""
|
237
|
-
return BatchClientHook(
|
238
|
-
aws_conn_id=self.aws_conn_id,
|
239
|
-
region_name=self.region_name,
|
240
|
-
)
|
241
220
|
|
242
221
|
def poke(self, context: Context) -> bool:
|
243
222
|
response = self.hook.client.describe_job_queues( # type: ignore[union-attr]
|
@@ -247,8 +226,7 @@ class BatchJobQueueSensor(BaseSensorOperator):
|
|
247
226
|
if not response["jobQueues"]:
|
248
227
|
if self.treat_non_existing_as_deleted:
|
249
228
|
return True
|
250
|
-
|
251
|
-
raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
|
229
|
+
raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
|
252
230
|
|
253
231
|
status = response["jobQueues"][0]["status"]
|
254
232
|
|
@@ -18,19 +18,20 @@
|
|
18
18
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
|
+
import warnings
|
21
22
|
from abc import abstractmethod
|
22
23
|
from collections.abc import Sequence
|
23
|
-
from functools import cached_property
|
24
24
|
from typing import TYPE_CHECKING
|
25
25
|
|
26
|
-
from airflow.exceptions import AirflowException
|
26
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
27
27
|
from airflow.providers.amazon.aws.hooks.eks import (
|
28
28
|
ClusterStates,
|
29
29
|
EksHook,
|
30
30
|
FargateProfileStates,
|
31
31
|
NodegroupStates,
|
32
32
|
)
|
33
|
-
from airflow.sensors.
|
33
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
34
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
34
35
|
|
35
36
|
if TYPE_CHECKING:
|
36
37
|
from airflow.utils.context import Context
|
@@ -57,7 +58,7 @@ NODEGROUP_TERMINAL_STATES = frozenset(
|
|
57
58
|
)
|
58
59
|
|
59
60
|
|
60
|
-
class EksBaseSensor(
|
61
|
+
class EksBaseSensor(AwsBaseSensor):
|
61
62
|
"""
|
62
63
|
Base class to check various EKS states.
|
63
64
|
|
@@ -68,41 +69,33 @@ class EksBaseSensor(BaseSensorOperator):
|
|
68
69
|
:param target_state_type: The enum containing the states,
|
69
70
|
will be used to convert the target state if it has to be converted from a string
|
70
71
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
71
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
72
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
72
73
|
running Airflow in a distributed manner and aws_conn_id is None or
|
73
|
-
empty, then
|
74
|
+
empty, then default boto3 configuration would be used (and must be
|
74
75
|
maintained on each worker node).
|
75
|
-
:param
|
76
|
-
|
76
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
77
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
78
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
77
79
|
"""
|
78
80
|
|
81
|
+
aws_hook_class = EksHook
|
82
|
+
|
79
83
|
def __init__(
|
80
84
|
self,
|
81
85
|
*,
|
82
86
|
cluster_name: str,
|
83
87
|
target_state: ClusterStates | NodegroupStates | FargateProfileStates,
|
84
88
|
target_state_type: type,
|
85
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
86
|
-
region: str | None = None,
|
87
89
|
**kwargs,
|
88
90
|
):
|
89
91
|
super().__init__(**kwargs)
|
90
92
|
self.cluster_name = cluster_name
|
91
|
-
self.aws_conn_id = aws_conn_id
|
92
|
-
self.region = region
|
93
93
|
self.target_state = (
|
94
94
|
target_state
|
95
95
|
if isinstance(target_state, target_state_type)
|
96
96
|
else target_state_type(str(target_state).upper())
|
97
97
|
)
|
98
98
|
|
99
|
-
@cached_property
|
100
|
-
def hook(self) -> EksHook:
|
101
|
-
return EksHook(
|
102
|
-
aws_conn_id=self.aws_conn_id,
|
103
|
-
region_name=self.region,
|
104
|
-
)
|
105
|
-
|
106
99
|
def poke(self, context: Context) -> bool:
|
107
100
|
state = self.get_state()
|
108
101
|
self.log.info("Current state: %s", state)
|
@@ -130,16 +123,17 @@ class EksClusterStateSensor(EksBaseSensor):
|
|
130
123
|
|
131
124
|
:param cluster_name: The name of the Cluster to watch. (templated)
|
132
125
|
:param target_state: Target state of the Cluster. (templated)
|
133
|
-
:param
|
134
|
-
If this is None or empty then the default boto3 behaviour is used.
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
126
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
127
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
128
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
129
|
+
empty, then default boto3 configuration would be used (and must be
|
130
|
+
maintained on each worker node).
|
131
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
132
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
133
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
140
134
|
"""
|
141
135
|
|
142
|
-
template_fields: Sequence[str] = ("cluster_name", "target_state"
|
136
|
+
template_fields: Sequence[str] = aws_template_fields("cluster_name", "target_state")
|
143
137
|
ui_color = "#ff9900"
|
144
138
|
ui_fgcolor = "#232F3E"
|
145
139
|
|
@@ -147,8 +141,16 @@ class EksClusterStateSensor(EksBaseSensor):
|
|
147
141
|
self,
|
148
142
|
*,
|
149
143
|
target_state: ClusterStates = ClusterStates.ACTIVE,
|
144
|
+
region: str | None = None,
|
150
145
|
**kwargs,
|
151
146
|
):
|
147
|
+
if region is not None:
|
148
|
+
warnings.warn(
|
149
|
+
message="Parameter `region` is deprecated. Use the parameter `region_name` instead",
|
150
|
+
category=AirflowProviderDeprecationWarning,
|
151
|
+
stacklevel=2,
|
152
|
+
)
|
153
|
+
kwargs["region_name"] = region
|
152
154
|
super().__init__(target_state=target_state, target_state_type=ClusterStates, **kwargs)
|
153
155
|
|
154
156
|
def get_state(self) -> ClusterStates:
|
@@ -169,21 +171,18 @@ class EksFargateProfileStateSensor(EksBaseSensor):
|
|
169
171
|
:param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated)
|
170
172
|
:param fargate_profile_name: The name of the Fargate profile to watch. (templated)
|
171
173
|
:param target_state: Target state of the Fargate profile. (templated)
|
172
|
-
:param
|
173
|
-
If this is None or empty then the default boto3 behaviour is used.
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
174
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
175
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
176
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
177
|
+
empty, then default boto3 configuration would be used (and must be
|
178
|
+
maintained on each worker node).
|
179
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
180
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
181
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
179
182
|
"""
|
180
183
|
|
181
|
-
template_fields: Sequence[str] = (
|
182
|
-
"cluster_name",
|
183
|
-
"fargate_profile_name",
|
184
|
-
"target_state",
|
185
|
-
"aws_conn_id",
|
186
|
-
"region",
|
184
|
+
template_fields: Sequence[str] = aws_template_fields(
|
185
|
+
"cluster_name", "fargate_profile_name", "target_state"
|
187
186
|
)
|
188
187
|
ui_color = "#ff9900"
|
189
188
|
ui_fgcolor = "#232F3E"
|
@@ -192,9 +191,17 @@ class EksFargateProfileStateSensor(EksBaseSensor):
|
|
192
191
|
self,
|
193
192
|
*,
|
194
193
|
fargate_profile_name: str,
|
194
|
+
region: str | None = None,
|
195
195
|
target_state: FargateProfileStates = FargateProfileStates.ACTIVE,
|
196
196
|
**kwargs,
|
197
197
|
):
|
198
|
+
if region is not None:
|
199
|
+
warnings.warn(
|
200
|
+
message="Parameter `region` is deprecated. Use the parameter `region_name` instead",
|
201
|
+
category=AirflowProviderDeprecationWarning,
|
202
|
+
stacklevel=2,
|
203
|
+
)
|
204
|
+
kwargs["region_name"] = region
|
198
205
|
super().__init__(target_state=target_state, target_state_type=FargateProfileStates, **kwargs)
|
199
206
|
self.fargate_profile_name = fargate_profile_name
|
200
207
|
|
@@ -218,22 +225,17 @@ class EksNodegroupStateSensor(EksBaseSensor):
|
|
218
225
|
:param cluster_name: The name of the Cluster which the Nodegroup is attached to. (templated)
|
219
226
|
:param nodegroup_name: The name of the Nodegroup to watch. (templated)
|
220
227
|
:param target_state: Target state of the Nodegroup. (templated)
|
221
|
-
:param
|
222
|
-
If this is None or empty then the default boto3 behaviour is used.
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
229
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
230
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
231
|
+
empty, then default boto3 configuration would be used (and must be
|
232
|
+
maintained on each worker node).
|
233
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
234
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
235
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
228
236
|
"""
|
229
237
|
|
230
|
-
template_fields: Sequence[str] = (
|
231
|
-
"cluster_name",
|
232
|
-
"nodegroup_name",
|
233
|
-
"target_state",
|
234
|
-
"aws_conn_id",
|
235
|
-
"region",
|
236
|
-
)
|
238
|
+
template_fields: Sequence[str] = aws_template_fields("cluster_name", "nodegroup_name", "target_state")
|
237
239
|
ui_color = "#ff9900"
|
238
240
|
ui_fgcolor = "#232F3E"
|
239
241
|
|
@@ -242,8 +244,16 @@ class EksNodegroupStateSensor(EksBaseSensor):
|
|
242
244
|
*,
|
243
245
|
nodegroup_name: str,
|
244
246
|
target_state: NodegroupStates = NodegroupStates.ACTIVE,
|
247
|
+
region: str | None = None,
|
245
248
|
**kwargs,
|
246
249
|
):
|
250
|
+
if region is not None:
|
251
|
+
warnings.warn(
|
252
|
+
message="Parameter `region` is deprecated. Use the parameter `region_name` instead",
|
253
|
+
category=AirflowProviderDeprecationWarning,
|
254
|
+
stacklevel=2,
|
255
|
+
)
|
256
|
+
kwargs["region_name"] = region
|
247
257
|
super().__init__(target_state=target_state, target_state_type=NodegroupStates, **kwargs)
|
248
258
|
self.nodegroup_name = nodegroup_name
|
249
259
|
|
@@ -89,11 +89,10 @@ class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
|
|
89
89
|
self.log.info("Job status: %s, code status: %s", response["Action"], response["StatusCode"])
|
90
90
|
self.log.info("Job finished successfully")
|
91
91
|
return True
|
92
|
-
|
92
|
+
if response["StatusCode"] == JobStatus.IN_PROGRESS.value:
|
93
93
|
self.log.info("Processing...")
|
94
94
|
self.log.warning("Code status: %s", response["StatusCode"])
|
95
95
|
return False
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
)
|
96
|
+
raise AirflowException(
|
97
|
+
f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}"
|
98
|
+
)
|