apache-airflow-providers-amazon 8.16.0rc1__py3-none-any.whl → 8.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -0
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +34 -19
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +44 -1
- airflow/providers/amazon/aws/auth_manager/cli/__init__.py +16 -0
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +178 -0
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +62 -0
- airflow/providers/amazon/aws/auth_manager/cli/schema.json +171 -0
- airflow/providers/amazon/aws/auth_manager/constants.py +1 -0
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +77 -23
- airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +17 -0
- airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
- airflow/providers/amazon/aws/executors/utils/__init__.py +16 -0
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +60 -0
- airflow/providers/amazon/aws/hooks/athena_sql.py +168 -0
- airflow/providers/amazon/aws/hooks/base_aws.py +14 -0
- airflow/providers/amazon/aws/hooks/quicksight.py +33 -18
- airflow/providers/amazon/aws/hooks/redshift_data.py +66 -17
- airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -1
- airflow/providers/amazon/aws/hooks/s3.py +18 -4
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
- airflow/providers/amazon/aws/operators/batch.py +33 -15
- airflow/providers/amazon/aws/operators/cloud_formation.py +37 -26
- airflow/providers/amazon/aws/operators/datasync.py +19 -18
- airflow/providers/amazon/aws/operators/dms.py +57 -69
- airflow/providers/amazon/aws/operators/ec2.py +19 -5
- airflow/providers/amazon/aws/operators/emr.py +30 -10
- airflow/providers/amazon/aws/operators/eventbridge.py +57 -80
- airflow/providers/amazon/aws/operators/quicksight.py +17 -24
- airflow/providers/amazon/aws/operators/redshift_data.py +68 -19
- airflow/providers/amazon/aws/operators/s3.py +1 -1
- airflow/providers/amazon/aws/operators/sagemaker.py +42 -12
- airflow/providers/amazon/aws/sensors/cloud_formation.py +30 -25
- airflow/providers/amazon/aws/sensors/dms.py +31 -24
- airflow/providers/amazon/aws/sensors/dynamodb.py +15 -15
- airflow/providers/amazon/aws/sensors/quicksight.py +34 -24
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +41 -3
- airflow/providers/amazon/aws/sensors/s3.py +13 -8
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
- airflow/providers/amazon/aws/triggers/redshift_data.py +113 -0
- airflow/providers/amazon/aws/triggers/s3.py +9 -4
- airflow/providers/amazon/get_provider_info.py +55 -16
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/METADATA +17 -15
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/RECORD +46 -38
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/entry_points.txt +0 -0
@@ -28,7 +28,7 @@ from collections import defaultdict, deque
|
|
28
28
|
from copy import deepcopy
|
29
29
|
from typing import TYPE_CHECKING
|
30
30
|
|
31
|
-
from botocore.exceptions import ClientError
|
31
|
+
from botocore.exceptions import ClientError, NoCredentialsError
|
32
32
|
|
33
33
|
from airflow.configuration import conf
|
34
34
|
from airflow.exceptions import AirflowException
|
@@ -42,6 +42,9 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
|
|
42
42
|
EcsQueuedTask,
|
43
43
|
EcsTaskCollection,
|
44
44
|
)
|
45
|
+
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
|
46
|
+
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
47
|
+
from airflow.utils import timezone
|
45
48
|
from airflow.utils.state import State
|
46
49
|
|
47
50
|
if TYPE_CHECKING:
|
@@ -51,6 +54,12 @@ if TYPE_CHECKING:
|
|
51
54
|
ExecutorConfigType,
|
52
55
|
)
|
53
56
|
|
57
|
+
INVALID_CREDENTIALS_EXCEPTIONS = [
|
58
|
+
"ExpiredTokenException",
|
59
|
+
"InvalidClientTokenId",
|
60
|
+
"UnrecognizedClientException",
|
61
|
+
]
|
62
|
+
|
54
63
|
|
55
64
|
class AwsEcsExecutor(BaseExecutor):
|
56
65
|
"""
|
@@ -91,30 +100,15 @@ class AwsEcsExecutor(BaseExecutor):
|
|
91
100
|
|
92
101
|
self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
|
93
102
|
self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
)
|
99
|
-
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
|
100
|
-
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
103
|
+
self.attempts_since_last_successful_connection = 0
|
104
|
+
|
105
|
+
self.load_ecs_connection(check_connection=False)
|
106
|
+
self.IS_BOTO_CONNECTION_HEALTHY = False
|
101
107
|
|
102
|
-
self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
|
103
108
|
self.run_task_kwargs = self._load_run_kwargs()
|
104
109
|
|
105
110
|
def start(self):
|
106
|
-
"""
|
107
|
-
Make a test API call to check the health of the ECS Executor.
|
108
|
-
|
109
|
-
Deliberately use an invalid task ID, some potential outcomes in order:
|
110
|
-
1. "AccessDeniedException" is raised if there are insufficient permissions.
|
111
|
-
2. "ClusterNotFoundException" is raised if permissions exist but the cluster does not.
|
112
|
-
3. The API responds with a failure message if the cluster is found and there
|
113
|
-
are permissions, but the cluster itself has issues.
|
114
|
-
4. "InvalidParameterException" is raised if the permissions and cluster exist but the task does not.
|
115
|
-
|
116
|
-
The last one is considered a success state for the purposes of this check.
|
117
|
-
"""
|
111
|
+
"""This is called by the scheduler when the Executor is being run for the first time."""
|
118
112
|
check_health = conf.getboolean(
|
119
113
|
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
|
120
114
|
)
|
@@ -123,7 +117,25 @@ class AwsEcsExecutor(BaseExecutor):
|
|
123
117
|
return
|
124
118
|
|
125
119
|
self.log.info("Starting ECS Executor and determining health...")
|
120
|
+
try:
|
121
|
+
self.check_health()
|
122
|
+
except AirflowException:
|
123
|
+
self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
|
124
|
+
raise
|
125
|
+
|
126
|
+
def check_health(self):
|
127
|
+
"""
|
128
|
+
Make a test API call to check the health of the ECS Executor.
|
126
129
|
|
130
|
+
Deliberately use an invalid task ID, some potential outcomes in order:
|
131
|
+
1. `AccessDeniedException` is raised if there are insufficient permissions.
|
132
|
+
2. `ClusterNotFoundException` is raised if permissions exist but the cluster does not.
|
133
|
+
3. The API responds with a failure message if the cluster is found and there
|
134
|
+
are permissions, but the cluster itself has issues.
|
135
|
+
4. `InvalidParameterException` is raised if the permissions and cluster exist but the task does not.
|
136
|
+
|
137
|
+
The last one is considered a success state for the purposes of this check.
|
138
|
+
"""
|
127
139
|
success_status = "succeeded."
|
128
140
|
status = success_status
|
129
141
|
|
@@ -151,18 +163,50 @@ class AwsEcsExecutor(BaseExecutor):
|
|
151
163
|
finally:
|
152
164
|
msg_prefix = "ECS Executor health check has %s"
|
153
165
|
if status == success_status:
|
166
|
+
self.IS_BOTO_CONNECTION_HEALTHY = True
|
154
167
|
self.log.info(msg_prefix, status)
|
155
168
|
else:
|
156
169
|
msg_error_suffix = (
|
157
|
-
"The ECS executor will not be able to run Airflow tasks until the issue is addressed.
|
158
|
-
"Stopping the Airflow Scheduler from starting until the issue is resolved."
|
170
|
+
"The ECS executor will not be able to run Airflow tasks until the issue is addressed."
|
159
171
|
)
|
160
172
|
raise AirflowException(msg_prefix % status + msg_error_suffix)
|
161
173
|
|
174
|
+
def load_ecs_connection(self, check_connection: bool = True):
|
175
|
+
self.log.info("Loading Connection information")
|
176
|
+
aws_conn_id = conf.get(
|
177
|
+
CONFIG_GROUP_NAME,
|
178
|
+
AllEcsConfigKeys.AWS_CONN_ID,
|
179
|
+
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
|
180
|
+
)
|
181
|
+
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
|
182
|
+
self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
|
183
|
+
self.attempts_since_last_successful_connection += 1
|
184
|
+
self.last_connection_reload = timezone.utcnow()
|
185
|
+
|
186
|
+
if check_connection:
|
187
|
+
self.check_health()
|
188
|
+
self.attempts_since_last_successful_connection = 0
|
189
|
+
|
162
190
|
def sync(self):
|
191
|
+
if not self.IS_BOTO_CONNECTION_HEALTHY:
|
192
|
+
exponential_backoff_retry(
|
193
|
+
self.last_connection_reload,
|
194
|
+
self.attempts_since_last_successful_connection,
|
195
|
+
self.load_ecs_connection,
|
196
|
+
)
|
197
|
+
if not self.IS_BOTO_CONNECTION_HEALTHY:
|
198
|
+
return
|
163
199
|
try:
|
164
200
|
self.sync_running_tasks()
|
165
201
|
self.attempt_task_runs()
|
202
|
+
except (ClientError, NoCredentialsError) as error:
|
203
|
+
error_code = error.response["Error"]["Code"]
|
204
|
+
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
|
205
|
+
self.IS_BOTO_CONNECTION_HEALTHY = False
|
206
|
+
self.log.warning(
|
207
|
+
f"AWS credentials are either missing or expired: {error}.\nRetrying connection"
|
208
|
+
)
|
209
|
+
|
166
210
|
except Exception:
|
167
211
|
# We catch any and all exceptions because otherwise they would bubble
|
168
212
|
# up and kill the scheduler process
|
@@ -176,6 +220,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
176
220
|
return
|
177
221
|
|
178
222
|
describe_tasks_response = self.__describe_tasks(all_task_arns)
|
223
|
+
|
179
224
|
self.log.debug("Active Workers: %s", describe_tasks_response)
|
180
225
|
|
181
226
|
if describe_tasks_response["failures"]:
|
@@ -288,6 +333,15 @@ class AwsEcsExecutor(BaseExecutor):
|
|
288
333
|
_failure_reasons = []
|
289
334
|
try:
|
290
335
|
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
|
336
|
+
except NoCredentialsError:
|
337
|
+
self.pending_tasks.appendleft(ecs_task)
|
338
|
+
raise
|
339
|
+
except ClientError as e:
|
340
|
+
error_code = e.response["Error"]["Code"]
|
341
|
+
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
|
342
|
+
self.pending_tasks.appendleft(ecs_task)
|
343
|
+
raise
|
344
|
+
_failure_reasons.append(str(e))
|
291
345
|
except Exception as e:
|
292
346
|
# Failed to even get a response back from the Boto3 API or something else went
|
293
347
|
# wrong. For any possible failure we want to add the exception reasons to the
|
@@ -40,6 +40,7 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
|
|
40
40
|
camelize_dict_keys,
|
41
41
|
parse_assign_public_ip,
|
42
42
|
)
|
43
|
+
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
43
44
|
from airflow.utils.helpers import prune_dict
|
44
45
|
|
45
46
|
|
@@ -60,6 +61,22 @@ def build_task_kwargs() -> dict:
|
|
60
61
|
task_kwargs = _fetch_config_values()
|
61
62
|
task_kwargs.update(_fetch_templated_kwargs())
|
62
63
|
|
64
|
+
has_launch_type: bool = "launch_type" in task_kwargs
|
65
|
+
has_capacity_provider: bool = "capacity_provider_strategy" in task_kwargs
|
66
|
+
|
67
|
+
if has_capacity_provider and has_launch_type:
|
68
|
+
raise ValueError(
|
69
|
+
"capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both."
|
70
|
+
)
|
71
|
+
elif "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type):
|
72
|
+
# Default API behavior if neither is provided is to fall back on the default capacity
|
73
|
+
# provider if it exists. Since it is not a required value, check if there is one
|
74
|
+
# before using it, and if there is not then use the FARGATE launch_type as
|
75
|
+
# the final fallback.
|
76
|
+
cluster = EcsHook().conn.describe_clusters(clusters=[task_kwargs["cluster"]])["clusters"][0]
|
77
|
+
if not cluster.get("defaultCapacityProviderStrategy"):
|
78
|
+
task_kwargs["launch_type"] = "FARGATE"
|
79
|
+
|
63
80
|
# There can only be 1 count of these containers
|
64
81
|
task_kwargs["count"] = 1 # type: ignore
|
65
82
|
# There could be a generic approach to the below, but likely more convoluted then just manually ensuring
|
@@ -44,7 +44,6 @@ CONFIG_DEFAULTS = {
|
|
44
44
|
"conn_id": "aws_default",
|
45
45
|
"max_run_task_attempts": "3",
|
46
46
|
"assign_public_ip": "False",
|
47
|
-
"launch_type": "FARGATE",
|
48
47
|
"platform_version": "LATEST",
|
49
48
|
"check_health_on_startup": "True",
|
50
49
|
}
|
@@ -81,6 +80,7 @@ class RunTaskKwargsConfigKeys(BaseConfigKeys):
|
|
81
80
|
"""Keys loaded into the config which are valid ECS run_task kwargs."""
|
82
81
|
|
83
82
|
ASSIGN_PUBLIC_IP = "assign_public_ip"
|
83
|
+
CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
|
84
84
|
CLUSTER = "cluster"
|
85
85
|
LAUNCH_TYPE = "launch_type"
|
86
86
|
PLATFORM_VERSION = "platform_version"
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import logging
|
20
|
+
from datetime import datetime, timedelta
|
21
|
+
from typing import Callable
|
22
|
+
|
23
|
+
from airflow.utils import timezone
|
24
|
+
|
25
|
+
log = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
def exponential_backoff_retry(
|
29
|
+
last_attempt_time: datetime,
|
30
|
+
attempts_since_last_successful: int,
|
31
|
+
callable_function: Callable,
|
32
|
+
max_delay: int = 60 * 2,
|
33
|
+
max_attempts: int = -1,
|
34
|
+
exponent_base: int = 4,
|
35
|
+
) -> None:
|
36
|
+
"""
|
37
|
+
Retries a callable function with exponential backoff between attempts if it raises an exception.
|
38
|
+
|
39
|
+
:param last_attempt_time: Timestamp of last attempt call.
|
40
|
+
:param attempts_since_last_successful: Number of attempts since last success.
|
41
|
+
:param callable_function: Callable function that will be called if enough time has passed.
|
42
|
+
:param max_delay: Maximum delay in seconds between retries. Default 120.
|
43
|
+
:param max_attempts: Maximum number of attempts before giving up. Default -1 (no limit).
|
44
|
+
:param exponent_base: Exponent base to calculate delay. Default 4.
|
45
|
+
"""
|
46
|
+
if max_attempts != -1 and attempts_since_last_successful >= max_attempts:
|
47
|
+
log.error("Max attempts reached. Exiting.")
|
48
|
+
return
|
49
|
+
|
50
|
+
delay = min((exponent_base**attempts_since_last_successful), max_delay)
|
51
|
+
next_retry_time = last_attempt_time + timedelta(seconds=delay)
|
52
|
+
current_time = timezone.utcnow()
|
53
|
+
|
54
|
+
if current_time >= next_retry_time:
|
55
|
+
try:
|
56
|
+
callable_function()
|
57
|
+
except Exception:
|
58
|
+
log.exception("Error calling %r", callable_function.__name__)
|
59
|
+
next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay)
|
60
|
+
log.info("Waiting for %s seconds before retrying.", next_delay)
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import json
|
20
|
+
from functools import cached_property
|
21
|
+
from typing import TYPE_CHECKING, Any
|
22
|
+
|
23
|
+
import pyathena
|
24
|
+
from sqlalchemy.engine.url import URL
|
25
|
+
|
26
|
+
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
27
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
28
|
+
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
29
|
+
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from pyathena.connection import Connection as AthenaConnection
|
33
|
+
|
34
|
+
|
35
|
+
class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
36
|
+
"""Interact with Amazon Athena.
|
37
|
+
|
38
|
+
Provide wrapper around PyAthena library.
|
39
|
+
|
40
|
+
:param athena_conn_id: :ref:`Amazon Athena Connection <howto/connection:athena>`.
|
41
|
+
|
42
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
43
|
+
are passed down to the underlying AwsBaseHook.
|
44
|
+
|
45
|
+
You can specify ``driver`` in ``extra`` of your connection in order to use
|
46
|
+
a different driver than the default ``rest``.
|
47
|
+
|
48
|
+
Also, aws_domain could be specified in ``extra`` of your connection.
|
49
|
+
|
50
|
+
PyAthena and AWS Authentication parameters could be passed in extra field of ``athena_conn_id`` connection.
|
51
|
+
|
52
|
+
Passing authentication parameters in ``athena_conn_id`` will override those in ``aws_conn_id``.
|
53
|
+
|
54
|
+
.. seealso::
|
55
|
+
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
56
|
+
|
57
|
+
.. note::
|
58
|
+
get_uri() depends on SQLAlchemy and PyAthena.
|
59
|
+
"""
|
60
|
+
|
61
|
+
conn_name_attr = "athena_conn_id"
|
62
|
+
default_conn_name = "athena_default"
|
63
|
+
conn_type = "athena"
|
64
|
+
hook_name = "Amazon Athena"
|
65
|
+
supports_autocommit = True
|
66
|
+
|
67
|
+
def __init__(self, athena_conn_id: str = default_conn_name, *args, **kwargs) -> None:
|
68
|
+
super().__init__(*args, **kwargs)
|
69
|
+
self.athena_conn_id = athena_conn_id
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def get_ui_field_behaviour(cls) -> dict[str, Any]:
|
73
|
+
"""Return custom UI field behaviour for AWS Athena Connection."""
|
74
|
+
return {
|
75
|
+
"hidden_fields": ["host", "port"],
|
76
|
+
"relabeling": {
|
77
|
+
"login": "AWS Access Key ID",
|
78
|
+
"password": "AWS Secret Access Key",
|
79
|
+
},
|
80
|
+
"placeholders": {
|
81
|
+
"login": "AKIAIOSFODNN7EXAMPLE",
|
82
|
+
"password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
83
|
+
"extra": json.dumps(
|
84
|
+
{
|
85
|
+
"aws_domain": "amazonaws.com",
|
86
|
+
"driver": "rest",
|
87
|
+
"s3_staging_dir": "s3://bucket_name/staging/",
|
88
|
+
"work_group": "primary",
|
89
|
+
"region_name": "us-east-1",
|
90
|
+
"session_kwargs": {"profile_name": "default"},
|
91
|
+
"config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}},
|
92
|
+
"role_arn": "arn:aws:iam::123456789098:role/role-name",
|
93
|
+
"assume_role_method": "assume_role",
|
94
|
+
"assume_role_kwargs": {"RoleSessionName": "airflow"},
|
95
|
+
"aws_session_token": "AQoDYXdzEJr...EXAMPLETOKEN",
|
96
|
+
"endpoint_url": "http://localhost:4566",
|
97
|
+
},
|
98
|
+
indent=2,
|
99
|
+
),
|
100
|
+
},
|
101
|
+
}
|
102
|
+
|
103
|
+
@cached_property
|
104
|
+
def conn_config(self) -> AwsConnectionWrapper:
|
105
|
+
"""Get the Airflow Connection object and wrap it in helper (cached)."""
|
106
|
+
athena_conn = self.get_connection(self.athena_conn_id)
|
107
|
+
if self.aws_conn_id:
|
108
|
+
try:
|
109
|
+
connection = self.get_connection(self.aws_conn_id)
|
110
|
+
connection.login = athena_conn.login
|
111
|
+
connection.password = athena_conn.password
|
112
|
+
connection.schema = athena_conn.schema
|
113
|
+
connection.set_extra(json.dumps({**athena_conn.extra_dejson, **connection.extra_dejson}))
|
114
|
+
except AirflowNotFoundException:
|
115
|
+
connection = athena_conn
|
116
|
+
connection.conn_type = "aws"
|
117
|
+
self.log.warning(
|
118
|
+
"Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
|
119
|
+
)
|
120
|
+
|
121
|
+
return AwsConnectionWrapper(
|
122
|
+
conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
|
123
|
+
)
|
124
|
+
|
125
|
+
@property
|
126
|
+
def conn(self) -> AwsConnectionWrapper:
|
127
|
+
"""Get Aws Connection Wrapper object."""
|
128
|
+
return self.conn_config
|
129
|
+
|
130
|
+
def _get_conn_params(self) -> dict[str, str | None]:
|
131
|
+
"""Retrieve connection parameters."""
|
132
|
+
if not self.conn.region_name:
|
133
|
+
raise AirflowException("region_name must be specified in the connection's extra")
|
134
|
+
|
135
|
+
return dict(
|
136
|
+
driver=self.conn.extra_dejson.get("driver", "rest"),
|
137
|
+
schema_name=self.conn.schema,
|
138
|
+
region_name=self.conn.region_name,
|
139
|
+
aws_domain=self.conn.extra_dejson.get("aws_domain", "amazonaws.com"),
|
140
|
+
)
|
141
|
+
|
142
|
+
def get_uri(self) -> str:
|
143
|
+
"""Overridden to use the Athena dialect as driver name."""
|
144
|
+
conn_params = self._get_conn_params()
|
145
|
+
creds = self.get_credentials(region_name=conn_params["region_name"])
|
146
|
+
|
147
|
+
return URL.create(
|
148
|
+
f'awsathena+{conn_params["driver"]}',
|
149
|
+
username=creds.access_key,
|
150
|
+
password=creds.secret_key,
|
151
|
+
host=f'athena.{conn_params["region_name"]}.{conn_params["aws_domain"]}',
|
152
|
+
port=443,
|
153
|
+
database=conn_params["schema_name"],
|
154
|
+
query={"aws_session_token": creds.token, **self.conn.extra_dejson},
|
155
|
+
)
|
156
|
+
|
157
|
+
def get_conn(self) -> AthenaConnection:
|
158
|
+
"""Get a ``pyathena.Connection`` object."""
|
159
|
+
conn_params = self._get_conn_params()
|
160
|
+
|
161
|
+
conn_kwargs: dict = {
|
162
|
+
"schema_name": conn_params["schema_name"],
|
163
|
+
"region_name": conn_params["region_name"],
|
164
|
+
"session": self.get_session(region_name=conn_params["region_name"]),
|
165
|
+
**self.conn.extra_dejson,
|
166
|
+
}
|
167
|
+
|
168
|
+
return pyathena.connect(**conn_kwargs)
|
@@ -629,6 +629,20 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
629
629
|
"""Verify or not SSL certificates boto3 client/resource read-only property."""
|
630
630
|
return self.conn_config.verify
|
631
631
|
|
632
|
+
@cached_property
|
633
|
+
def account_id(self) -> str:
|
634
|
+
"""Return associated AWS Account ID."""
|
635
|
+
return (
|
636
|
+
self.get_session(region_name=self.region_name)
|
637
|
+
.client(
|
638
|
+
service_name="sts",
|
639
|
+
endpoint_url=self.conn_config.get_service_endpoint_url("sts"),
|
640
|
+
config=self.config,
|
641
|
+
verify=self.verify,
|
642
|
+
)
|
643
|
+
.get_caller_identity()["Account"]
|
644
|
+
)
|
645
|
+
|
632
646
|
def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session:
|
633
647
|
"""Get the underlying boto3.session.Session(region_name=region_name)."""
|
634
648
|
return SessionFactory(
|
@@ -18,13 +18,13 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import time
|
21
|
+
import warnings
|
21
22
|
from functools import cached_property
|
22
23
|
|
23
24
|
from botocore.exceptions import ClientError
|
24
25
|
|
25
|
-
from airflow.exceptions import AirflowException
|
26
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
26
27
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
27
|
-
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
28
28
|
|
29
29
|
|
30
30
|
class QuickSightHook(AwsBaseHook):
|
@@ -46,10 +46,6 @@ class QuickSightHook(AwsBaseHook):
|
|
46
46
|
def __init__(self, *args, **kwargs):
|
47
47
|
super().__init__(client_type="quicksight", *args, **kwargs)
|
48
48
|
|
49
|
-
@cached_property
|
50
|
-
def sts_hook(self):
|
51
|
-
return StsHook(aws_conn_id=self.aws_conn_id)
|
52
|
-
|
53
49
|
def create_ingestion(
|
54
50
|
self,
|
55
51
|
data_set_id: str,
|
@@ -57,6 +53,7 @@ class QuickSightHook(AwsBaseHook):
|
|
57
53
|
ingestion_type: str,
|
58
54
|
wait_for_completion: bool = True,
|
59
55
|
check_interval: int = 30,
|
56
|
+
aws_account_id: str | None = None,
|
60
57
|
) -> dict:
|
61
58
|
"""
|
62
59
|
Create and start a new SPICE ingestion for a dataset; refresh the SPICE datasets.
|
@@ -66,18 +63,18 @@ class QuickSightHook(AwsBaseHook):
|
|
66
63
|
|
67
64
|
:param data_set_id: ID of the dataset used in the ingestion.
|
68
65
|
:param ingestion_id: ID for the ingestion.
|
69
|
-
:param ingestion_type: Type of ingestion
|
66
|
+
:param ingestion_type: Type of ingestion: "INCREMENTAL_REFRESH"|"FULL_REFRESH"
|
70
67
|
:param wait_for_completion: if the program should keep running until job finishes
|
71
68
|
:param check_interval: the time interval in seconds which the operator
|
72
69
|
will check the status of QuickSight Ingestion
|
70
|
+
:param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
|
73
71
|
:return: Returns descriptive information about the created data ingestion
|
74
72
|
having Ingestion ARN, HTTP status, ingestion ID and ingestion status.
|
75
73
|
"""
|
74
|
+
aws_account_id = aws_account_id or self.account_id
|
76
75
|
self.log.info("Creating QuickSight Ingestion for data set id %s.", data_set_id)
|
77
|
-
quicksight_client = self.get_conn()
|
78
76
|
try:
|
79
|
-
|
80
|
-
create_ingestion_response = quicksight_client.create_ingestion(
|
77
|
+
create_ingestion_response = self.conn.create_ingestion(
|
81
78
|
DataSetId=data_set_id,
|
82
79
|
IngestionId=ingestion_id,
|
83
80
|
IngestionType=ingestion_type,
|
@@ -97,20 +94,21 @@ class QuickSightHook(AwsBaseHook):
|
|
97
94
|
self.log.error("Failed to run Amazon QuickSight create_ingestion API, error: %s", general_error)
|
98
95
|
raise
|
99
96
|
|
100
|
-
def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> str:
|
97
|
+
def get_status(self, aws_account_id: str | None, data_set_id: str, ingestion_id: str) -> str:
|
101
98
|
"""
|
102
99
|
Get the current status of QuickSight Create Ingestion API.
|
103
100
|
|
104
101
|
.. seealso::
|
105
102
|
- :external+boto3:py:meth:`QuickSight.Client.describe_ingestion`
|
106
103
|
|
107
|
-
:param aws_account_id: An AWS Account ID
|
104
|
+
:param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
|
108
105
|
:param data_set_id: QuickSight Data Set ID
|
109
106
|
:param ingestion_id: QuickSight Ingestion ID
|
110
107
|
:return: An QuickSight Ingestion Status
|
111
108
|
"""
|
109
|
+
aws_account_id = aws_account_id or self.account_id
|
112
110
|
try:
|
113
|
-
describe_ingestion_response = self.
|
111
|
+
describe_ingestion_response = self.conn.describe_ingestion(
|
114
112
|
AwsAccountId=aws_account_id, DataSetId=data_set_id, IngestionId=ingestion_id
|
115
113
|
)
|
116
114
|
return describe_ingestion_response["Ingestion"]["IngestionStatus"]
|
@@ -119,17 +117,19 @@ class QuickSightHook(AwsBaseHook):
|
|
119
117
|
except ClientError as e:
|
120
118
|
raise AirflowException(f"AWS request failed: {e}")
|
121
119
|
|
122
|
-
def get_error_info(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> dict | None:
|
120
|
+
def get_error_info(self, aws_account_id: str | None, data_set_id: str, ingestion_id: str) -> dict | None:
|
123
121
|
"""
|
124
122
|
Get info about the error if any.
|
125
123
|
|
126
|
-
:param aws_account_id: An AWS Account ID
|
124
|
+
:param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
|
127
125
|
:param data_set_id: QuickSight Data Set ID
|
128
126
|
:param ingestion_id: QuickSight Ingestion ID
|
129
127
|
:return: Error info dict containing the error type (key 'Type') and message (key 'Message')
|
130
128
|
if available. Else, returns None.
|
131
129
|
"""
|
132
|
-
|
130
|
+
aws_account_id = aws_account_id or self.account_id
|
131
|
+
|
132
|
+
describe_ingestion_response = self.conn.describe_ingestion(
|
133
133
|
AwsAccountId=aws_account_id, DataSetId=data_set_id, IngestionId=ingestion_id
|
134
134
|
)
|
135
135
|
# using .get() to get None if the key is not present, instead of an exception.
|
@@ -137,7 +137,7 @@ class QuickSightHook(AwsBaseHook):
|
|
137
137
|
|
138
138
|
def wait_for_state(
|
139
139
|
self,
|
140
|
-
aws_account_id: str,
|
140
|
+
aws_account_id: str | None,
|
141
141
|
data_set_id: str,
|
142
142
|
ingestion_id: str,
|
143
143
|
target_state: set,
|
@@ -146,7 +146,7 @@ class QuickSightHook(AwsBaseHook):
|
|
146
146
|
"""
|
147
147
|
Check status of a QuickSight Create Ingestion API.
|
148
148
|
|
149
|
-
:param aws_account_id: An AWS Account ID
|
149
|
+
:param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
|
150
150
|
:param data_set_id: QuickSight Data Set ID
|
151
151
|
:param ingestion_id: QuickSight Ingestion ID
|
152
152
|
:param target_state: Describes the QuickSight Job's Target State
|
@@ -154,6 +154,8 @@ class QuickSightHook(AwsBaseHook):
|
|
154
154
|
will check the status of QuickSight Ingestion
|
155
155
|
:return: response of describe_ingestion call after Ingestion is done
|
156
156
|
"""
|
157
|
+
aws_account_id = aws_account_id or self.account_id
|
158
|
+
|
157
159
|
while True:
|
158
160
|
status = self.get_status(aws_account_id, data_set_id, ingestion_id)
|
159
161
|
self.log.info("Current status is %s", status)
|
@@ -168,3 +170,16 @@ class QuickSightHook(AwsBaseHook):
|
|
168
170
|
|
169
171
|
self.log.info("QuickSight Ingestion completed")
|
170
172
|
return status
|
173
|
+
|
174
|
+
@cached_property
|
175
|
+
def sts_hook(self):
|
176
|
+
warnings.warn(
|
177
|
+
f"`{type(self).__name__}.sts_hook` property is deprecated and will be removed in the future. "
|
178
|
+
"This property used for obtain AWS Account ID, "
|
179
|
+
f"please consider to use `{type(self).__name__}.account_id` instead",
|
180
|
+
AirflowProviderDeprecationWarning,
|
181
|
+
stacklevel=2,
|
182
|
+
)
|
183
|
+
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
184
|
+
|
185
|
+
return StsHook(aws_conn_id=self.aws_conn_id)
|