apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +69 -97
- airflow/providers/amazon/aws/auth_manager/router/login.py +9 -4
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/base_aws.py +8 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker.py +1 -2
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +160 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +31 -42
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +11 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +45 -4
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/METADATA +38 -31
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/RECORD +55 -48
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,73 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
"""This module contains the Amazon SageMaker Unified Studio Notebook sensor."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
from typing import TYPE_CHECKING
|
23
|
+
|
24
|
+
from airflow.exceptions import AirflowException
|
25
|
+
from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import (
|
26
|
+
SageMakerNotebookHook,
|
27
|
+
)
|
28
|
+
from airflow.sensors.base import BaseSensorOperator
|
29
|
+
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from airflow.utils.context import Context
|
32
|
+
|
33
|
+
|
34
|
+
class SageMakerNotebookSensor(BaseSensorOperator):
|
35
|
+
"""
|
36
|
+
Waits for a Sagemaker Workflows Notebook execution to reach any of the status below.
|
37
|
+
|
38
|
+
'FAILED', 'STOPPED', 'COMPLETED'
|
39
|
+
|
40
|
+
:param execution_id: The Sagemaker Workflows Notebook running execution identifier
|
41
|
+
:param execution_name: The Sagemaker Workflows Notebook unique execution name
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, *, execution_id: str, execution_name: str, **kwargs):
|
45
|
+
super().__init__(**kwargs)
|
46
|
+
self.execution_id = execution_id
|
47
|
+
self.execution_name = execution_name
|
48
|
+
self.success_state = ["COMPLETED"]
|
49
|
+
self.in_progress_states = ["PENDING", "RUNNING"]
|
50
|
+
|
51
|
+
def hook(self):
|
52
|
+
return SageMakerNotebookHook(execution_name=self.execution_name)
|
53
|
+
|
54
|
+
# override from base sensor
|
55
|
+
def poke(self, context=None):
|
56
|
+
status = self.hook().get_execution_status(execution_id=self.execution_id)
|
57
|
+
|
58
|
+
if status in self.success_state:
|
59
|
+
log_info_message = f"Exiting Execution {self.execution_id} State: {status}"
|
60
|
+
self.log.info(log_info_message)
|
61
|
+
return True
|
62
|
+
elif status in self.in_progress_states:
|
63
|
+
return False
|
64
|
+
else:
|
65
|
+
error_message = f"Exiting Execution {self.execution_id} State: {status}"
|
66
|
+
self.log.info(error_message)
|
67
|
+
raise AirflowException(error_message)
|
68
|
+
|
69
|
+
def execute(self, context: Context):
|
70
|
+
# This will invoke poke method in the base sensor
|
71
|
+
log_info_message = f"Polling Sagemaker Workflows Artifact execution: {self.execution_name} and execution id: {self.execution_id}"
|
72
|
+
self.log.info(log_info_message)
|
73
|
+
super().execute(context=context)
|
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
|
|
29
29
|
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
|
30
30
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
31
31
|
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
|
32
|
+
from airflow.utils.types import NOTSET, ArgNotSet
|
32
33
|
|
33
34
|
if TYPE_CHECKING:
|
34
35
|
from airflow.utils.context import Context
|
@@ -102,7 +103,7 @@ class RedshiftToS3Operator(BaseOperator):
|
|
102
103
|
table: str | None = None,
|
103
104
|
select_query: str | None = None,
|
104
105
|
redshift_conn_id: str = "redshift_default",
|
105
|
-
aws_conn_id: str | None =
|
106
|
+
aws_conn_id: str | None | ArgNotSet = NOTSET,
|
106
107
|
verify: bool | str | None = None,
|
107
108
|
unload_options: list | None = None,
|
108
109
|
autocommit: bool = False,
|
@@ -118,7 +119,6 @@ class RedshiftToS3Operator(BaseOperator):
|
|
118
119
|
self.schema = schema
|
119
120
|
self.table = table
|
120
121
|
self.redshift_conn_id = redshift_conn_id
|
121
|
-
self.aws_conn_id = aws_conn_id
|
122
122
|
self.verify = verify
|
123
123
|
self.unload_options = unload_options or []
|
124
124
|
self.autocommit = autocommit
|
@@ -127,6 +127,16 @@ class RedshiftToS3Operator(BaseOperator):
|
|
127
127
|
self.table_as_file_name = table_as_file_name
|
128
128
|
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
|
129
129
|
self.select_query = select_query
|
130
|
+
# In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
|
131
|
+
# actually provide a connection note that, because we don't want to let the exception bubble up in
|
132
|
+
# that case (since we're silently injecting a connection on their behalf).
|
133
|
+
self._aws_conn_id: str | None
|
134
|
+
if isinstance(aws_conn_id, ArgNotSet):
|
135
|
+
self.conn_set = False
|
136
|
+
self._aws_conn_id = "aws_default"
|
137
|
+
else:
|
138
|
+
self.conn_set = True
|
139
|
+
self._aws_conn_id = aws_conn_id
|
130
140
|
|
131
141
|
def _build_unload_query(
|
132
142
|
self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
|
@@ -176,11 +186,16 @@ class RedshiftToS3Operator(BaseOperator):
|
|
176
186
|
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
|
177
187
|
else:
|
178
188
|
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
|
179
|
-
conn =
|
189
|
+
conn = (
|
190
|
+
S3Hook.get_connection(conn_id=self._aws_conn_id)
|
191
|
+
# Only fetch the connection if it was set by the user and it is not None
|
192
|
+
if self.conn_set and self._aws_conn_id
|
193
|
+
else None
|
194
|
+
)
|
180
195
|
if conn and conn.extra_dejson.get("role_arn", False):
|
181
196
|
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
|
182
197
|
else:
|
183
|
-
s3_hook = S3Hook(aws_conn_id=self.
|
198
|
+
s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
|
184
199
|
credentials = s3_hook.get_credentials()
|
185
200
|
credentials_block = build_credentials_block(credentials)
|
186
201
|
|
@@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
|
|
25
25
|
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
|
26
26
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
27
27
|
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
|
28
|
+
from airflow.utils.types import NOTSET, ArgNotSet
|
28
29
|
|
29
30
|
if TYPE_CHECKING:
|
30
31
|
from airflow.utils.context import Context
|
@@ -93,7 +94,7 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
93
94
|
s3_key: str,
|
94
95
|
schema: str | None = None,
|
95
96
|
redshift_conn_id: str = "redshift_default",
|
96
|
-
aws_conn_id: str | None =
|
97
|
+
aws_conn_id: str | None | ArgNotSet = NOTSET,
|
97
98
|
verify: bool | str | None = None,
|
98
99
|
column_list: list[str] | None = None,
|
99
100
|
copy_options: list | None = None,
|
@@ -117,6 +118,16 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
117
118
|
self.method = method
|
118
119
|
self.upsert_keys = upsert_keys
|
119
120
|
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
|
121
|
+
# In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
|
122
|
+
# actually provide a connection note that, because we don't want to let the exception bubble up in
|
123
|
+
# that case (since we're silently injecting a connection on their behalf).
|
124
|
+
self._aws_conn_id: str | None
|
125
|
+
if isinstance(aws_conn_id, ArgNotSet):
|
126
|
+
self.conn_set = False
|
127
|
+
self._aws_conn_id = "aws_default"
|
128
|
+
else:
|
129
|
+
self.conn_set = True
|
130
|
+
self._aws_conn_id = aws_conn_id
|
120
131
|
|
121
132
|
if self.redshift_data_api_kwargs:
|
122
133
|
for arg in ["sql", "parameters"]:
|
@@ -149,14 +160,19 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
149
160
|
else:
|
150
161
|
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
|
151
162
|
|
152
|
-
conn =
|
163
|
+
conn = (
|
164
|
+
S3Hook.get_connection(conn_id=self._aws_conn_id)
|
165
|
+
# Only fetch the connection if it was set by the user and it is not None
|
166
|
+
if self.conn_set and self._aws_conn_id
|
167
|
+
else None
|
168
|
+
)
|
153
169
|
region_info = ""
|
154
170
|
if conn and conn.extra_dejson.get("region", False):
|
155
171
|
region_info = f"region '{conn.extra_dejson['region']}'"
|
156
172
|
if conn and conn.extra_dejson.get("role_arn", False):
|
157
173
|
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
|
158
174
|
else:
|
159
|
-
s3_hook = S3Hook(aws_conn_id=self.
|
175
|
+
s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
|
160
176
|
credentials = s3_hook.get_credentials()
|
161
177
|
credentials_block = build_credentials_block(credentials)
|
162
178
|
|
@@ -65,10 +65,10 @@ To call the asynchronous `wait` function, first create a hook for the particular
|
|
65
65
|
self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
|
66
66
|
```
|
67
67
|
|
68
|
-
With this hook, we can use the
|
68
|
+
With this hook, we can use the asynchronous get_async_conn method to get access to the aiobotocore client:
|
69
69
|
|
70
70
|
```python
|
71
|
-
async with self.redshift_hook.
|
71
|
+
async with await self.redshift_hook.get_async_conn() as client:
|
72
72
|
await client.get_waiter("cluster_available").wait(
|
73
73
|
ClusterIdentifier=self.cluster_identifier,
|
74
74
|
WaiterConfig={
|
@@ -81,7 +81,7 @@ async with self.redshift_hook.async_conn as client:
|
|
81
81
|
In this case, we are using the built-in cluster_available waiter. If we wanted to use a custom waiter, we would change the code slightly to use the `get_waiter` function from the hook, rather than the aiobotocore client:
|
82
82
|
|
83
83
|
```python
|
84
|
-
async with self.redshift_hook.
|
84
|
+
async with await self.redshift_hook.get_async_conn() as client:
|
85
85
|
waiter = self.redshift_hook.get_waiter("cluster_paused", deferrable=True, client=client)
|
86
86
|
await waiter.wait(
|
87
87
|
ClusterIdentifier=self.cluster_identifier,
|
@@ -131,7 +131,7 @@ For more information about writing custom waiter, see the [README.md](https://gi
|
|
131
131
|
In some cases, a built-in or custom waiter may not be able to solve the problem. In such cases, the asynchronous method used to poll the boto3 API would need to be defined in the hook of the service being used. This method is essentially the same as the synchronous version of the method, except that it will use the aiobotocore client, and will be awaited. For the Redshift example, the async `describe_clusters` method would look as follows:
|
132
132
|
|
133
133
|
```python
|
134
|
-
async with self.
|
134
|
+
async with await self.get_async_conn() as client:
|
135
135
|
response = client.describe_clusters(ClusterIdentifier=self.cluster_identifier)
|
136
136
|
```
|
137
137
|
|
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
55
55
|
|
56
56
|
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
57
57
|
:param waiter_max_attempts: The maximum number of attempts to be made.
|
58
|
+
:param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will
|
59
|
+
be updated.
|
58
60
|
:param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
|
59
61
|
:param region_name: The AWS region where the resources to watch are. To be used to build the hook.
|
60
62
|
:param verify: Whether or not to verify SSL certificates. To be used to build the hook.
|
@@ -77,6 +79,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
77
79
|
return_value: Any,
|
78
80
|
waiter_delay: int,
|
79
81
|
waiter_max_attempts: int,
|
82
|
+
waiter_config_overrides: dict[str, Any] | None = None,
|
80
83
|
aws_conn_id: str | None,
|
81
84
|
region_name: str | None = None,
|
82
85
|
verify: bool | str | None = None,
|
@@ -91,6 +94,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
91
94
|
self.failure_message = failure_message
|
92
95
|
self.status_message = status_message
|
93
96
|
self.status_queries = status_queries
|
97
|
+
self.waiter_config_overrides = waiter_config_overrides
|
94
98
|
|
95
99
|
self.return_key = return_key
|
96
100
|
self.return_value = return_value
|
@@ -139,8 +143,13 @@ class AwsBaseWaiterTrigger(BaseTrigger):
|
|
139
143
|
|
140
144
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
141
145
|
hook = self.hook()
|
142
|
-
async with hook.
|
143
|
-
waiter = hook.get_waiter(
|
146
|
+
async with await hook.get_async_conn() as client:
|
147
|
+
waiter = hook.get_waiter(
|
148
|
+
self.waiter_name,
|
149
|
+
deferrable=True,
|
150
|
+
client=client,
|
151
|
+
config_overrides=self.waiter_config_overrides,
|
152
|
+
)
|
144
153
|
await async_wait(
|
145
154
|
waiter,
|
146
155
|
self.waiter_delay,
|
@@ -167,8 +167,12 @@ class TaskDoneTrigger(BaseTrigger):
|
|
167
167
|
|
168
168
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
169
169
|
async with (
|
170
|
-
EcsHook(
|
171
|
-
|
170
|
+
await EcsHook(
|
171
|
+
aws_conn_id=self.aws_conn_id, region_name=self.region
|
172
|
+
).get_async_conn() as ecs_client,
|
173
|
+
await AwsLogsHook(
|
174
|
+
aws_conn_id=self.aws_conn_id, region_name=self.region
|
175
|
+
).get_async_conn() as logs_client,
|
172
176
|
):
|
173
177
|
waiter = ecs_client.get_waiter("tasks_stopped")
|
174
178
|
logs_token = None
|
@@ -70,7 +70,7 @@ class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
|
|
70
70
|
return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
71
71
|
|
72
72
|
async def run(self):
|
73
|
-
async with self.hook().
|
73
|
+
async with await self.hook().get_async_conn() as client:
|
74
74
|
waiter = client.get_waiter(self.waiter_name)
|
75
75
|
try:
|
76
76
|
await async_wait(
|
@@ -140,7 +140,7 @@ class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
|
|
140
140
|
return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
141
141
|
|
142
142
|
async def run(self):
|
143
|
-
async with self.hook().
|
143
|
+
async with await self.hook().get_async_conn() as client:
|
144
144
|
waiter = client.get_waiter("cluster_deleted")
|
145
145
|
if self.force_delete_compute:
|
146
146
|
await self.delete_any_nodegroups(client=client)
|
@@ -157,7 +157,7 @@ class GlueCatalogPartitionTrigger(BaseTrigger):
|
|
157
157
|
return bool(partitions)
|
158
158
|
|
159
159
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
160
|
-
async with self.hook.
|
160
|
+
async with await self.hook.get_async_conn() as client:
|
161
161
|
while True:
|
162
162
|
result = await self.poke(client=client)
|
163
163
|
if result:
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from collections.abc import Collection
|
21
|
+
from typing import TYPE_CHECKING
|
22
|
+
|
23
|
+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
|
24
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
25
|
+
from airflow.utils.state import DagRunState
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
29
|
+
|
30
|
+
|
31
|
+
class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
|
32
|
+
"""
|
33
|
+
Trigger when an MWAA Dag Run is complete.
|
34
|
+
|
35
|
+
:param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
|
36
|
+
(templated)
|
37
|
+
:param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
|
38
|
+
(templated)
|
39
|
+
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
|
40
|
+
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
|
41
|
+
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
|
42
|
+
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
|
43
|
+
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
|
44
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
|
45
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
|
46
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
*,
|
52
|
+
external_env_name: str,
|
53
|
+
external_dag_id: str,
|
54
|
+
external_dag_run_id: str,
|
55
|
+
success_states: Collection[str] | None = None,
|
56
|
+
failure_states: Collection[str] | None = None,
|
57
|
+
waiter_delay: int = 60,
|
58
|
+
waiter_max_attempts: int = 720,
|
59
|
+
aws_conn_id: str | None = None,
|
60
|
+
) -> None:
|
61
|
+
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
|
62
|
+
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
|
63
|
+
|
64
|
+
if len(self.success_states & self.failure_states):
|
65
|
+
raise ValueError("success_states and failure_states must not have any values in common")
|
66
|
+
|
67
|
+
in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states
|
68
|
+
|
69
|
+
super().__init__(
|
70
|
+
serialized_fields={
|
71
|
+
"external_env_name": external_env_name,
|
72
|
+
"external_dag_id": external_dag_id,
|
73
|
+
"external_dag_run_id": external_dag_run_id,
|
74
|
+
"success_states": success_states,
|
75
|
+
"failure_states": failure_states,
|
76
|
+
},
|
77
|
+
waiter_name="mwaa_dag_run_complete",
|
78
|
+
waiter_args={
|
79
|
+
"Name": external_env_name,
|
80
|
+
"Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
|
81
|
+
"Method": "GET",
|
82
|
+
},
|
83
|
+
failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
|
84
|
+
status_message="State of DAG run",
|
85
|
+
status_queries=["RestApiResponse.state"],
|
86
|
+
return_key="dag_run_id",
|
87
|
+
return_value=external_dag_run_id,
|
88
|
+
waiter_delay=waiter_delay,
|
89
|
+
waiter_max_attempts=waiter_max_attempts,
|
90
|
+
aws_conn_id=aws_conn_id,
|
91
|
+
waiter_config_overrides={
|
92
|
+
"acceptors": _build_waiter_acceptors(
|
93
|
+
success_states=self.success_states,
|
94
|
+
failure_states=self.failure_states,
|
95
|
+
in_progress_states=in_progress_states,
|
96
|
+
)
|
97
|
+
},
|
98
|
+
)
|
99
|
+
|
100
|
+
def hook(self) -> AwsGenericHook:
|
101
|
+
return MwaaHook(
|
102
|
+
aws_conn_id=self.aws_conn_id,
|
103
|
+
region_name=self.region_name,
|
104
|
+
verify=self.verify,
|
105
|
+
config=self.botocore_config,
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
def _build_waiter_acceptors(
|
110
|
+
success_states: set[str], failure_states: set[str], in_progress_states: set[str]
|
111
|
+
) -> list:
|
112
|
+
acceptors = []
|
113
|
+
for state_set, state_waiter_category in (
|
114
|
+
(success_states, "success"),
|
115
|
+
(failure_states, "failure"),
|
116
|
+
(in_progress_states, "retry"),
|
117
|
+
):
|
118
|
+
for dag_run_state in state_set:
|
119
|
+
acceptors.append(
|
120
|
+
{
|
121
|
+
"matcher": "path",
|
122
|
+
"argument": "RestApiResponse.state",
|
123
|
+
"expected": dag_run_state,
|
124
|
+
"state": state_waiter_category,
|
125
|
+
}
|
126
|
+
)
|
127
|
+
|
128
|
+
return acceptors
|
@@ -53,6 +53,9 @@ class S3KeyTrigger(BaseTrigger):
|
|
53
53
|
poke_interval: float = 5.0,
|
54
54
|
should_check_fn: bool = False,
|
55
55
|
use_regex: bool = False,
|
56
|
+
region_name: str | None = None,
|
57
|
+
verify: bool | str | None = None,
|
58
|
+
botocore_config: dict | None = None,
|
56
59
|
**hook_params: Any,
|
57
60
|
):
|
58
61
|
super().__init__()
|
@@ -64,6 +67,9 @@ class S3KeyTrigger(BaseTrigger):
|
|
64
67
|
self.poke_interval = poke_interval
|
65
68
|
self.should_check_fn = should_check_fn
|
66
69
|
self.use_regex = use_regex
|
70
|
+
self.region_name = region_name
|
71
|
+
self.verify = verify
|
72
|
+
self.botocore_config = botocore_config
|
67
73
|
|
68
74
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
69
75
|
"""Serialize S3KeyTrigger arguments and classpath."""
|
@@ -78,17 +84,25 @@ class S3KeyTrigger(BaseTrigger):
|
|
78
84
|
"poke_interval": self.poke_interval,
|
79
85
|
"should_check_fn": self.should_check_fn,
|
80
86
|
"use_regex": self.use_regex,
|
87
|
+
"region_name": self.region_name,
|
88
|
+
"verify": self.verify,
|
89
|
+
"botocore_config": self.botocore_config,
|
81
90
|
},
|
82
91
|
)
|
83
92
|
|
84
93
|
@cached_property
|
85
94
|
def hook(self) -> S3Hook:
|
86
|
-
return S3Hook(
|
95
|
+
return S3Hook(
|
96
|
+
aws_conn_id=self.aws_conn_id,
|
97
|
+
region_name=self.region_name,
|
98
|
+
verify=self.verify,
|
99
|
+
config=self.botocore_config,
|
100
|
+
)
|
87
101
|
|
88
102
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
89
103
|
"""Make an asynchronous connection using S3HookAsync."""
|
90
104
|
try:
|
91
|
-
async with self.hook.
|
105
|
+
async with await self.hook.get_async_conn() as client:
|
92
106
|
while True:
|
93
107
|
if await self.hook.check_key_async(
|
94
108
|
client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex
|
@@ -143,7 +157,9 @@ class S3KeysUnchangedTrigger(BaseTrigger):
|
|
143
157
|
allow_delete: bool = True,
|
144
158
|
aws_conn_id: str | None = "aws_default",
|
145
159
|
last_activity_time: datetime | None = None,
|
160
|
+
region_name: str | None = None,
|
146
161
|
verify: bool | str | None = None,
|
162
|
+
botocore_config: dict | None = None,
|
147
163
|
**hook_params: Any,
|
148
164
|
):
|
149
165
|
super().__init__()
|
@@ -160,8 +176,10 @@ class S3KeysUnchangedTrigger(BaseTrigger):
|
|
160
176
|
self.allow_delete = allow_delete
|
161
177
|
self.aws_conn_id = aws_conn_id
|
162
178
|
self.last_activity_time = last_activity_time
|
163
|
-
self.verify = verify
|
164
179
|
self.polling_period_seconds = 0
|
180
|
+
self.region_name = region_name
|
181
|
+
self.verify = verify
|
182
|
+
self.botocore_config = botocore_config
|
165
183
|
self.hook_params = hook_params
|
166
184
|
|
167
185
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
@@ -179,19 +197,26 @@ class S3KeysUnchangedTrigger(BaseTrigger):
|
|
179
197
|
"aws_conn_id": self.aws_conn_id,
|
180
198
|
"last_activity_time": self.last_activity_time,
|
181
199
|
"hook_params": self.hook_params,
|
182
|
-
"verify": self.verify,
|
183
200
|
"polling_period_seconds": self.polling_period_seconds,
|
201
|
+
"region_name": self.region_name,
|
202
|
+
"verify": self.verify,
|
203
|
+
"botocore_config": self.botocore_config,
|
184
204
|
},
|
185
205
|
)
|
186
206
|
|
187
207
|
@cached_property
|
188
208
|
def hook(self) -> S3Hook:
|
189
|
-
return S3Hook(
|
209
|
+
return S3Hook(
|
210
|
+
aws_conn_id=self.aws_conn_id,
|
211
|
+
region_name=self.region_name,
|
212
|
+
verify=self.verify,
|
213
|
+
config=self.botocore_config,
|
214
|
+
)
|
190
215
|
|
191
216
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
192
217
|
"""Make an asynchronous connection using S3Hook."""
|
193
218
|
try:
|
194
|
-
async with self.hook.
|
219
|
+
async with await self.hook.get_async_conn() as client:
|
195
220
|
while True:
|
196
221
|
result = await self.hook.is_keys_unchanged_async(
|
197
222
|
client=client,
|
@@ -108,7 +108,7 @@ class SageMakerTrigger(BaseTrigger):
|
|
108
108
|
|
109
109
|
async def run(self):
|
110
110
|
self.log.info("job name is %s and job type is %s", self.job_name, self.job_type)
|
111
|
-
async with self.hook.
|
111
|
+
async with await self.hook.get_async_conn() as client:
|
112
112
|
waiter = self.hook.get_waiter(
|
113
113
|
self._get_job_type_waiter(self.job_type), deferrable=True, client=client
|
114
114
|
)
|
@@ -166,7 +166,7 @@ class SageMakerPipelineTrigger(BaseTrigger):
|
|
166
166
|
|
167
167
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
168
168
|
hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
169
|
-
async with hook.
|
169
|
+
async with await hook.get_async_conn() as conn:
|
170
170
|
waiter = hook.get_waiter(self._waiter_name[self.waiter_type], deferrable=True, client=conn)
|
171
171
|
for _ in range(self.waiter_max_attempts):
|
172
172
|
try:
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
|
18
|
+
"""This module contains the Amazon SageMaker Unified Studio Notebook job trigger."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
from airflow.triggers.base import BaseTrigger
|
23
|
+
|
24
|
+
|
25
|
+
class SageMakerNotebookJobTrigger(BaseTrigger):
|
26
|
+
"""
|
27
|
+
Watches for a notebook job, triggers when it finishes.
|
28
|
+
|
29
|
+
Examples:
|
30
|
+
.. code-block:: python
|
31
|
+
|
32
|
+
from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import SageMakerNotebookJobTrigger
|
33
|
+
|
34
|
+
notebook_trigger = SageMakerNotebookJobTrigger(
|
35
|
+
execution_id="notebook_job_1234",
|
36
|
+
execution_name="notebook_task",
|
37
|
+
waiter_delay=10,
|
38
|
+
waiter_max_attempts=1440,
|
39
|
+
)
|
40
|
+
|
41
|
+
:param execution_id: A unique, meaningful id for the task.
|
42
|
+
:param execution_name: A unique, meaningful name for the task.
|
43
|
+
:param waiter_delay: Interval in seconds to check the notebook execution status.
|
44
|
+
:param waiter_max_attempts: Number of attempts to wait before returning FAILED.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(self, execution_id, execution_name, waiter_delay, waiter_max_attempts, **kwargs):
|
48
|
+
super().__init__(**kwargs)
|
49
|
+
self.execution_id = execution_id
|
50
|
+
self.execution_name = execution_name
|
51
|
+
self.waiter_delay = waiter_delay
|
52
|
+
self.waiter_max_attempts = waiter_max_attempts
|
53
|
+
|
54
|
+
def serialize(self):
|
55
|
+
return (
|
56
|
+
# dynamically generate the fully qualified name of the class
|
57
|
+
self.__class__.__module__ + "." + self.__class__.__qualname__,
|
58
|
+
{
|
59
|
+
"execution_id": self.execution_id,
|
60
|
+
"execution_name": self.execution_name,
|
61
|
+
"poll_interval": self.poll_interval,
|
62
|
+
},
|
63
|
+
)
|
64
|
+
|
65
|
+
async def run(self):
|
66
|
+
pass
|
@@ -23,14 +23,22 @@ from typing import TYPE_CHECKING, Any
|
|
23
23
|
from airflow.exceptions import AirflowException
|
24
24
|
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
|
25
25
|
from airflow.providers.amazon.aws.utils.sqs import process_response
|
26
|
-
from airflow.
|
26
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
27
|
+
|
28
|
+
if AIRFLOW_V_3_0_PLUS:
|
29
|
+
from airflow.triggers.base import BaseEventTrigger, TriggerEvent
|
30
|
+
else:
|
31
|
+
from airflow.triggers.base import ( # type: ignore
|
32
|
+
BaseTrigger as BaseEventTrigger,
|
33
|
+
TriggerEvent,
|
34
|
+
)
|
27
35
|
|
28
36
|
if TYPE_CHECKING:
|
29
37
|
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
|
30
38
|
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType
|
31
39
|
|
32
40
|
|
33
|
-
class SqsSensorTrigger(
|
41
|
+
class SqsSensorTrigger(BaseEventTrigger):
|
34
42
|
"""
|
35
43
|
Asynchronously get messages from an Amazon SQS queue and then delete the messages from the queue.
|
36
44
|
|
@@ -176,7 +184,7 @@ class SqsSensorTrigger(BaseTrigger):
|
|
176
184
|
while True:
|
177
185
|
# This loop will run indefinitely until the timeout, which is set in the self.defer
|
178
186
|
# method, is reached.
|
179
|
-
async with self.hook.
|
187
|
+
async with await self.hook.get_async_conn() as client:
|
180
188
|
result = await self.poke(client=client)
|
181
189
|
if result:
|
182
190
|
yield TriggerEvent({"status": "success", "message_batch": result})
|
@@ -14,3 +14,15 @@
|
|
14
14
|
# KIND, either express or implied. See the License for the
|
15
15
|
# specific language governing permissions and limitations
|
16
16
|
# under the License.
|
17
|
+
|
18
|
+
"""This module contains utils for the Amazon SageMaker Unified Studio Notebook plugin."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import os
|
23
|
+
|
24
|
+
workflows_env_key = "WORKFLOWS_ENV"
|
25
|
+
|
26
|
+
|
27
|
+
def is_local_runner():
|
28
|
+
return os.getenv(workflows_env_key, "") == "Local"
|