apache-airflow-providers-amazon 8.26.0rc1__py3-none-any.whl → 8.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
- airflow/providers/amazon/aws/datasets/__init__.py +16 -0
- airflow/providers/amazon/aws/datasets/s3.py +45 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +27 -17
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +31 -13
- airflow/providers/amazon/aws/hooks/kinesis_analytics.py +65 -0
- airflow/providers/amazon/aws/hooks/rds.py +3 -3
- airflow/providers/amazon/aws/hooks/s3.py +26 -1
- airflow/providers/amazon/aws/hooks/step_function.py +18 -0
- airflow/providers/amazon/aws/operators/athena.py +16 -17
- airflow/providers/amazon/aws/operators/emr.py +23 -23
- airflow/providers/amazon/aws/operators/kinesis_analytics.py +348 -0
- airflow/providers/amazon/aws/operators/rds.py +17 -20
- airflow/providers/amazon/aws/operators/redshift_cluster.py +71 -53
- airflow/providers/amazon/aws/operators/s3.py +18 -12
- airflow/providers/amazon/aws/operators/sagemaker.py +12 -27
- airflow/providers/amazon/aws/operators/step_function.py +12 -2
- airflow/providers/amazon/aws/sensors/kinesis_analytics.py +234 -0
- airflow/providers/amazon/aws/sensors/s3.py +11 -5
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
- airflow/providers/amazon/aws/triggers/emr.py +3 -1
- airflow/providers/amazon/aws/triggers/kinesis_analytics.py +69 -0
- airflow/providers/amazon/aws/triggers/sagemaker.py +9 -1
- airflow/providers/amazon/aws/waiters/kinesisanalyticsv2.json +151 -0
- airflow/providers/amazon/aws/waiters/rds.json +253 -0
- airflow/providers/amazon/get_provider_info.py +35 -2
- {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/METADATA +32 -25
- {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/RECORD +32 -24
- {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/entry_points.txt +0 -0
@@ -324,8 +324,7 @@ class S3CopyObjectOperator(BaseOperator):
|
|
324
324
|
)
|
325
325
|
|
326
326
|
def get_openlineage_facets_on_start(self):
|
327
|
-
from openlineage.
|
328
|
-
|
327
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
329
328
|
from airflow.providers.openlineage.extractors import OperatorLineage
|
330
329
|
|
331
330
|
dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
|
@@ -439,8 +438,7 @@ class S3CreateObjectOperator(BaseOperator):
|
|
439
438
|
s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy)
|
440
439
|
|
441
440
|
def get_openlineage_facets_on_start(self):
|
442
|
-
from openlineage.
|
443
|
-
|
441
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
444
442
|
from airflow.providers.openlineage.extractors import OperatorLineage
|
445
443
|
|
446
444
|
bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key")
|
@@ -546,13 +544,12 @@ class S3DeleteObjectsOperator(BaseOperator):
|
|
546
544
|
|
547
545
|
def get_openlineage_facets_on_complete(self, task_instance):
|
548
546
|
"""Implement _on_complete because object keys are resolved in execute()."""
|
549
|
-
from openlineage.
|
547
|
+
from airflow.providers.common.compat.openlineage.facet import (
|
548
|
+
Dataset,
|
550
549
|
LifecycleStateChange,
|
551
550
|
LifecycleStateChangeDatasetFacet,
|
552
|
-
|
551
|
+
PreviousIdentifier,
|
553
552
|
)
|
554
|
-
from openlineage.client.run import Dataset
|
555
|
-
|
556
553
|
from airflow.providers.openlineage.extractors import OperatorLineage
|
557
554
|
|
558
555
|
if not self._keys:
|
@@ -570,7 +567,7 @@ class S3DeleteObjectsOperator(BaseOperator):
|
|
570
567
|
facets={
|
571
568
|
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
|
572
569
|
lifecycleStateChange=LifecycleStateChange.DROP.value,
|
573
|
-
previousIdentifier=
|
570
|
+
previousIdentifier=PreviousIdentifier(
|
574
571
|
namespace=bucket_url,
|
575
572
|
name=key,
|
576
573
|
),
|
@@ -610,6 +607,7 @@ class S3FileTransformOperator(BaseOperator):
|
|
610
607
|
:param dest_s3_key: The key to be written from S3. (templated)
|
611
608
|
:param transform_script: location of the executable transformation script
|
612
609
|
:param select_expression: S3 Select expression
|
610
|
+
:param select_expr_serialization_config: A dictionary that contains input and output serialization configurations for S3 Select.
|
613
611
|
:param script_args: arguments for transformation script (templated)
|
614
612
|
:param source_aws_conn_id: source s3 connection
|
615
613
|
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
|
@@ -641,6 +639,7 @@ class S3FileTransformOperator(BaseOperator):
|
|
641
639
|
dest_s3_key: str,
|
642
640
|
transform_script: str | None = None,
|
643
641
|
select_expression=None,
|
642
|
+
select_expr_serialization_config: dict[str, dict[str, dict]] | None = None,
|
644
643
|
script_args: Sequence[str] | None = None,
|
645
644
|
source_aws_conn_id: str | None = "aws_default",
|
646
645
|
source_verify: bool | str | None = None,
|
@@ -659,6 +658,7 @@ class S3FileTransformOperator(BaseOperator):
|
|
659
658
|
self.replace = replace
|
660
659
|
self.transform_script = transform_script
|
661
660
|
self.select_expression = select_expression
|
661
|
+
self.select_expr_serialization_config = select_expr_serialization_config or {}
|
662
662
|
self.script_args = script_args or []
|
663
663
|
self.output_encoding = sys.getdefaultencoding()
|
664
664
|
|
@@ -678,7 +678,14 @@ class S3FileTransformOperator(BaseOperator):
|
|
678
678
|
self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name)
|
679
679
|
|
680
680
|
if self.select_expression is not None:
|
681
|
-
|
681
|
+
input_serialization = self.select_expr_serialization_config.get("input_serialization")
|
682
|
+
output_serialization = self.select_expr_serialization_config.get("output_serialization")
|
683
|
+
content = source_s3.select_key(
|
684
|
+
key=self.source_s3_key,
|
685
|
+
expression=self.select_expression,
|
686
|
+
input_serialization=input_serialization,
|
687
|
+
output_serialization=output_serialization,
|
688
|
+
)
|
682
689
|
f_source.write(content.encode("utf-8"))
|
683
690
|
else:
|
684
691
|
source_s3_key_object.download_fileobj(Fileobj=f_source)
|
@@ -715,8 +722,7 @@ class S3FileTransformOperator(BaseOperator):
|
|
715
722
|
self.log.info("Upload successful")
|
716
723
|
|
717
724
|
def get_openlineage_facets_on_start(self):
|
718
|
-
from openlineage.
|
719
|
-
|
725
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
720
726
|
from airflow.providers.openlineage.extractors import OperatorLineage
|
721
727
|
|
722
728
|
dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
|
@@ -36,7 +36,6 @@ from airflow.providers.amazon.aws.hooks.sagemaker import (
|
|
36
36
|
)
|
37
37
|
from airflow.providers.amazon.aws.triggers.sagemaker import (
|
38
38
|
SageMakerPipelineTrigger,
|
39
|
-
SageMakerTrainingPrintLogTrigger,
|
40
39
|
SageMakerTrigger,
|
41
40
|
)
|
42
41
|
from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
|
@@ -46,8 +45,7 @@ from airflow.utils.helpers import prune_dict
|
|
46
45
|
from airflow.utils.json import AirflowJsonEncoder
|
47
46
|
|
48
47
|
if TYPE_CHECKING:
|
49
|
-
from openlineage.
|
50
|
-
|
48
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
51
49
|
from airflow.providers.openlineage.extractors.base import OperatorLineage
|
52
50
|
from airflow.utils.context import Context
|
53
51
|
|
@@ -208,7 +206,7 @@ class SageMakerBaseOperator(BaseOperator):
|
|
208
206
|
|
209
207
|
@staticmethod
|
210
208
|
def path_to_s3_dataset(path) -> Dataset:
|
211
|
-
from openlineage.
|
209
|
+
from airflow.providers.common.compat.openlineage.facet import Dataset
|
212
210
|
|
213
211
|
path = path.replace("s3://", "")
|
214
212
|
split_path = path.split("/")
|
@@ -361,7 +359,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
|
|
361
359
|
raise AirflowException(f"Error while running job: {event}")
|
362
360
|
|
363
361
|
self.log.info(event["message"])
|
364
|
-
self.serialized_job = serialize(self.hook.describe_processing_job(
|
362
|
+
self.serialized_job = serialize(self.hook.describe_processing_job(event["job_name"]))
|
365
363
|
self.log.info("%s completed successfully.", self.task_id)
|
366
364
|
return {"Processing": self.serialized_job}
|
367
365
|
|
@@ -612,12 +610,11 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
612
610
|
|
613
611
|
if event["status"] != "success":
|
614
612
|
raise AirflowException(f"Error while running job: {event}")
|
615
|
-
|
613
|
+
|
614
|
+
response = self.hook.describe_endpoint(event["job_name"])
|
616
615
|
return {
|
617
|
-
"EndpointConfig": serialize(
|
618
|
-
|
619
|
-
),
|
620
|
-
"Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
|
616
|
+
"EndpointConfig": serialize(self.hook.describe_endpoint_config(response["EndpointConfigName"])),
|
617
|
+
"Endpoint": serialize(self.hook.describe_endpoint(response["EndpointName"])),
|
621
618
|
}
|
622
619
|
|
623
620
|
|
@@ -997,9 +994,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
|
|
997
994
|
|
998
995
|
if event["status"] != "success":
|
999
996
|
raise AirflowException(f"Error while running job: {event}")
|
1000
|
-
return {
|
1001
|
-
"Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
|
1002
|
-
}
|
997
|
+
return {"Tuning": serialize(self.hook.describe_tuning_job(event["job_name"]))}
|
1003
998
|
|
1004
999
|
|
1005
1000
|
class SageMakerModelOperator(SageMakerBaseOperator):
|
@@ -1199,25 +1194,15 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1199
1194
|
if self.max_ingestion_time:
|
1200
1195
|
timeout = datetime.timedelta(seconds=self.max_ingestion_time)
|
1201
1196
|
|
1202
|
-
|
1203
|
-
|
1204
|
-
trigger
|
1205
|
-
job_name=self.config["TrainingJobName"],
|
1206
|
-
poke_interval=self.check_interval,
|
1207
|
-
aws_conn_id=self.aws_conn_id,
|
1208
|
-
)
|
1209
|
-
else:
|
1210
|
-
trigger = SageMakerTrigger(
|
1197
|
+
self.defer(
|
1198
|
+
timeout=timeout,
|
1199
|
+
trigger=SageMakerTrigger(
|
1211
1200
|
job_name=self.config["TrainingJobName"],
|
1212
1201
|
job_type="Training",
|
1213
1202
|
poke_interval=self.check_interval,
|
1214
1203
|
max_attempts=self.max_attempts,
|
1215
1204
|
aws_conn_id=self.aws_conn_id,
|
1216
|
-
)
|
1217
|
-
|
1218
|
-
self.defer(
|
1219
|
-
timeout=timeout,
|
1220
|
-
trigger=trigger,
|
1205
|
+
),
|
1221
1206
|
method_name="execute_complete",
|
1222
1207
|
)
|
1223
1208
|
|
@@ -48,6 +48,8 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
|
|
48
48
|
|
49
49
|
:param state_machine_arn: ARN of the Step Function State Machine
|
50
50
|
:param name: The name of the execution.
|
51
|
+
:param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
|
52
|
+
complete successfully in the last 14 days.
|
51
53
|
:param state_machine_input: JSON data input to pass to the State Machine
|
52
54
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
53
55
|
If this is None or empty then the default boto3 behaviour is used. If
|
@@ -73,7 +75,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
|
|
73
75
|
"""
|
74
76
|
|
75
77
|
aws_hook_class = StepFunctionHook
|
76
|
-
template_fields: Sequence[str] = aws_template_fields(
|
78
|
+
template_fields: Sequence[str] = aws_template_fields(
|
79
|
+
"state_machine_arn", "name", "input", "is_redrive_execution"
|
80
|
+
)
|
77
81
|
ui_color = "#f9c915"
|
78
82
|
operator_extra_links = (StateMachineDetailsLink(), StateMachineExecutionsDetailsLink())
|
79
83
|
|
@@ -82,6 +86,7 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
|
|
82
86
|
*,
|
83
87
|
state_machine_arn: str,
|
84
88
|
name: str | None = None,
|
89
|
+
is_redrive_execution: bool = False,
|
85
90
|
state_machine_input: dict | str | None = None,
|
86
91
|
waiter_max_attempts: int = 30,
|
87
92
|
waiter_delay: int = 60,
|
@@ -91,6 +96,7 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
|
|
91
96
|
super().__init__(**kwargs)
|
92
97
|
self.state_machine_arn = state_machine_arn
|
93
98
|
self.name = name
|
99
|
+
self.is_redrive_execution = is_redrive_execution
|
94
100
|
self.input = state_machine_input
|
95
101
|
self.waiter_delay = waiter_delay
|
96
102
|
self.waiter_max_attempts = waiter_max_attempts
|
@@ -105,7 +111,11 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
|
|
105
111
|
state_machine_arn=self.state_machine_arn,
|
106
112
|
)
|
107
113
|
|
108
|
-
if not (
|
114
|
+
if not (
|
115
|
+
execution_arn := self.hook.start_execution(
|
116
|
+
self.state_machine_arn, self.name, self.input, self.is_redrive_execution
|
117
|
+
)
|
118
|
+
):
|
109
119
|
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
|
110
120
|
|
111
121
|
StateMachineExecutionsDetailsLink.persist(
|
@@ -0,0 +1,234 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
20
|
+
|
21
|
+
from airflow.configuration import conf
|
22
|
+
from airflow.exceptions import AirflowException, AirflowSkipException
|
23
|
+
from airflow.providers.amazon.aws.hooks.kinesis_analytics import KinesisAnalyticsV2Hook
|
24
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
25
|
+
from airflow.providers.amazon.aws.triggers.kinesis_analytics import (
|
26
|
+
KinesisAnalyticsV2ApplicationOperationCompleteTrigger,
|
27
|
+
)
|
28
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
29
|
+
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from airflow.utils.context import Context
|
32
|
+
|
33
|
+
|
34
|
+
class KinesisAnalyticsV2BaseSensor(AwsBaseSensor[KinesisAnalyticsV2Hook]):
|
35
|
+
"""
|
36
|
+
General sensor behaviour for AWS Managed Service for Apache Flink.
|
37
|
+
|
38
|
+
Subclasses must set the following fields:
|
39
|
+
- ``INTERMEDIATE_STATES``
|
40
|
+
- ``FAILURE_STATES``
|
41
|
+
- ``SUCCESS_STATES``
|
42
|
+
- ``FAILURE_MESSAGE``
|
43
|
+
- ``SUCCESS_MESSAGE``
|
44
|
+
|
45
|
+
:param application_name: Application name.
|
46
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
47
|
+
module to be installed.
|
48
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
49
|
+
|
50
|
+
"""
|
51
|
+
|
52
|
+
aws_hook_class = KinesisAnalyticsV2Hook
|
53
|
+
ui_color = "#66c3ff"
|
54
|
+
|
55
|
+
INTERMEDIATE_STATES: tuple[str, ...] = ()
|
56
|
+
FAILURE_STATES: tuple[str, ...] = ()
|
57
|
+
SUCCESS_STATES: tuple[str, ...] = ()
|
58
|
+
FAILURE_MESSAGE = ""
|
59
|
+
SUCCESS_MESSAGE = ""
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
application_name: str,
|
64
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
65
|
+
**kwargs: Any,
|
66
|
+
):
|
67
|
+
super().__init__(**kwargs)
|
68
|
+
self.application_name = application_name
|
69
|
+
self.deferrable = deferrable
|
70
|
+
|
71
|
+
def poke(self, context: Context, **kwargs) -> bool:
|
72
|
+
status = self.hook.conn.describe_application(ApplicationName=self.application_name)[
|
73
|
+
"ApplicationDetail"
|
74
|
+
]["ApplicationStatus"]
|
75
|
+
|
76
|
+
self.log.info(
|
77
|
+
"Poking for AWS Managed Service for Apache Flink application: %s status: %s",
|
78
|
+
self.application_name,
|
79
|
+
status,
|
80
|
+
)
|
81
|
+
|
82
|
+
if status in self.FAILURE_STATES:
|
83
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
84
|
+
if self.soft_fail:
|
85
|
+
raise AirflowSkipException(self.FAILURE_MESSAGE)
|
86
|
+
raise AirflowException(self.FAILURE_MESSAGE)
|
87
|
+
|
88
|
+
if status in self.SUCCESS_STATES:
|
89
|
+
self.log.info(
|
90
|
+
"%s `%s`.",
|
91
|
+
self.SUCCESS_MESSAGE,
|
92
|
+
self.application_name,
|
93
|
+
)
|
94
|
+
return True
|
95
|
+
|
96
|
+
return False
|
97
|
+
|
98
|
+
|
99
|
+
class KinesisAnalyticsV2StartApplicationCompletedSensor(KinesisAnalyticsV2BaseSensor):
|
100
|
+
"""
|
101
|
+
Waits for AWS Managed Service for Apache Flink application to start.
|
102
|
+
|
103
|
+
.. seealso::
|
104
|
+
For more information on how to use this sensor, take a look at the guide:
|
105
|
+
:ref:`howto/sensor:KinesisAnalyticsV2StartApplicationCompletedSensor`
|
106
|
+
|
107
|
+
:param application_name: Application name.
|
108
|
+
|
109
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
110
|
+
module to be installed.
|
111
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
112
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
113
|
+
:param max_retries: Number of times before returning the current state. (default: 75)
|
114
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
115
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
116
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
117
|
+
empty, then default boto3 configuration would be used (and must be
|
118
|
+
maintained on each worker node).
|
119
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
120
|
+
:param verify: Whether to verify SSL certificates. See:
|
121
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
122
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
123
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
124
|
+
|
125
|
+
"""
|
126
|
+
|
127
|
+
INTERMEDIATE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_START_INTERMEDIATE_STATES
|
128
|
+
FAILURE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_START_FAILURE_STATES
|
129
|
+
SUCCESS_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_START_SUCCESS_STATES
|
130
|
+
|
131
|
+
FAILURE_MESSAGE = "AWS Managed Service for Apache Flink application start failed."
|
132
|
+
SUCCESS_MESSAGE = "AWS Managed Service for Apache Flink application started successfully"
|
133
|
+
|
134
|
+
template_fields: Sequence[str] = aws_template_fields("application_name")
|
135
|
+
|
136
|
+
def __init__(
|
137
|
+
self,
|
138
|
+
*,
|
139
|
+
application_name: str,
|
140
|
+
max_retries: int = 75,
|
141
|
+
poke_interval: int = 120,
|
142
|
+
**kwargs: Any,
|
143
|
+
) -> None:
|
144
|
+
super().__init__(application_name=application_name, **kwargs)
|
145
|
+
self.application_name = application_name
|
146
|
+
self.max_retries = max_retries
|
147
|
+
self.poke_interval = poke_interval
|
148
|
+
|
149
|
+
def execute(self, context: Context) -> Any:
|
150
|
+
if self.deferrable:
|
151
|
+
self.defer(
|
152
|
+
trigger=KinesisAnalyticsV2ApplicationOperationCompleteTrigger(
|
153
|
+
application_name=self.application_name,
|
154
|
+
waiter_name="application_start_complete",
|
155
|
+
aws_conn_id=self.aws_conn_id,
|
156
|
+
waiter_delay=int(self.poke_interval),
|
157
|
+
waiter_max_attempts=self.max_retries,
|
158
|
+
region_name=self.region_name,
|
159
|
+
verify=self.verify,
|
160
|
+
botocore_config=self.botocore_config,
|
161
|
+
),
|
162
|
+
method_name="poke",
|
163
|
+
)
|
164
|
+
else:
|
165
|
+
super().execute(context=context)
|
166
|
+
|
167
|
+
|
168
|
+
class KinesisAnalyticsV2StopApplicationCompletedSensor(KinesisAnalyticsV2BaseSensor):
|
169
|
+
"""
|
170
|
+
Waits for AWS Managed Service for Apache Flink application to stop.
|
171
|
+
|
172
|
+
.. seealso::
|
173
|
+
For more information on how to use this sensor, take a look at the guide:
|
174
|
+
:ref:`howto/sensor:KinesisAnalyticsV2StopApplicationCompletedSensor`
|
175
|
+
|
176
|
+
:param application_name: Application name.
|
177
|
+
|
178
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
179
|
+
module to be installed.
|
180
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
181
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
182
|
+
:param max_retries: Number of times before returning the current state. (default: 75)
|
183
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
184
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
185
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
186
|
+
empty, then default boto3 configuration would be used (and must be
|
187
|
+
maintained on each worker node).
|
188
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
189
|
+
:param verify: Whether to verify SSL certificates. See:
|
190
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
191
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
192
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
193
|
+
|
194
|
+
"""
|
195
|
+
|
196
|
+
INTERMEDIATE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_STOP_INTERMEDIATE_STATES
|
197
|
+
FAILURE_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_STOP_FAILURE_STATES
|
198
|
+
SUCCESS_STATES: tuple[str, ...] = KinesisAnalyticsV2Hook.APPLICATION_STOP_SUCCESS_STATES
|
199
|
+
|
200
|
+
FAILURE_MESSAGE = "AWS Managed Service for Apache Flink application stop failed."
|
201
|
+
SUCCESS_MESSAGE = "AWS Managed Service for Apache Flink application stopped successfully"
|
202
|
+
|
203
|
+
template_fields: Sequence[str] = aws_template_fields("application_name")
|
204
|
+
|
205
|
+
def __init__(
|
206
|
+
self,
|
207
|
+
*,
|
208
|
+
application_name: str,
|
209
|
+
max_retries: int = 75,
|
210
|
+
poke_interval: int = 120,
|
211
|
+
**kwargs: Any,
|
212
|
+
) -> None:
|
213
|
+
super().__init__(application_name=application_name, **kwargs)
|
214
|
+
self.application_name = application_name
|
215
|
+
self.max_retries = max_retries
|
216
|
+
self.poke_interval = poke_interval
|
217
|
+
|
218
|
+
def execute(self, context: Context) -> Any:
|
219
|
+
if self.deferrable:
|
220
|
+
self.defer(
|
221
|
+
trigger=KinesisAnalyticsV2ApplicationOperationCompleteTrigger(
|
222
|
+
application_name=self.application_name,
|
223
|
+
waiter_name="application_stop_complete",
|
224
|
+
aws_conn_id=self.aws_conn_id,
|
225
|
+
waiter_delay=int(self.poke_interval),
|
226
|
+
waiter_max_attempts=self.max_retries,
|
227
|
+
region_name=self.region_name,
|
228
|
+
verify=self.verify,
|
229
|
+
botocore_config=self.botocore_config,
|
230
|
+
),
|
231
|
+
method_name="poke",
|
232
|
+
)
|
233
|
+
else:
|
234
|
+
super().execute(context=context)
|
@@ -18,6 +18,7 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import fnmatch
|
21
|
+
import inspect
|
21
22
|
import os
|
22
23
|
import re
|
23
24
|
from datetime import datetime, timedelta
|
@@ -57,13 +58,13 @@ class S3KeySensor(BaseSensorOperator):
|
|
57
58
|
refers to this bucket
|
58
59
|
:param wildcard_match: whether the bucket_key should be interpreted as a
|
59
60
|
Unix wildcard pattern
|
60
|
-
:param check_fn: Function that receives the list of the S3 objects,
|
61
|
+
:param check_fn: Function that receives the list of the S3 objects with the context values,
|
61
62
|
and returns a boolean:
|
62
63
|
- ``True``: the criteria is met
|
63
64
|
- ``False``: the criteria isn't met
|
64
65
|
**Example**: Wait for any S3 object size more than 1 megabyte ::
|
65
66
|
|
66
|
-
def check_fn(files: List) -> bool:
|
67
|
+
def check_fn(files: List, **kwargs) -> bool:
|
67
68
|
return any(f.get('Size', 0) > 1048576 for f in files)
|
68
69
|
:param aws_conn_id: a reference to the s3 connection
|
69
70
|
:param verify: Whether to verify SSL certificates for S3 connection.
|
@@ -112,7 +113,7 @@ class S3KeySensor(BaseSensorOperator):
|
|
112
113
|
self.use_regex = use_regex
|
113
114
|
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
|
114
115
|
|
115
|
-
def _check_key(self, key):
|
116
|
+
def _check_key(self, key, context: Context):
|
116
117
|
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
|
117
118
|
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
|
118
119
|
|
@@ -167,15 +168,20 @@ class S3KeySensor(BaseSensorOperator):
|
|
167
168
|
files = [metadata]
|
168
169
|
|
169
170
|
if self.check_fn is not None:
|
171
|
+
# For backwards compatibility, check if the function takes a context argument
|
172
|
+
signature = inspect.signature(self.check_fn)
|
173
|
+
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
|
174
|
+
return self.check_fn(files, **context)
|
175
|
+
# Otherwise, just pass the files
|
170
176
|
return self.check_fn(files)
|
171
177
|
|
172
178
|
return True
|
173
179
|
|
174
180
|
def poke(self, context: Context):
|
175
181
|
if isinstance(self.bucket_key, str):
|
176
|
-
return self._check_key(self.bucket_key)
|
182
|
+
return self._check_key(self.bucket_key, context=context)
|
177
183
|
else:
|
178
|
-
return all(self._check_key(key) for key in self.bucket_key)
|
184
|
+
return all(self._check_key(key, context=context) for key in self.bucket_key)
|
179
185
|
|
180
186
|
def execute(self, context: Context) -> None:
|
181
187
|
"""Airflow runs this method on the worker and defers using the trigger."""
|
@@ -16,6 +16,7 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
+
import sys
|
19
20
|
import warnings
|
20
21
|
from typing import TYPE_CHECKING
|
21
22
|
|
@@ -174,6 +175,7 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
|
|
174
175
|
:param job_id: job_id to check the state
|
175
176
|
:param aws_conn_id: Reference to AWS connection id
|
176
177
|
:param waiter_delay: polling period in seconds to check for the status
|
178
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. Defaults to an infinite wait.
|
177
179
|
"""
|
178
180
|
|
179
181
|
def __init__(
|
@@ -183,7 +185,7 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
|
|
183
185
|
aws_conn_id: str | None = "aws_default",
|
184
186
|
poll_interval: int | None = None, # deprecated
|
185
187
|
waiter_delay: int = 30,
|
186
|
-
waiter_max_attempts: int =
|
188
|
+
waiter_max_attempts: int = sys.maxsize,
|
187
189
|
):
|
188
190
|
if poll_interval is not None:
|
189
191
|
warnings.warn(
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.kinesis_analytics import KinesisAnalyticsV2Hook
|
22
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
26
|
+
|
27
|
+
|
28
|
+
class KinesisAnalyticsV2ApplicationOperationCompleteTrigger(AwsBaseWaiterTrigger):
|
29
|
+
"""
|
30
|
+
Trigger when a Managed Service for Apache Flink application Start or Stop is complete.
|
31
|
+
|
32
|
+
:param application_name: Application name.
|
33
|
+
:param waiter_name: The name of the waiter for stop or start application.
|
34
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
|
35
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
36
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
application_name: str,
|
42
|
+
waiter_name: str,
|
43
|
+
waiter_delay: int = 120,
|
44
|
+
waiter_max_attempts: int = 75,
|
45
|
+
aws_conn_id: str | None = "aws_default",
|
46
|
+
**kwargs,
|
47
|
+
) -> None:
|
48
|
+
super().__init__(
|
49
|
+
serialized_fields={"application_name": application_name, "waiter_name": waiter_name},
|
50
|
+
waiter_name=waiter_name,
|
51
|
+
waiter_args={"ApplicationName": application_name},
|
52
|
+
failure_message=f"AWS Managed Service for Apache Flink Application {application_name} failed.",
|
53
|
+
status_message=f"Status of AWS Managed Service for Apache Flink Application {application_name} is",
|
54
|
+
status_queries=["ApplicationDetail.ApplicationStatus"],
|
55
|
+
return_key="application_name",
|
56
|
+
return_value=application_name,
|
57
|
+
waiter_delay=waiter_delay,
|
58
|
+
waiter_max_attempts=waiter_max_attempts,
|
59
|
+
aws_conn_id=aws_conn_id,
|
60
|
+
**kwargs,
|
61
|
+
)
|
62
|
+
|
63
|
+
def hook(self) -> AwsGenericHook:
|
64
|
+
return KinesisAnalyticsV2Hook(
|
65
|
+
aws_conn_id=self.aws_conn_id,
|
66
|
+
region_name=self.region_name,
|
67
|
+
verify=self.verify,
|
68
|
+
config=self.botocore_config,
|
69
|
+
)
|
@@ -25,8 +25,9 @@ from functools import cached_property
|
|
25
25
|
from typing import Any, AsyncIterator
|
26
26
|
|
27
27
|
from botocore.exceptions import WaiterError
|
28
|
+
from deprecated import deprecated
|
28
29
|
|
29
|
-
from airflow.exceptions import AirflowException
|
30
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
30
31
|
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
|
31
32
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
32
33
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
@@ -199,6 +200,13 @@ class SageMakerPipelineTrigger(BaseTrigger):
|
|
199
200
|
raise AirflowException("Waiter error: max attempts reached")
|
200
201
|
|
201
202
|
|
203
|
+
@deprecated(
|
204
|
+
reason=(
|
205
|
+
"`airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger` "
|
206
|
+
"has been deprecated and will be removed in future. Please use ``SageMakerTrigger`` instead."
|
207
|
+
),
|
208
|
+
category=AirflowProviderDeprecationWarning,
|
209
|
+
)
|
202
210
|
class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
203
211
|
"""
|
204
212
|
SageMakerTrainingPrintLogTrigger is fired as deferred class with params to run the task in triggerer.
|