apache-airflow-providers-amazon 8.16.0rc1__py3-none-any.whl → 8.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -0
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +34 -19
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +44 -1
  5. airflow/providers/amazon/aws/auth_manager/cli/__init__.py +16 -0
  6. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +178 -0
  7. airflow/providers/amazon/aws/auth_manager/cli/definition.py +62 -0
  8. airflow/providers/amazon/aws/auth_manager/cli/schema.json +171 -0
  9. airflow/providers/amazon/aws/auth_manager/constants.py +1 -0
  10. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +77 -23
  11. airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +17 -0
  12. airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
  13. airflow/providers/amazon/aws/executors/utils/__init__.py +16 -0
  14. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +60 -0
  15. airflow/providers/amazon/aws/hooks/athena_sql.py +168 -0
  16. airflow/providers/amazon/aws/hooks/base_aws.py +14 -0
  17. airflow/providers/amazon/aws/hooks/quicksight.py +33 -18
  18. airflow/providers/amazon/aws/hooks/redshift_data.py +66 -17
  19. airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -1
  20. airflow/providers/amazon/aws/hooks/s3.py +18 -4
  21. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
  22. airflow/providers/amazon/aws/operators/batch.py +33 -15
  23. airflow/providers/amazon/aws/operators/cloud_formation.py +37 -26
  24. airflow/providers/amazon/aws/operators/datasync.py +19 -18
  25. airflow/providers/amazon/aws/operators/dms.py +57 -69
  26. airflow/providers/amazon/aws/operators/ec2.py +19 -5
  27. airflow/providers/amazon/aws/operators/emr.py +30 -10
  28. airflow/providers/amazon/aws/operators/eventbridge.py +57 -80
  29. airflow/providers/amazon/aws/operators/quicksight.py +17 -24
  30. airflow/providers/amazon/aws/operators/redshift_data.py +68 -19
  31. airflow/providers/amazon/aws/operators/s3.py +1 -1
  32. airflow/providers/amazon/aws/operators/sagemaker.py +42 -12
  33. airflow/providers/amazon/aws/sensors/cloud_formation.py +30 -25
  34. airflow/providers/amazon/aws/sensors/dms.py +31 -24
  35. airflow/providers/amazon/aws/sensors/dynamodb.py +15 -15
  36. airflow/providers/amazon/aws/sensors/quicksight.py +34 -24
  37. airflow/providers/amazon/aws/sensors/redshift_cluster.py +41 -3
  38. airflow/providers/amazon/aws/sensors/s3.py +13 -8
  39. airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
  40. airflow/providers/amazon/aws/triggers/redshift_data.py +113 -0
  41. airflow/providers/amazon/aws/triggers/s3.py +9 -4
  42. airflow/providers/amazon/get_provider_info.py +55 -16
  43. {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/METADATA +17 -15
  44. {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/RECORD +46 -38
  45. {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/WHEEL +0 -0
  46. {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/entry_points.txt +0 -0
@@ -28,7 +28,7 @@ from collections import defaultdict, deque
28
28
  from copy import deepcopy
29
29
  from typing import TYPE_CHECKING
30
30
 
31
- from botocore.exceptions import ClientError
31
+ from botocore.exceptions import ClientError, NoCredentialsError
32
32
 
33
33
  from airflow.configuration import conf
34
34
  from airflow.exceptions import AirflowException
@@ -42,6 +42,9 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
42
42
  EcsQueuedTask,
43
43
  EcsTaskCollection,
44
44
  )
45
+ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
46
+ from airflow.providers.amazon.aws.hooks.ecs import EcsHook
47
+ from airflow.utils import timezone
45
48
  from airflow.utils.state import State
46
49
 
47
50
  if TYPE_CHECKING:
@@ -51,6 +54,12 @@ if TYPE_CHECKING:
51
54
  ExecutorConfigType,
52
55
  )
53
56
 
57
+ INVALID_CREDENTIALS_EXCEPTIONS = [
58
+ "ExpiredTokenException",
59
+ "InvalidClientTokenId",
60
+ "UnrecognizedClientException",
61
+ ]
62
+
54
63
 
55
64
  class AwsEcsExecutor(BaseExecutor):
56
65
  """
@@ -91,30 +100,15 @@ class AwsEcsExecutor(BaseExecutor):
91
100
 
92
101
  self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
93
102
  self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
94
- aws_conn_id = conf.get(
95
- CONFIG_GROUP_NAME,
96
- AllEcsConfigKeys.AWS_CONN_ID,
97
- fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
98
- )
99
- region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
100
- from airflow.providers.amazon.aws.hooks.ecs import EcsHook
103
+ self.attempts_since_last_successful_connection = 0
104
+
105
+ self.load_ecs_connection(check_connection=False)
106
+ self.IS_BOTO_CONNECTION_HEALTHY = False
101
107
 
102
- self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
103
108
  self.run_task_kwargs = self._load_run_kwargs()
104
109
 
105
110
  def start(self):
106
- """
107
- Make a test API call to check the health of the ECS Executor.
108
-
109
- Deliberately use an invalid task ID, some potential outcomes in order:
110
- 1. "AccessDeniedException" is raised if there are insufficient permissions.
111
- 2. "ClusterNotFoundException" is raised if permissions exist but the cluster does not.
112
- 3. The API responds with a failure message if the cluster is found and there
113
- are permissions, but the cluster itself has issues.
114
- 4. "InvalidParameterException" is raised if the permissions and cluster exist but the task does not.
115
-
116
- The last one is considered a success state for the purposes of this check.
117
- """
111
+ """This is called by the scheduler when the Executor is being run for the first time."""
118
112
  check_health = conf.getboolean(
119
113
  CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
120
114
  )
@@ -123,7 +117,25 @@ class AwsEcsExecutor(BaseExecutor):
123
117
  return
124
118
 
125
119
  self.log.info("Starting ECS Executor and determining health...")
120
+ try:
121
+ self.check_health()
122
+ except AirflowException:
123
+ self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
124
+ raise
125
+
126
+ def check_health(self):
127
+ """
128
+ Make a test API call to check the health of the ECS Executor.
126
129
 
130
+ Deliberately use an invalid task ID, some potential outcomes in order:
131
+ 1. `AccessDeniedException` is raised if there are insufficient permissions.
132
+ 2. `ClusterNotFoundException` is raised if permissions exist but the cluster does not.
133
+ 3. The API responds with a failure message if the cluster is found and there
134
+ are permissions, but the cluster itself has issues.
135
+ 4. `InvalidParameterException` is raised if the permissions and cluster exist but the task does not.
136
+
137
+ The last one is considered a success state for the purposes of this check.
138
+ """
127
139
  success_status = "succeeded."
128
140
  status = success_status
129
141
 
@@ -151,18 +163,50 @@ class AwsEcsExecutor(BaseExecutor):
151
163
  finally:
152
164
  msg_prefix = "ECS Executor health check has %s"
153
165
  if status == success_status:
166
+ self.IS_BOTO_CONNECTION_HEALTHY = True
154
167
  self.log.info(msg_prefix, status)
155
168
  else:
156
169
  msg_error_suffix = (
157
- "The ECS executor will not be able to run Airflow tasks until the issue is addressed. "
158
- "Stopping the Airflow Scheduler from starting until the issue is resolved."
170
+ "The ECS executor will not be able to run Airflow tasks until the issue is addressed."
159
171
  )
160
172
  raise AirflowException(msg_prefix % status + msg_error_suffix)
161
173
 
174
+ def load_ecs_connection(self, check_connection: bool = True):
175
+ self.log.info("Loading Connection information")
176
+ aws_conn_id = conf.get(
177
+ CONFIG_GROUP_NAME,
178
+ AllEcsConfigKeys.AWS_CONN_ID,
179
+ fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
180
+ )
181
+ region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
182
+ self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
183
+ self.attempts_since_last_successful_connection += 1
184
+ self.last_connection_reload = timezone.utcnow()
185
+
186
+ if check_connection:
187
+ self.check_health()
188
+ self.attempts_since_last_successful_connection = 0
189
+
162
190
  def sync(self):
191
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
192
+ exponential_backoff_retry(
193
+ self.last_connection_reload,
194
+ self.attempts_since_last_successful_connection,
195
+ self.load_ecs_connection,
196
+ )
197
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
198
+ return
163
199
  try:
164
200
  self.sync_running_tasks()
165
201
  self.attempt_task_runs()
202
+ except (ClientError, NoCredentialsError) as error:
203
+ error_code = error.response["Error"]["Code"]
204
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
205
+ self.IS_BOTO_CONNECTION_HEALTHY = False
206
+ self.log.warning(
207
+ f"AWS credentials are either missing or expired: {error}.\nRetrying connection"
208
+ )
209
+
166
210
  except Exception:
167
211
  # We catch any and all exceptions because otherwise they would bubble
168
212
  # up and kill the scheduler process
@@ -176,6 +220,7 @@ class AwsEcsExecutor(BaseExecutor):
176
220
  return
177
221
 
178
222
  describe_tasks_response = self.__describe_tasks(all_task_arns)
223
+
179
224
  self.log.debug("Active Workers: %s", describe_tasks_response)
180
225
 
181
226
  if describe_tasks_response["failures"]:
@@ -288,6 +333,15 @@ class AwsEcsExecutor(BaseExecutor):
288
333
  _failure_reasons = []
289
334
  try:
290
335
  run_task_response = self._run_task(task_key, cmd, queue, exec_config)
336
+ except NoCredentialsError:
337
+ self.pending_tasks.appendleft(ecs_task)
338
+ raise
339
+ except ClientError as e:
340
+ error_code = e.response["Error"]["Code"]
341
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
342
+ self.pending_tasks.appendleft(ecs_task)
343
+ raise
344
+ _failure_reasons.append(str(e))
291
345
  except Exception as e:
292
346
  # Failed to even get a response back from the Boto3 API or something else went
293
347
  # wrong. For any possible failure we want to add the exception reasons to the
@@ -40,6 +40,7 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
40
40
  camelize_dict_keys,
41
41
  parse_assign_public_ip,
42
42
  )
43
+ from airflow.providers.amazon.aws.hooks.ecs import EcsHook
43
44
  from airflow.utils.helpers import prune_dict
44
45
 
45
46
 
@@ -60,6 +61,22 @@ def build_task_kwargs() -> dict:
60
61
  task_kwargs = _fetch_config_values()
61
62
  task_kwargs.update(_fetch_templated_kwargs())
62
63
 
64
+ has_launch_type: bool = "launch_type" in task_kwargs
65
+ has_capacity_provider: bool = "capacity_provider_strategy" in task_kwargs
66
+
67
+ if has_capacity_provider and has_launch_type:
68
+ raise ValueError(
69
+ "capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both."
70
+ )
71
+ elif "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type):
72
+ # Default API behavior if neither is provided is to fall back on the default capacity
73
+ # provider if it exists. Since it is not a required value, check if there is one
74
+ # before using it, and if there is not then use the FARGATE launch_type as
75
+ # the final fallback.
76
+ cluster = EcsHook().conn.describe_clusters(clusters=[task_kwargs["cluster"]])["clusters"][0]
77
+ if not cluster.get("defaultCapacityProviderStrategy"):
78
+ task_kwargs["launch_type"] = "FARGATE"
79
+
63
80
  # There can only be 1 count of these containers
64
81
  task_kwargs["count"] = 1 # type: ignore
65
82
  # There could be a generic approach to the below, but likely more convoluted then just manually ensuring
@@ -44,7 +44,6 @@ CONFIG_DEFAULTS = {
44
44
  "conn_id": "aws_default",
45
45
  "max_run_task_attempts": "3",
46
46
  "assign_public_ip": "False",
47
- "launch_type": "FARGATE",
48
47
  "platform_version": "LATEST",
49
48
  "check_health_on_startup": "True",
50
49
  }
@@ -81,6 +80,7 @@ class RunTaskKwargsConfigKeys(BaseConfigKeys):
81
80
  """Keys loaded into the config which are valid ECS run_task kwargs."""
82
81
 
83
82
  ASSIGN_PUBLIC_IP = "assign_public_ip"
83
+ CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
84
84
  CLUSTER = "cluster"
85
85
  LAUNCH_TYPE = "launch_type"
86
86
  PLATFORM_VERSION = "platform_version"
@@ -0,0 +1,16 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,60 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from datetime import datetime, timedelta
21
+ from typing import Callable
22
+
23
+ from airflow.utils import timezone
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+
28
+ def exponential_backoff_retry(
29
+ last_attempt_time: datetime,
30
+ attempts_since_last_successful: int,
31
+ callable_function: Callable,
32
+ max_delay: int = 60 * 2,
33
+ max_attempts: int = -1,
34
+ exponent_base: int = 4,
35
+ ) -> None:
36
+ """
37
+ Retries a callable function with exponential backoff between attempts if it raises an exception.
38
+
39
+ :param last_attempt_time: Timestamp of last attempt call.
40
+ :param attempts_since_last_successful: Number of attempts since last success.
41
+ :param callable_function: Callable function that will be called if enough time has passed.
42
+ :param max_delay: Maximum delay in seconds between retries. Default 120.
43
+ :param max_attempts: Maximum number of attempts before giving up. Default -1 (no limit).
44
+ :param exponent_base: Exponent base to calculate delay. Default 4.
45
+ """
46
+ if max_attempts != -1 and attempts_since_last_successful >= max_attempts:
47
+ log.error("Max attempts reached. Exiting.")
48
+ return
49
+
50
+ delay = min((exponent_base**attempts_since_last_successful), max_delay)
51
+ next_retry_time = last_attempt_time + timedelta(seconds=delay)
52
+ current_time = timezone.utcnow()
53
+
54
+ if current_time >= next_retry_time:
55
+ try:
56
+ callable_function()
57
+ except Exception:
58
+ log.exception("Error calling %r", callable_function.__name__)
59
+ next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay)
60
+ log.info("Waiting for %s seconds before retrying.", next_delay)
@@ -0,0 +1,168 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ from functools import cached_property
21
+ from typing import TYPE_CHECKING, Any
22
+
23
+ import pyathena
24
+ from sqlalchemy.engine.url import URL
25
+
26
+ from airflow.exceptions import AirflowException, AirflowNotFoundException
27
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
28
+ from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
29
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
30
+
31
+ if TYPE_CHECKING:
32
+ from pyathena.connection import Connection as AthenaConnection
33
+
34
+
35
+ class AthenaSQLHook(AwsBaseHook, DbApiHook):
36
+ """Interact with Amazon Athena.
37
+
38
+ Provide wrapper around PyAthena library.
39
+
40
+ :param athena_conn_id: :ref:`Amazon Athena Connection <howto/connection:athena>`.
41
+
42
+ Additional arguments (such as ``aws_conn_id``) may be specified and
43
+ are passed down to the underlying AwsBaseHook.
44
+
45
+ You can specify ``driver`` in ``extra`` of your connection in order to use
46
+ a different driver than the default ``rest``.
47
+
48
+ Also, aws_domain could be specified in ``extra`` of your connection.
49
+
50
+ PyAthena and AWS Authentication parameters could be passed in extra field of ``athena_conn_id`` connection.
51
+
52
+ Passing authentication parameters in ``athena_conn_id`` will override those in ``aws_conn_id``.
53
+
54
+ .. seealso::
55
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
56
+
57
+ .. note::
58
+ get_uri() depends on SQLAlchemy and PyAthena.
59
+ """
60
+
61
+ conn_name_attr = "athena_conn_id"
62
+ default_conn_name = "athena_default"
63
+ conn_type = "athena"
64
+ hook_name = "Amazon Athena"
65
+ supports_autocommit = True
66
+
67
+ def __init__(self, athena_conn_id: str = default_conn_name, *args, **kwargs) -> None:
68
+ super().__init__(*args, **kwargs)
69
+ self.athena_conn_id = athena_conn_id
70
+
71
+ @classmethod
72
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
73
+ """Return custom UI field behaviour for AWS Athena Connection."""
74
+ return {
75
+ "hidden_fields": ["host", "port"],
76
+ "relabeling": {
77
+ "login": "AWS Access Key ID",
78
+ "password": "AWS Secret Access Key",
79
+ },
80
+ "placeholders": {
81
+ "login": "AKIAIOSFODNN7EXAMPLE",
82
+ "password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
83
+ "extra": json.dumps(
84
+ {
85
+ "aws_domain": "amazonaws.com",
86
+ "driver": "rest",
87
+ "s3_staging_dir": "s3://bucket_name/staging/",
88
+ "work_group": "primary",
89
+ "region_name": "us-east-1",
90
+ "session_kwargs": {"profile_name": "default"},
91
+ "config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}},
92
+ "role_arn": "arn:aws:iam::123456789098:role/role-name",
93
+ "assume_role_method": "assume_role",
94
+ "assume_role_kwargs": {"RoleSessionName": "airflow"},
95
+ "aws_session_token": "AQoDYXdzEJr...EXAMPLETOKEN",
96
+ "endpoint_url": "http://localhost:4566",
97
+ },
98
+ indent=2,
99
+ ),
100
+ },
101
+ }
102
+
103
+ @cached_property
104
+ def conn_config(self) -> AwsConnectionWrapper:
105
+ """Get the Airflow Connection object and wrap it in helper (cached)."""
106
+ athena_conn = self.get_connection(self.athena_conn_id)
107
+ if self.aws_conn_id:
108
+ try:
109
+ connection = self.get_connection(self.aws_conn_id)
110
+ connection.login = athena_conn.login
111
+ connection.password = athena_conn.password
112
+ connection.schema = athena_conn.schema
113
+ connection.set_extra(json.dumps({**athena_conn.extra_dejson, **connection.extra_dejson}))
114
+ except AirflowNotFoundException:
115
+ connection = athena_conn
116
+ connection.conn_type = "aws"
117
+ self.log.warning(
118
+ "Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
119
+ )
120
+
121
+ return AwsConnectionWrapper(
122
+ conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
123
+ )
124
+
125
+ @property
126
+ def conn(self) -> AwsConnectionWrapper:
127
+ """Get Aws Connection Wrapper object."""
128
+ return self.conn_config
129
+
130
+ def _get_conn_params(self) -> dict[str, str | None]:
131
+ """Retrieve connection parameters."""
132
+ if not self.conn.region_name:
133
+ raise AirflowException("region_name must be specified in the connection's extra")
134
+
135
+ return dict(
136
+ driver=self.conn.extra_dejson.get("driver", "rest"),
137
+ schema_name=self.conn.schema,
138
+ region_name=self.conn.region_name,
139
+ aws_domain=self.conn.extra_dejson.get("aws_domain", "amazonaws.com"),
140
+ )
141
+
142
+ def get_uri(self) -> str:
143
+ """Overridden to use the Athena dialect as driver name."""
144
+ conn_params = self._get_conn_params()
145
+ creds = self.get_credentials(region_name=conn_params["region_name"])
146
+
147
+ return URL.create(
148
+ f'awsathena+{conn_params["driver"]}',
149
+ username=creds.access_key,
150
+ password=creds.secret_key,
151
+ host=f'athena.{conn_params["region_name"]}.{conn_params["aws_domain"]}',
152
+ port=443,
153
+ database=conn_params["schema_name"],
154
+ query={"aws_session_token": creds.token, **self.conn.extra_dejson},
155
+ )
156
+
157
+ def get_conn(self) -> AthenaConnection:
158
+ """Get a ``pyathena.Connection`` object."""
159
+ conn_params = self._get_conn_params()
160
+
161
+ conn_kwargs: dict = {
162
+ "schema_name": conn_params["schema_name"],
163
+ "region_name": conn_params["region_name"],
164
+ "session": self.get_session(region_name=conn_params["region_name"]),
165
+ **self.conn.extra_dejson,
166
+ }
167
+
168
+ return pyathena.connect(**conn_kwargs)
@@ -629,6 +629,20 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
629
629
  """Verify or not SSL certificates boto3 client/resource read-only property."""
630
630
  return self.conn_config.verify
631
631
 
632
+ @cached_property
633
+ def account_id(self) -> str:
634
+ """Return associated AWS Account ID."""
635
+ return (
636
+ self.get_session(region_name=self.region_name)
637
+ .client(
638
+ service_name="sts",
639
+ endpoint_url=self.conn_config.get_service_endpoint_url("sts"),
640
+ config=self.config,
641
+ verify=self.verify,
642
+ )
643
+ .get_caller_identity()["Account"]
644
+ )
645
+
632
646
  def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session:
633
647
  """Get the underlying boto3.session.Session(region_name=region_name)."""
634
648
  return SessionFactory(
@@ -18,13 +18,13 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import time
21
+ import warnings
21
22
  from functools import cached_property
22
23
 
23
24
  from botocore.exceptions import ClientError
24
25
 
25
- from airflow.exceptions import AirflowException
26
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
26
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
27
- from airflow.providers.amazon.aws.hooks.sts import StsHook
28
28
 
29
29
 
30
30
  class QuickSightHook(AwsBaseHook):
@@ -46,10 +46,6 @@ class QuickSightHook(AwsBaseHook):
46
46
  def __init__(self, *args, **kwargs):
47
47
  super().__init__(client_type="quicksight", *args, **kwargs)
48
48
 
49
- @cached_property
50
- def sts_hook(self):
51
- return StsHook(aws_conn_id=self.aws_conn_id)
52
-
53
49
  def create_ingestion(
54
50
  self,
55
51
  data_set_id: str,
@@ -57,6 +53,7 @@ class QuickSightHook(AwsBaseHook):
57
53
  ingestion_type: str,
58
54
  wait_for_completion: bool = True,
59
55
  check_interval: int = 30,
56
+ aws_account_id: str | None = None,
60
57
  ) -> dict:
61
58
  """
62
59
  Create and start a new SPICE ingestion for a dataset; refresh the SPICE datasets.
@@ -66,18 +63,18 @@ class QuickSightHook(AwsBaseHook):
66
63
 
67
64
  :param data_set_id: ID of the dataset used in the ingestion.
68
65
  :param ingestion_id: ID for the ingestion.
69
- :param ingestion_type: Type of ingestion . "INCREMENTAL_REFRESH"|"FULL_REFRESH"
66
+ :param ingestion_type: Type of ingestion: "INCREMENTAL_REFRESH"|"FULL_REFRESH"
70
67
  :param wait_for_completion: if the program should keep running until job finishes
71
68
  :param check_interval: the time interval in seconds which the operator
72
69
  will check the status of QuickSight Ingestion
70
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
73
71
  :return: Returns descriptive information about the created data ingestion
74
72
  having Ingestion ARN, HTTP status, ingestion ID and ingestion status.
75
73
  """
74
+ aws_account_id = aws_account_id or self.account_id
76
75
  self.log.info("Creating QuickSight Ingestion for data set id %s.", data_set_id)
77
- quicksight_client = self.get_conn()
78
76
  try:
79
- aws_account_id = self.sts_hook.get_account_number()
80
- create_ingestion_response = quicksight_client.create_ingestion(
77
+ create_ingestion_response = self.conn.create_ingestion(
81
78
  DataSetId=data_set_id,
82
79
  IngestionId=ingestion_id,
83
80
  IngestionType=ingestion_type,
@@ -97,20 +94,21 @@ class QuickSightHook(AwsBaseHook):
97
94
  self.log.error("Failed to run Amazon QuickSight create_ingestion API, error: %s", general_error)
98
95
  raise
99
96
 
100
- def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> str:
97
+ def get_status(self, aws_account_id: str | None, data_set_id: str, ingestion_id: str) -> str:
101
98
  """
102
99
  Get the current status of QuickSight Create Ingestion API.
103
100
 
104
101
  .. seealso::
105
102
  - :external+boto3:py:meth:`QuickSight.Client.describe_ingestion`
106
103
 
107
- :param aws_account_id: An AWS Account ID
104
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
108
105
  :param data_set_id: QuickSight Data Set ID
109
106
  :param ingestion_id: QuickSight Ingestion ID
110
107
  :return: An QuickSight Ingestion Status
111
108
  """
109
+ aws_account_id = aws_account_id or self.account_id
112
110
  try:
113
- describe_ingestion_response = self.get_conn().describe_ingestion(
111
+ describe_ingestion_response = self.conn.describe_ingestion(
114
112
  AwsAccountId=aws_account_id, DataSetId=data_set_id, IngestionId=ingestion_id
115
113
  )
116
114
  return describe_ingestion_response["Ingestion"]["IngestionStatus"]
@@ -119,17 +117,19 @@ class QuickSightHook(AwsBaseHook):
119
117
  except ClientError as e:
120
118
  raise AirflowException(f"AWS request failed: {e}")
121
119
 
122
- def get_error_info(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> dict | None:
120
+ def get_error_info(self, aws_account_id: str | None, data_set_id: str, ingestion_id: str) -> dict | None:
123
121
  """
124
122
  Get info about the error if any.
125
123
 
126
- :param aws_account_id: An AWS Account ID
124
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
127
125
  :param data_set_id: QuickSight Data Set ID
128
126
  :param ingestion_id: QuickSight Ingestion ID
129
127
  :return: Error info dict containing the error type (key 'Type') and message (key 'Message')
130
128
  if available. Else, returns None.
131
129
  """
132
- describe_ingestion_response = self.get_conn().describe_ingestion(
130
+ aws_account_id = aws_account_id or self.account_id
131
+
132
+ describe_ingestion_response = self.conn.describe_ingestion(
133
133
  AwsAccountId=aws_account_id, DataSetId=data_set_id, IngestionId=ingestion_id
134
134
  )
135
135
  # using .get() to get None if the key is not present, instead of an exception.
@@ -137,7 +137,7 @@ class QuickSightHook(AwsBaseHook):
137
137
 
138
138
  def wait_for_state(
139
139
  self,
140
- aws_account_id: str,
140
+ aws_account_id: str | None,
141
141
  data_set_id: str,
142
142
  ingestion_id: str,
143
143
  target_state: set,
@@ -146,7 +146,7 @@ class QuickSightHook(AwsBaseHook):
146
146
  """
147
147
  Check status of a QuickSight Create Ingestion API.
148
148
 
149
- :param aws_account_id: An AWS Account ID
149
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID.
150
150
  :param data_set_id: QuickSight Data Set ID
151
151
  :param ingestion_id: QuickSight Ingestion ID
152
152
  :param target_state: Describes the QuickSight Job's Target State
@@ -154,6 +154,8 @@ class QuickSightHook(AwsBaseHook):
154
154
  will check the status of QuickSight Ingestion
155
155
  :return: response of describe_ingestion call after Ingestion is done
156
156
  """
157
+ aws_account_id = aws_account_id or self.account_id
158
+
157
159
  while True:
158
160
  status = self.get_status(aws_account_id, data_set_id, ingestion_id)
159
161
  self.log.info("Current status is %s", status)
@@ -168,3 +170,16 @@ class QuickSightHook(AwsBaseHook):
168
170
 
169
171
  self.log.info("QuickSight Ingestion completed")
170
172
  return status
173
+
174
+ @cached_property
175
+ def sts_hook(self):
176
+ warnings.warn(
177
+ f"`{type(self).__name__}.sts_hook` property is deprecated and will be removed in the future. "
178
+ "This property used for obtain AWS Account ID, "
179
+ f"please consider to use `{type(self).__name__}.account_id` instead",
180
+ AirflowProviderDeprecationWarning,
181
+ stacklevel=2,
182
+ )
183
+ from airflow.providers.amazon.aws.hooks.sts import StsHook
184
+
185
+ return StsHook(aws_conn_id=self.aws_conn_id)