apache-airflow-providers-amazon 8.12.0rc1__py3-none-any.whl → 8.13.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. airflow/providers/amazon/__init__.py +3 -3
  2. airflow/providers/amazon/aws/auth_manager/avp/__init__.py +16 -0
  3. airflow/providers/amazon/aws/auth_manager/avp/entities.py +64 -0
  4. airflow/providers/amazon/aws/auth_manager/avp/facade.py +126 -0
  5. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +47 -6
  6. airflow/providers/amazon/aws/auth_manager/constants.py +3 -1
  7. airflow/providers/amazon/aws/auth_manager/user.py +3 -0
  8. airflow/providers/amazon/aws/fs/s3.py +6 -6
  9. airflow/providers/amazon/aws/hooks/athena.py +10 -17
  10. airflow/providers/amazon/aws/hooks/ec2.py +10 -5
  11. airflow/providers/amazon/aws/hooks/emr.py +6 -13
  12. airflow/providers/amazon/aws/hooks/redshift_sql.py +41 -18
  13. airflow/providers/amazon/aws/hooks/s3.py +3 -3
  14. airflow/providers/amazon/aws/hooks/verified_permissions.py +44 -0
  15. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +3 -3
  16. airflow/providers/amazon/aws/notifications/chime.py +1 -7
  17. airflow/providers/amazon/aws/notifications/sns.py +1 -8
  18. airflow/providers/amazon/aws/notifications/sqs.py +1 -8
  19. airflow/providers/amazon/aws/operators/eks.py +22 -8
  20. airflow/providers/amazon/aws/operators/redshift_cluster.py +29 -0
  21. airflow/providers/amazon/aws/sensors/batch.py +1 -1
  22. airflow/providers/amazon/aws/sensors/dynamodb.py +6 -5
  23. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  24. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -4
  25. airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
  26. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +4 -1
  27. airflow/providers/amazon/aws/triggers/eks.py +30 -6
  28. airflow/providers/amazon/aws/triggers/emr.py +7 -3
  29. airflow/providers/amazon/aws/triggers/rds.py +5 -1
  30. airflow/providers/amazon/get_provider_info.py +28 -3
  31. {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/METADATA +9 -9
  32. {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/RECORD +34 -30
  33. {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/WHEEL +0 -0
  34. {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,7 @@ import packaging.version
27
27
 
28
28
  __all__ = ["__version__"]
29
29
 
30
- __version__ = "8.12.0"
30
+ __version__ = "8.13.0"
31
31
 
32
32
  try:
33
33
  from airflow import __version__ as airflow_version
@@ -35,8 +35,8 @@ except ImportError:
35
35
  from airflow.version import version as airflow_version
36
36
 
37
37
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
38
- "2.5.0"
38
+ "2.6.0"
39
39
  ):
40
40
  raise RuntimeError(
41
- f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.5.0+"
41
+ f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.6.0+"
42
42
  )
@@ -0,0 +1,16 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,64 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from enum import Enum
20
+ from typing import TYPE_CHECKING
21
+
22
+ if TYPE_CHECKING:
23
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
24
+
25
+ AVP_PREFIX_ENTITIES = "Airflow::"
26
+
27
+
28
+ class AvpEntities(Enum):
29
+ """Enum of Amazon Verified Permissions entities."""
30
+
31
+ ACTION = "Action"
32
+ ROLE = "Role"
33
+ USER = "User"
34
+
35
+ # Resource types
36
+ CONFIGURATION = "Configuration"
37
+ CONNECTION = "Connection"
38
+ DATASET = "Dataset"
39
+ POOL = "Pool"
40
+ VARIABLE = "Variable"
41
+ VIEW = "View"
42
+
43
+
44
+ def get_entity_type(resource_type: AvpEntities) -> str:
45
+ """
46
+ Return entity type.
47
+
48
+ :param resource_type: Resource type.
49
+
50
+ Example: Airflow::Action, Airflow::Role, Airflow::Variable, Airflow::User.
51
+ """
52
+ return AVP_PREFIX_ENTITIES + resource_type.value
53
+
54
+
55
+ def get_action_id(resource_type: AvpEntities, method: ResourceMethod):
56
+ """
57
+ Return action id.
58
+
59
+ Convention for action ID is <resource_type>::<method>. Example: Variable::GET.
60
+
61
+ :param resource_type: Resource type.
62
+ :param method: Resource method.
63
+ """
64
+ return f"{resource_type.value}::{method}"
@@ -0,0 +1,126 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from functools import cached_property
20
+ from typing import TYPE_CHECKING, Callable
21
+
22
+ from airflow.configuration import conf
23
+ from airflow.exceptions import AirflowException
24
+ from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities, get_action_id, get_entity_type
25
+ from airflow.providers.amazon.aws.auth_manager.constants import (
26
+ CONF_AVP_POLICY_STORE_ID_KEY,
27
+ CONF_CONN_ID_KEY,
28
+ CONF_SECTION_NAME,
29
+ )
30
+ from airflow.providers.amazon.aws.hooks.verified_permissions import VerifiedPermissionsHook
31
+ from airflow.utils.log.logging_mixin import LoggingMixin
32
+
33
+ if TYPE_CHECKING:
34
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
35
+ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
36
+
37
+
38
+ class AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
39
+ """
40
+ Facade for Amazon Verified Permissions.
41
+
42
+ Used as an intermediate layer between AWS auth manager and Amazon Verified Permissions.
43
+ """
44
+
45
+ @cached_property
46
+ def avp_client(self):
47
+ """Build Amazon Verified Permissions client."""
48
+ aws_conn_id = conf.get(CONF_SECTION_NAME, CONF_CONN_ID_KEY)
49
+ return VerifiedPermissionsHook(aws_conn_id=aws_conn_id).conn
50
+
51
+ @cached_property
52
+ def avp_policy_store_id(self):
53
+ """Get the Amazon Verified Permission policy store ID from config."""
54
+ return conf.get_mandatory_value(CONF_SECTION_NAME, CONF_AVP_POLICY_STORE_ID_KEY)
55
+
56
+ def is_authorized(
57
+ self,
58
+ *,
59
+ method: ResourceMethod,
60
+ entity_type: AvpEntities,
61
+ user: AwsAuthManagerUser,
62
+ entity_id: str | None = None,
63
+ entity_fetcher: Callable | None = None,
64
+ ) -> bool:
65
+ """
66
+ Make an authorization decision against Amazon Verified Permissions.
67
+
68
+ Check whether the user has permissions to access given resource.
69
+
70
+ :param method: the method to perform
71
+ :param entity_type: the entity type the user accesses
72
+ :param user: the user
73
+ :param entity_id: the entity ID the user accesses. If not provided, all entities of the type will be
74
+ considered.
75
+ :param entity_fetcher: function that returns list of entities to be passed to Amazon Verified
76
+ Permissions. Only needed if some resource properties are used in the policies (e.g. DAG folder).
77
+ """
78
+ entity_list = self._get_user_role_entities(user)
79
+ if entity_fetcher and entity_id:
80
+ # If no entity ID is provided, there is no need to fetch entities.
81
+ # We just need to know whether the user has permissions to access all resources from this type
82
+ entity_list += entity_fetcher()
83
+
84
+ self.log.debug(
85
+ "Making authorization request for user=%s, method=%s, entity_type=%s, entity_id=%s",
86
+ user.get_id(),
87
+ method,
88
+ entity_type,
89
+ entity_id,
90
+ )
91
+
92
+ resp = self.avp_client.is_authorized(
93
+ policyStoreId=self.avp_policy_store_id,
94
+ principal={"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
95
+ action={
96
+ "actionType": get_entity_type(AvpEntities.ACTION),
97
+ "actionId": get_action_id(entity_type, method),
98
+ },
99
+ resource={"entityType": get_entity_type(entity_type), "entityId": entity_id or "*"},
100
+ entities={"entityList": entity_list},
101
+ )
102
+
103
+ self.log.debug("Authorization response: %s", resp)
104
+
105
+ if len(resp.get("errors", [])) > 0:
106
+ self.log.error(
107
+ "Error occurred while making an authorization decision. Errors: %s", resp["errors"]
108
+ )
109
+ raise AirflowException("Error occurred while making an authorization decision.")
110
+
111
+ return resp["decision"] == "ALLOW"
112
+
113
+ @staticmethod
114
+ def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]:
115
+ user_entity = {
116
+ "identifier": {"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
117
+ "parents": [
118
+ {"entityType": get_entity_type(AvpEntities.ROLE), "entityId": group}
119
+ for group in user.get_groups()
120
+ ],
121
+ }
122
+ role_entities = [
123
+ {"identifier": {"entityType": get_entity_type(AvpEntities.ROLE), "entityId": group}}
124
+ for group in user.get_groups()
125
+ ]
126
+ return [user_entity, *role_entities]
@@ -23,6 +23,8 @@ from flask import session, url_for
23
23
 
24
24
  from airflow.configuration import conf
25
25
  from airflow.exceptions import AirflowOptionalProviderFeatureException
26
+ from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
27
+ from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade
26
28
  from airflow.providers.amazon.aws.auth_manager.constants import (
27
29
  CONF_ENABLE_KEY,
28
30
  CONF_SECTION_NAME,
@@ -72,6 +74,10 @@ class AwsAuthManager(BaseAuthManager):
72
74
  "The AWS auth manager is currently being built. It is not finalized. It is not intended to be used yet."
73
75
  )
74
76
 
77
+ @cached_property
78
+ def avp_facade(self):
79
+ return AwsAuthManagerAmazonVerifiedPermissionsFacade()
80
+
75
81
  def get_user(self) -> AwsAuthManagerUser | None:
76
82
  return session["aws_user"] if self.is_logged_in() else None
77
83
 
@@ -85,7 +91,13 @@ class AwsAuthManager(BaseAuthManager):
85
91
  details: ConfigurationDetails | None = None,
86
92
  user: BaseUser | None = None,
87
93
  ) -> bool:
88
- return self.is_logged_in()
94
+ config_section = details.section if details else None
95
+ return self.avp_facade.is_authorized(
96
+ method=method,
97
+ entity_type=AvpEntities.CONFIGURATION,
98
+ user=user or self.get_user(),
99
+ entity_id=config_section,
100
+ )
89
101
 
90
102
  def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool:
91
103
  return self.is_logged_in()
@@ -97,7 +109,13 @@ class AwsAuthManager(BaseAuthManager):
97
109
  details: ConnectionDetails | None = None,
98
110
  user: BaseUser | None = None,
99
111
  ) -> bool:
100
- return self.is_logged_in()
112
+ connection_id = details.conn_id if details else None
113
+ return self.avp_facade.is_authorized(
114
+ method=method,
115
+ entity_type=AvpEntities.CONNECTION,
116
+ user=user or self.get_user(),
117
+ entity_id=connection_id,
118
+ )
101
119
 
102
120
  def is_authorized_dag(
103
121
  self,
@@ -112,17 +130,35 @@ class AwsAuthManager(BaseAuthManager):
112
130
  def is_authorized_dataset(
113
131
  self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None
114
132
  ) -> bool:
115
- return self.is_logged_in()
133
+ dataset_uri = details.uri if details else None
134
+ return self.avp_facade.is_authorized(
135
+ method=method,
136
+ entity_type=AvpEntities.DATASET,
137
+ user=user or self.get_user(),
138
+ entity_id=dataset_uri,
139
+ )
116
140
 
117
141
  def is_authorized_pool(
118
142
  self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None
119
143
  ) -> bool:
120
- return self.is_logged_in()
144
+ pool_name = details.name if details else None
145
+ return self.avp_facade.is_authorized(
146
+ method=method,
147
+ entity_type=AvpEntities.POOL,
148
+ user=user or self.get_user(),
149
+ entity_id=pool_name,
150
+ )
121
151
 
122
152
  def is_authorized_variable(
123
153
  self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None
124
154
  ) -> bool:
125
- return self.is_logged_in()
155
+ variable_key = details.key if details else None
156
+ return self.avp_facade.is_authorized(
157
+ method=method,
158
+ entity_type=AvpEntities.VARIABLE,
159
+ user=user or self.get_user(),
160
+ entity_id=variable_key,
161
+ )
126
162
 
127
163
  def is_authorized_view(
128
164
  self,
@@ -130,7 +166,12 @@ class AwsAuthManager(BaseAuthManager):
130
166
  access_view: AccessView,
131
167
  user: BaseUser | None = None,
132
168
  ) -> bool:
133
- return self.is_logged_in()
169
+ return self.avp_facade.is_authorized(
170
+ method="GET",
171
+ entity_type=AvpEntities.VIEW,
172
+ user=user or self.get_user(),
173
+ entity_id=access_view.value,
174
+ )
134
175
 
135
176
  def get_url_login(self, **kwargs) -> str:
136
177
  return url_for("AwsAuthManagerAuthenticationViews.login")
@@ -18,6 +18,8 @@
18
18
  # Configuration keys
19
19
  from __future__ import annotations
20
20
 
21
+ CONF_ENABLE_KEY = "enable"
21
22
  CONF_SECTION_NAME = "aws_auth_manager"
23
+ CONF_CONN_ID_KEY = "conn_id"
22
24
  CONF_SAML_METADATA_URL_KEY = "saml_metadata_url"
23
- CONF_ENABLE_KEY = "enable"
25
+ CONF_AVP_POLICY_STORE_ID_KEY = "avp_policy_store_id"
@@ -49,3 +49,6 @@ class AwsAuthManagerUser(BaseUser):
49
49
 
50
50
  def get_name(self) -> str:
51
51
  return self.username or self.email or self.user_id
52
+
53
+ def get_groups(self):
54
+ return self.groups
@@ -25,7 +25,7 @@ import requests
25
25
  from botocore import UNSIGNED
26
26
  from requests import HTTPError
27
27
 
28
- from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
28
+ from airflow.providers.amazon.aws.hooks.s3 import S3Hook
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from botocore.awsrequest import AWSRequest
@@ -55,14 +55,14 @@ def get_fs(conn_id: str | None) -> AbstractFileSystem:
55
55
  "pip install apache-airflow-providers-amazon[s3fs]"
56
56
  )
57
57
 
58
- aws: AwsGenericHook = AwsGenericHook(aws_conn_id=conn_id, client_type="s3")
59
- session = aws.get_session(deferrable=True)
60
- endpoint_url = aws.conn_config.get_service_endpoint_url(service_name="s3")
58
+ s3_hook = S3Hook(aws_conn_id=conn_id)
59
+ session = s3_hook.get_session(deferrable=True)
60
+ endpoint_url = s3_hook.conn_config.get_service_endpoint_url(service_name="s3")
61
61
 
62
- config_kwargs: dict[str, Any] = aws.conn_config.extra_config.get("config_kwargs", {})
62
+ config_kwargs: dict[str, Any] = s3_hook.conn_config.extra_config.get("config_kwargs", {})
63
63
  register_events: dict[str, Callable[[Properties], None]] = {}
64
64
 
65
- s3_service_config = aws.service_config
65
+ s3_service_config = s3_hook.service_config
66
66
  if signer := s3_service_config.get("signer", None):
67
67
  log.info("Loading signer %s", signer)
68
68
  if singer_func := SIGNERS.get(signer):
@@ -292,24 +292,17 @@ class AthenaHook(AwsBaseHook):
292
292
 
293
293
  :param query_execution_id: Id of submitted athena query
294
294
  """
295
- output_location = None
296
- if query_execution_id:
297
- response = self.get_query_info(query_execution_id=query_execution_id, use_cache=True)
298
-
299
- if response:
300
- try:
301
- output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
302
- except KeyError:
303
- self.log.error(
304
- "Error retrieving OutputLocation. Query execution id: %s", query_execution_id
305
- )
306
- raise
307
- else:
308
- raise
309
- else:
310
- raise ValueError("Invalid Query execution id. Query execution id: %s", query_execution_id)
295
+ if not query_execution_id:
296
+ raise ValueError(f"Invalid Query execution id. Query execution id: {query_execution_id}")
297
+
298
+ if not (response := self.get_query_info(query_execution_id=query_execution_id, use_cache=True)):
299
+ raise ValueError(f"Unable to get query information for execution id: {query_execution_id}")
311
300
 
312
- return output_location
301
+ try:
302
+ return response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
303
+ except KeyError:
304
+ self.log.error("Error retrieving OutputLocation. Query execution id: %s", query_execution_id)
305
+ raise
313
306
 
314
307
  def stop_query(self, query_execution_id: str) -> dict:
315
308
  """Cancel the submitted query.
@@ -19,16 +19,21 @@ from __future__ import annotations
19
19
 
20
20
  import functools
21
21
  import time
22
+ from typing import Callable, TypeVar
22
23
 
23
24
  from airflow.exceptions import AirflowException
24
25
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26
+ from airflow.typing_compat import ParamSpec
25
27
 
28
+ PS = ParamSpec("PS")
29
+ RT = TypeVar("RT")
26
30
 
27
- def only_client_type(func):
31
+
32
+ def only_client_type(func: Callable[PS, RT]) -> Callable[PS, RT]:
28
33
  @functools.wraps(func)
29
- def checker(self, *args, **kwargs):
30
- if self._api_type == "client_type":
31
- return func(self, *args, **kwargs)
34
+ def checker(*args, **kwargs) -> RT:
35
+ if args[0]._api_type == "client_type":
36
+ return func(*args, **kwargs)
32
37
 
33
38
  ec2_doc_link = "https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html"
34
39
  raise AirflowException(
@@ -85,7 +90,7 @@ class EC2Hook(AwsBaseHook):
85
90
  :return: Instance object
86
91
  """
87
92
  if self._api_type == "client_type":
88
- return self.get_instances(filters=filters, instance_ids=[instance_id])
93
+ return self.get_instances(filters=filters, instance_ids=[instance_id])[0]
89
94
 
90
95
  return self.conn.Instance(id=instance_id)
91
96
 
@@ -437,8 +437,6 @@ class EmrContainerHook(AwsBaseHook):
437
437
 
438
438
  :param job_id: The ID of the job run request.
439
439
  """
440
- reason = None # We absorb any errors if we can't retrieve the job status
441
-
442
440
  try:
443
441
  response = self.conn.describe_job_run(
444
442
  virtualClusterId=self.virtual_cluster_id,
@@ -446,13 +444,13 @@ class EmrContainerHook(AwsBaseHook):
446
444
  )
447
445
  failure_reason = response["jobRun"]["failureReason"]
448
446
  state_details = response["jobRun"]["stateDetails"]
449
- reason = f"{failure_reason} - {state_details}"
447
+ return f"{failure_reason} - {state_details}"
450
448
  except KeyError:
451
449
  self.log.error("Could not get status of the EMR on EKS job")
452
450
  except ClientError as ex:
453
451
  self.log.error("AWS request failed, check logs for more info: %s", ex)
454
452
 
455
- return reason
453
+ return None
456
454
 
457
455
  def check_query_status(self, job_id: str) -> str | None:
458
456
  """
@@ -491,26 +489,21 @@ class EmrContainerHook(AwsBaseHook):
491
489
  :param max_polling_attempts: Number of times to poll for query state before function exits
492
490
  """
493
491
  try_number = 1
494
- final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
495
-
496
492
  while True:
497
493
  query_state = self.check_query_status(job_id)
494
+ if query_state in self.TERMINAL_STATES:
495
+ self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
496
+ return query_state
498
497
  if query_state is None:
499
498
  self.log.info("Try %s: Invalid query state. Retrying again", try_number)
500
- elif query_state in self.TERMINAL_STATES:
501
- self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
502
- final_query_state = query_state
503
- break
504
499
  else:
505
500
  self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
506
501
  if (
507
502
  max_polling_attempts and try_number >= max_polling_attempts
508
503
  ): # Break loop if max_polling_attempts reached
509
- final_query_state = query_state
510
- break
504
+ return query_state
511
505
  try_number += 1
512
506
  time.sleep(poll_interval)
513
- return final_query_state
514
507
 
515
508
  def stop_query(self, job_id: str) -> dict:
516
509
  """
@@ -102,24 +102,47 @@ class RedshiftSQLHook(DbApiHook):
102
102
  Port is required. If none is provided, default is used for each service.
103
103
  """
104
104
  port = conn.port or 5439
105
- # Pull the custer-identifier from the beginning of the Redshift URL
106
- # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
107
- cluster_identifier = conn.extra_dejson.get("cluster_identifier")
108
- if not cluster_identifier:
109
- if conn.host:
110
- cluster_identifier = conn.host.split(".", 1)[0]
111
- else:
112
- raise AirflowException("Please set cluster_identifier or host in redshift connection.")
113
- redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="redshift").conn
114
- # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
115
- cluster_creds = redshift_client.get_cluster_credentials(
116
- DbUser=conn.login,
117
- DbName=conn.schema,
118
- ClusterIdentifier=cluster_identifier,
119
- AutoCreate=False,
120
- )
121
- token = cluster_creds["DbPassword"]
122
- login = cluster_creds["DbUser"]
105
+ is_serverless = conn.extra_dejson.get("is_serverless", False)
106
+ if is_serverless:
107
+ serverless_work_group = conn.extra_dejson.get("serverless_work_group")
108
+ if not serverless_work_group:
109
+ raise AirflowException(
110
+ "Please set serverless_work_group in redshift connection to use IAM with"
111
+ " Redshift Serverless."
112
+ )
113
+ serverless_token_duration_seconds = conn.extra_dejson.get(
114
+ "serverless_token_duration_seconds", 3600
115
+ )
116
+ redshift_serverless_client = AwsBaseHook(
117
+ aws_conn_id=self.aws_conn_id, client_type="redshift-serverless"
118
+ ).conn
119
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#get-credentials
120
+ cluster_creds = redshift_serverless_client.get_credentials(
121
+ dbName=conn.schema,
122
+ workgroupName=serverless_work_group,
123
+ durationSeconds=serverless_token_duration_seconds,
124
+ )
125
+ token = cluster_creds["dbPassword"]
126
+ login = cluster_creds["dbUser"]
127
+ else:
128
+ # Pull the custer-identifier from the beginning of the Redshift URL
129
+ # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
130
+ cluster_identifier = conn.extra_dejson.get("cluster_identifier")
131
+ if not cluster_identifier:
132
+ if conn.host:
133
+ cluster_identifier = conn.host.split(".", 1)[0]
134
+ else:
135
+ raise AirflowException("Please set cluster_identifier or host in redshift connection.")
136
+ redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="redshift").conn
137
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
138
+ cluster_creds = redshift_client.get_cluster_credentials(
139
+ DbUser=conn.login,
140
+ DbName=conn.schema,
141
+ ClusterIdentifier=cluster_identifier,
142
+ AutoCreate=False,
143
+ )
144
+ token = cluster_creds["DbPassword"]
145
+ login = cluster_creds["DbUser"]
123
146
  return login, token, port
124
147
 
125
148
  def get_uri(self) -> str:
@@ -805,7 +805,7 @@ class S3Hook(AwsBaseHook):
805
805
  _prefix = _original_prefix.split("*", 1)[0] if _apply_wildcard else _original_prefix
806
806
  delimiter = delimiter or ""
807
807
  start_after_key = start_after_key or ""
808
- self.object_filter_usr = object_filter
808
+ object_filter_usr = object_filter
809
809
  config = {
810
810
  "PageSize": page_size,
811
811
  "MaxItems": max_items,
@@ -827,8 +827,8 @@ class S3Hook(AwsBaseHook):
827
827
  if _apply_wildcard:
828
828
  new_keys = (k for k in new_keys if fnmatch.fnmatch(k["Key"], _original_prefix))
829
829
  keys.extend(new_keys)
830
- if self.object_filter_usr is not None:
831
- return self.object_filter_usr(keys, from_datetime, to_datetime)
830
+ if object_filter_usr is not None:
831
+ return object_filter_usr(keys, from_datetime, to_datetime)
832
832
 
833
833
  return self._list_key_object_filter(keys, from_datetime, to_datetime)
834
834
 
@@ -0,0 +1,44 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
22
+
23
+ if TYPE_CHECKING:
24
+ from mypy_boto3_verifiedpermissions.client import VerifiedPermissionsClient # noqa
25
+
26
+
27
+ class VerifiedPermissionsHook(AwsGenericHook["VerifiedPermissionsClient"]):
28
+ """
29
+ Interact with Amazon Verified Permissions.
30
+
31
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("verifiedpermissions")
32
+ <VerifiedPermissions.Client>`.
33
+
34
+ Additional arguments (such as ``aws_conn_id``) may be specified and
35
+ are passed down to the underlying AwsBaseHook.
36
+
37
+ .. seealso::
38
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
39
+ - `Amazon Appflow API Reference <https://docs.aws.amazon.com/verifiedpermissions/latest/apireference/Welcome.html>`__
40
+ """
41
+
42
+ def __init__(self, *args, **kwargs) -> None:
43
+ kwargs["client_type"] = "verifiedpermissions"
44
+ super().__init__(*args, **kwargs)
@@ -96,15 +96,15 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
96
96
  # Replace unsupported log group name characters
97
97
  return super()._render_filename(ti, try_number).replace(":", "_")
98
98
 
99
- def set_context(self, ti):
99
+ def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
100
100
  super().set_context(ti)
101
- self.json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer")
101
+ _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer")
102
102
  self.handler = watchtower.CloudWatchLogHandler(
103
103
  log_group_name=self.log_group,
104
104
  log_stream_name=self._render_filename(ti, ti.try_number),
105
105
  use_queues=not getattr(ti, "is_trigger_log_context", False),
106
106
  boto3_client=self.hook.get_conn(),
107
- json_serialize_default=self.json_serialize,
107
+ json_serialize_default=_json_serialize,
108
108
  )
109
109
 
110
110
  def close(self):