apache-airflow-providers-amazon 8.12.0rc1__py3-none-any.whl → 8.13.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +3 -3
- airflow/providers/amazon/aws/auth_manager/avp/__init__.py +16 -0
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +64 -0
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +126 -0
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +47 -6
- airflow/providers/amazon/aws/auth_manager/constants.py +3 -1
- airflow/providers/amazon/aws/auth_manager/user.py +3 -0
- airflow/providers/amazon/aws/fs/s3.py +6 -6
- airflow/providers/amazon/aws/hooks/athena.py +10 -17
- airflow/providers/amazon/aws/hooks/ec2.py +10 -5
- airflow/providers/amazon/aws/hooks/emr.py +6 -13
- airflow/providers/amazon/aws/hooks/redshift_sql.py +41 -18
- airflow/providers/amazon/aws/hooks/s3.py +3 -3
- airflow/providers/amazon/aws/hooks/verified_permissions.py +44 -0
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +3 -3
- airflow/providers/amazon/aws/notifications/chime.py +1 -7
- airflow/providers/amazon/aws/notifications/sns.py +1 -8
- airflow/providers/amazon/aws/notifications/sqs.py +1 -8
- airflow/providers/amazon/aws/operators/eks.py +22 -8
- airflow/providers/amazon/aws/operators/redshift_cluster.py +29 -0
- airflow/providers/amazon/aws/sensors/batch.py +1 -1
- airflow/providers/amazon/aws/sensors/dynamodb.py +6 -5
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -4
- airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +4 -1
- airflow/providers/amazon/aws/triggers/eks.py +30 -6
- airflow/providers/amazon/aws/triggers/emr.py +7 -3
- airflow/providers/amazon/aws/triggers/rds.py +5 -1
- airflow/providers/amazon/get_provider_info.py +28 -3
- {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/METADATA +9 -9
- {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/RECORD +34 -30
- {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.12.0rc1.dist-info → apache_airflow_providers_amazon-8.13.0rc1.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,7 @@ import packaging.version
|
|
27
27
|
|
28
28
|
__all__ = ["__version__"]
|
29
29
|
|
30
|
-
__version__ = "8.
|
30
|
+
__version__ = "8.13.0"
|
31
31
|
|
32
32
|
try:
|
33
33
|
from airflow import __version__ as airflow_version
|
@@ -35,8 +35,8 @@ except ImportError:
|
|
35
35
|
from airflow.version import version as airflow_version
|
36
36
|
|
37
37
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
38
|
-
"2.
|
38
|
+
"2.6.0"
|
39
39
|
):
|
40
40
|
raise RuntimeError(
|
41
|
-
f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.
|
41
|
+
f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.6.0+"
|
42
42
|
)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from enum import Enum
|
20
|
+
from typing import TYPE_CHECKING
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
24
|
+
|
25
|
+
AVP_PREFIX_ENTITIES = "Airflow::"
|
26
|
+
|
27
|
+
|
28
|
+
class AvpEntities(Enum):
|
29
|
+
"""Enum of Amazon Verified Permissions entities."""
|
30
|
+
|
31
|
+
ACTION = "Action"
|
32
|
+
ROLE = "Role"
|
33
|
+
USER = "User"
|
34
|
+
|
35
|
+
# Resource types
|
36
|
+
CONFIGURATION = "Configuration"
|
37
|
+
CONNECTION = "Connection"
|
38
|
+
DATASET = "Dataset"
|
39
|
+
POOL = "Pool"
|
40
|
+
VARIABLE = "Variable"
|
41
|
+
VIEW = "View"
|
42
|
+
|
43
|
+
|
44
|
+
def get_entity_type(resource_type: AvpEntities) -> str:
|
45
|
+
"""
|
46
|
+
Return entity type.
|
47
|
+
|
48
|
+
:param resource_type: Resource type.
|
49
|
+
|
50
|
+
Example: Airflow::Action, Airflow::Role, Airflow::Variable, Airflow::User.
|
51
|
+
"""
|
52
|
+
return AVP_PREFIX_ENTITIES + resource_type.value
|
53
|
+
|
54
|
+
|
55
|
+
def get_action_id(resource_type: AvpEntities, method: ResourceMethod):
|
56
|
+
"""
|
57
|
+
Return action id.
|
58
|
+
|
59
|
+
Convention for action ID is <resource_type>::<method>. Example: Variable::GET.
|
60
|
+
|
61
|
+
:param resource_type: Resource type.
|
62
|
+
:param method: Resource method.
|
63
|
+
"""
|
64
|
+
return f"{resource_type.value}::{method}"
|
@@ -0,0 +1,126 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from functools import cached_property
|
20
|
+
from typing import TYPE_CHECKING, Callable
|
21
|
+
|
22
|
+
from airflow.configuration import conf
|
23
|
+
from airflow.exceptions import AirflowException
|
24
|
+
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities, get_action_id, get_entity_type
|
25
|
+
from airflow.providers.amazon.aws.auth_manager.constants import (
|
26
|
+
CONF_AVP_POLICY_STORE_ID_KEY,
|
27
|
+
CONF_CONN_ID_KEY,
|
28
|
+
CONF_SECTION_NAME,
|
29
|
+
)
|
30
|
+
from airflow.providers.amazon.aws.hooks.verified_permissions import VerifiedPermissionsHook
|
31
|
+
from airflow.utils.log.logging_mixin import LoggingMixin
|
32
|
+
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
35
|
+
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
36
|
+
|
37
|
+
|
38
|
+
class AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
|
39
|
+
"""
|
40
|
+
Facade for Amazon Verified Permissions.
|
41
|
+
|
42
|
+
Used as an intermediate layer between AWS auth manager and Amazon Verified Permissions.
|
43
|
+
"""
|
44
|
+
|
45
|
+
@cached_property
|
46
|
+
def avp_client(self):
|
47
|
+
"""Build Amazon Verified Permissions client."""
|
48
|
+
aws_conn_id = conf.get(CONF_SECTION_NAME, CONF_CONN_ID_KEY)
|
49
|
+
return VerifiedPermissionsHook(aws_conn_id=aws_conn_id).conn
|
50
|
+
|
51
|
+
@cached_property
|
52
|
+
def avp_policy_store_id(self):
|
53
|
+
"""Get the Amazon Verified Permission policy store ID from config."""
|
54
|
+
return conf.get_mandatory_value(CONF_SECTION_NAME, CONF_AVP_POLICY_STORE_ID_KEY)
|
55
|
+
|
56
|
+
def is_authorized(
|
57
|
+
self,
|
58
|
+
*,
|
59
|
+
method: ResourceMethod,
|
60
|
+
entity_type: AvpEntities,
|
61
|
+
user: AwsAuthManagerUser,
|
62
|
+
entity_id: str | None = None,
|
63
|
+
entity_fetcher: Callable | None = None,
|
64
|
+
) -> bool:
|
65
|
+
"""
|
66
|
+
Make an authorization decision against Amazon Verified Permissions.
|
67
|
+
|
68
|
+
Check whether the user has permissions to access given resource.
|
69
|
+
|
70
|
+
:param method: the method to perform
|
71
|
+
:param entity_type: the entity type the user accesses
|
72
|
+
:param user: the user
|
73
|
+
:param entity_id: the entity ID the user accesses. If not provided, all entities of the type will be
|
74
|
+
considered.
|
75
|
+
:param entity_fetcher: function that returns list of entities to be passed to Amazon Verified
|
76
|
+
Permissions. Only needed if some resource properties are used in the policies (e.g. DAG folder).
|
77
|
+
"""
|
78
|
+
entity_list = self._get_user_role_entities(user)
|
79
|
+
if entity_fetcher and entity_id:
|
80
|
+
# If no entity ID is provided, there is no need to fetch entities.
|
81
|
+
# We just need to know whether the user has permissions to access all resources from this type
|
82
|
+
entity_list += entity_fetcher()
|
83
|
+
|
84
|
+
self.log.debug(
|
85
|
+
"Making authorization request for user=%s, method=%s, entity_type=%s, entity_id=%s",
|
86
|
+
user.get_id(),
|
87
|
+
method,
|
88
|
+
entity_type,
|
89
|
+
entity_id,
|
90
|
+
)
|
91
|
+
|
92
|
+
resp = self.avp_client.is_authorized(
|
93
|
+
policyStoreId=self.avp_policy_store_id,
|
94
|
+
principal={"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
|
95
|
+
action={
|
96
|
+
"actionType": get_entity_type(AvpEntities.ACTION),
|
97
|
+
"actionId": get_action_id(entity_type, method),
|
98
|
+
},
|
99
|
+
resource={"entityType": get_entity_type(entity_type), "entityId": entity_id or "*"},
|
100
|
+
entities={"entityList": entity_list},
|
101
|
+
)
|
102
|
+
|
103
|
+
self.log.debug("Authorization response: %s", resp)
|
104
|
+
|
105
|
+
if len(resp.get("errors", [])) > 0:
|
106
|
+
self.log.error(
|
107
|
+
"Error occurred while making an authorization decision. Errors: %s", resp["errors"]
|
108
|
+
)
|
109
|
+
raise AirflowException("Error occurred while making an authorization decision.")
|
110
|
+
|
111
|
+
return resp["decision"] == "ALLOW"
|
112
|
+
|
113
|
+
@staticmethod
|
114
|
+
def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]:
|
115
|
+
user_entity = {
|
116
|
+
"identifier": {"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
|
117
|
+
"parents": [
|
118
|
+
{"entityType": get_entity_type(AvpEntities.ROLE), "entityId": group}
|
119
|
+
for group in user.get_groups()
|
120
|
+
],
|
121
|
+
}
|
122
|
+
role_entities = [
|
123
|
+
{"identifier": {"entityType": get_entity_type(AvpEntities.ROLE), "entityId": group}}
|
124
|
+
for group in user.get_groups()
|
125
|
+
]
|
126
|
+
return [user_entity, *role_entities]
|
@@ -23,6 +23,8 @@ from flask import session, url_for
|
|
23
23
|
|
24
24
|
from airflow.configuration import conf
|
25
25
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
26
|
+
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
|
27
|
+
from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade
|
26
28
|
from airflow.providers.amazon.aws.auth_manager.constants import (
|
27
29
|
CONF_ENABLE_KEY,
|
28
30
|
CONF_SECTION_NAME,
|
@@ -72,6 +74,10 @@ class AwsAuthManager(BaseAuthManager):
|
|
72
74
|
"The AWS auth manager is currently being built. It is not finalized. It is not intended to be used yet."
|
73
75
|
)
|
74
76
|
|
77
|
+
@cached_property
|
78
|
+
def avp_facade(self):
|
79
|
+
return AwsAuthManagerAmazonVerifiedPermissionsFacade()
|
80
|
+
|
75
81
|
def get_user(self) -> AwsAuthManagerUser | None:
|
76
82
|
return session["aws_user"] if self.is_logged_in() else None
|
77
83
|
|
@@ -85,7 +91,13 @@ class AwsAuthManager(BaseAuthManager):
|
|
85
91
|
details: ConfigurationDetails | None = None,
|
86
92
|
user: BaseUser | None = None,
|
87
93
|
) -> bool:
|
88
|
-
|
94
|
+
config_section = details.section if details else None
|
95
|
+
return self.avp_facade.is_authorized(
|
96
|
+
method=method,
|
97
|
+
entity_type=AvpEntities.CONFIGURATION,
|
98
|
+
user=user or self.get_user(),
|
99
|
+
entity_id=config_section,
|
100
|
+
)
|
89
101
|
|
90
102
|
def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool:
|
91
103
|
return self.is_logged_in()
|
@@ -97,7 +109,13 @@ class AwsAuthManager(BaseAuthManager):
|
|
97
109
|
details: ConnectionDetails | None = None,
|
98
110
|
user: BaseUser | None = None,
|
99
111
|
) -> bool:
|
100
|
-
|
112
|
+
connection_id = details.conn_id if details else None
|
113
|
+
return self.avp_facade.is_authorized(
|
114
|
+
method=method,
|
115
|
+
entity_type=AvpEntities.CONNECTION,
|
116
|
+
user=user or self.get_user(),
|
117
|
+
entity_id=connection_id,
|
118
|
+
)
|
101
119
|
|
102
120
|
def is_authorized_dag(
|
103
121
|
self,
|
@@ -112,17 +130,35 @@ class AwsAuthManager(BaseAuthManager):
|
|
112
130
|
def is_authorized_dataset(
|
113
131
|
self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None
|
114
132
|
) -> bool:
|
115
|
-
|
133
|
+
dataset_uri = details.uri if details else None
|
134
|
+
return self.avp_facade.is_authorized(
|
135
|
+
method=method,
|
136
|
+
entity_type=AvpEntities.DATASET,
|
137
|
+
user=user or self.get_user(),
|
138
|
+
entity_id=dataset_uri,
|
139
|
+
)
|
116
140
|
|
117
141
|
def is_authorized_pool(
|
118
142
|
self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None
|
119
143
|
) -> bool:
|
120
|
-
|
144
|
+
pool_name = details.name if details else None
|
145
|
+
return self.avp_facade.is_authorized(
|
146
|
+
method=method,
|
147
|
+
entity_type=AvpEntities.POOL,
|
148
|
+
user=user or self.get_user(),
|
149
|
+
entity_id=pool_name,
|
150
|
+
)
|
121
151
|
|
122
152
|
def is_authorized_variable(
|
123
153
|
self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None
|
124
154
|
) -> bool:
|
125
|
-
|
155
|
+
variable_key = details.key if details else None
|
156
|
+
return self.avp_facade.is_authorized(
|
157
|
+
method=method,
|
158
|
+
entity_type=AvpEntities.VARIABLE,
|
159
|
+
user=user or self.get_user(),
|
160
|
+
entity_id=variable_key,
|
161
|
+
)
|
126
162
|
|
127
163
|
def is_authorized_view(
|
128
164
|
self,
|
@@ -130,7 +166,12 @@ class AwsAuthManager(BaseAuthManager):
|
|
130
166
|
access_view: AccessView,
|
131
167
|
user: BaseUser | None = None,
|
132
168
|
) -> bool:
|
133
|
-
return self.
|
169
|
+
return self.avp_facade.is_authorized(
|
170
|
+
method="GET",
|
171
|
+
entity_type=AvpEntities.VIEW,
|
172
|
+
user=user or self.get_user(),
|
173
|
+
entity_id=access_view.value,
|
174
|
+
)
|
134
175
|
|
135
176
|
def get_url_login(self, **kwargs) -> str:
|
136
177
|
return url_for("AwsAuthManagerAuthenticationViews.login")
|
@@ -18,6 +18,8 @@
|
|
18
18
|
# Configuration keys
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
|
+
CONF_ENABLE_KEY = "enable"
|
21
22
|
CONF_SECTION_NAME = "aws_auth_manager"
|
23
|
+
CONF_CONN_ID_KEY = "conn_id"
|
22
24
|
CONF_SAML_METADATA_URL_KEY = "saml_metadata_url"
|
23
|
-
|
25
|
+
CONF_AVP_POLICY_STORE_ID_KEY = "avp_policy_store_id"
|
@@ -25,7 +25,7 @@ import requests
|
|
25
25
|
from botocore import UNSIGNED
|
26
26
|
from requests import HTTPError
|
27
27
|
|
28
|
-
from airflow.providers.amazon.aws.hooks.
|
28
|
+
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
31
|
from botocore.awsrequest import AWSRequest
|
@@ -55,14 +55,14 @@ def get_fs(conn_id: str | None) -> AbstractFileSystem:
|
|
55
55
|
"pip install apache-airflow-providers-amazon[s3fs]"
|
56
56
|
)
|
57
57
|
|
58
|
-
|
59
|
-
session =
|
60
|
-
endpoint_url =
|
58
|
+
s3_hook = S3Hook(aws_conn_id=conn_id)
|
59
|
+
session = s3_hook.get_session(deferrable=True)
|
60
|
+
endpoint_url = s3_hook.conn_config.get_service_endpoint_url(service_name="s3")
|
61
61
|
|
62
|
-
config_kwargs: dict[str, Any] =
|
62
|
+
config_kwargs: dict[str, Any] = s3_hook.conn_config.extra_config.get("config_kwargs", {})
|
63
63
|
register_events: dict[str, Callable[[Properties], None]] = {}
|
64
64
|
|
65
|
-
s3_service_config =
|
65
|
+
s3_service_config = s3_hook.service_config
|
66
66
|
if signer := s3_service_config.get("signer", None):
|
67
67
|
log.info("Loading signer %s", signer)
|
68
68
|
if singer_func := SIGNERS.get(signer):
|
@@ -292,24 +292,17 @@ class AthenaHook(AwsBaseHook):
|
|
292
292
|
|
293
293
|
:param query_execution_id: Id of submitted athena query
|
294
294
|
"""
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
try:
|
301
|
-
output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
|
302
|
-
except KeyError:
|
303
|
-
self.log.error(
|
304
|
-
"Error retrieving OutputLocation. Query execution id: %s", query_execution_id
|
305
|
-
)
|
306
|
-
raise
|
307
|
-
else:
|
308
|
-
raise
|
309
|
-
else:
|
310
|
-
raise ValueError("Invalid Query execution id. Query execution id: %s", query_execution_id)
|
295
|
+
if not query_execution_id:
|
296
|
+
raise ValueError(f"Invalid Query execution id. Query execution id: {query_execution_id}")
|
297
|
+
|
298
|
+
if not (response := self.get_query_info(query_execution_id=query_execution_id, use_cache=True)):
|
299
|
+
raise ValueError(f"Unable to get query information for execution id: {query_execution_id}")
|
311
300
|
|
312
|
-
|
301
|
+
try:
|
302
|
+
return response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
|
303
|
+
except KeyError:
|
304
|
+
self.log.error("Error retrieving OutputLocation. Query execution id: %s", query_execution_id)
|
305
|
+
raise
|
313
306
|
|
314
307
|
def stop_query(self, query_execution_id: str) -> dict:
|
315
308
|
"""Cancel the submitted query.
|
@@ -19,16 +19,21 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
import functools
|
21
21
|
import time
|
22
|
+
from typing import Callable, TypeVar
|
22
23
|
|
23
24
|
from airflow.exceptions import AirflowException
|
24
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
26
|
+
from airflow.typing_compat import ParamSpec
|
25
27
|
|
28
|
+
PS = ParamSpec("PS")
|
29
|
+
RT = TypeVar("RT")
|
26
30
|
|
27
|
-
|
31
|
+
|
32
|
+
def only_client_type(func: Callable[PS, RT]) -> Callable[PS, RT]:
|
28
33
|
@functools.wraps(func)
|
29
|
-
def checker(
|
30
|
-
if
|
31
|
-
return func(
|
34
|
+
def checker(*args, **kwargs) -> RT:
|
35
|
+
if args[0]._api_type == "client_type":
|
36
|
+
return func(*args, **kwargs)
|
32
37
|
|
33
38
|
ec2_doc_link = "https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html"
|
34
39
|
raise AirflowException(
|
@@ -85,7 +90,7 @@ class EC2Hook(AwsBaseHook):
|
|
85
90
|
:return: Instance object
|
86
91
|
"""
|
87
92
|
if self._api_type == "client_type":
|
88
|
-
return self.get_instances(filters=filters, instance_ids=[instance_id])
|
93
|
+
return self.get_instances(filters=filters, instance_ids=[instance_id])[0]
|
89
94
|
|
90
95
|
return self.conn.Instance(id=instance_id)
|
91
96
|
|
@@ -437,8 +437,6 @@ class EmrContainerHook(AwsBaseHook):
|
|
437
437
|
|
438
438
|
:param job_id: The ID of the job run request.
|
439
439
|
"""
|
440
|
-
reason = None # We absorb any errors if we can't retrieve the job status
|
441
|
-
|
442
440
|
try:
|
443
441
|
response = self.conn.describe_job_run(
|
444
442
|
virtualClusterId=self.virtual_cluster_id,
|
@@ -446,13 +444,13 @@ class EmrContainerHook(AwsBaseHook):
|
|
446
444
|
)
|
447
445
|
failure_reason = response["jobRun"]["failureReason"]
|
448
446
|
state_details = response["jobRun"]["stateDetails"]
|
449
|
-
|
447
|
+
return f"{failure_reason} - {state_details}"
|
450
448
|
except KeyError:
|
451
449
|
self.log.error("Could not get status of the EMR on EKS job")
|
452
450
|
except ClientError as ex:
|
453
451
|
self.log.error("AWS request failed, check logs for more info: %s", ex)
|
454
452
|
|
455
|
-
return
|
453
|
+
return None
|
456
454
|
|
457
455
|
def check_query_status(self, job_id: str) -> str | None:
|
458
456
|
"""
|
@@ -491,26 +489,21 @@ class EmrContainerHook(AwsBaseHook):
|
|
491
489
|
:param max_polling_attempts: Number of times to poll for query state before function exits
|
492
490
|
"""
|
493
491
|
try_number = 1
|
494
|
-
final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
|
495
|
-
|
496
492
|
while True:
|
497
493
|
query_state = self.check_query_status(job_id)
|
494
|
+
if query_state in self.TERMINAL_STATES:
|
495
|
+
self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
|
496
|
+
return query_state
|
498
497
|
if query_state is None:
|
499
498
|
self.log.info("Try %s: Invalid query state. Retrying again", try_number)
|
500
|
-
elif query_state in self.TERMINAL_STATES:
|
501
|
-
self.log.info("Try %s: Query execution completed. Final state is %s", try_number, query_state)
|
502
|
-
final_query_state = query_state
|
503
|
-
break
|
504
499
|
else:
|
505
500
|
self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
|
506
501
|
if (
|
507
502
|
max_polling_attempts and try_number >= max_polling_attempts
|
508
503
|
): # Break loop if max_polling_attempts reached
|
509
|
-
|
510
|
-
break
|
504
|
+
return query_state
|
511
505
|
try_number += 1
|
512
506
|
time.sleep(poll_interval)
|
513
|
-
return final_query_state
|
514
507
|
|
515
508
|
def stop_query(self, job_id: str) -> dict:
|
516
509
|
"""
|
@@ -102,24 +102,47 @@ class RedshiftSQLHook(DbApiHook):
|
|
102
102
|
Port is required. If none is provided, default is used for each service.
|
103
103
|
"""
|
104
104
|
port = conn.port or 5439
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
105
|
+
is_serverless = conn.extra_dejson.get("is_serverless", False)
|
106
|
+
if is_serverless:
|
107
|
+
serverless_work_group = conn.extra_dejson.get("serverless_work_group")
|
108
|
+
if not serverless_work_group:
|
109
|
+
raise AirflowException(
|
110
|
+
"Please set serverless_work_group in redshift connection to use IAM with"
|
111
|
+
" Redshift Serverless."
|
112
|
+
)
|
113
|
+
serverless_token_duration_seconds = conn.extra_dejson.get(
|
114
|
+
"serverless_token_duration_seconds", 3600
|
115
|
+
)
|
116
|
+
redshift_serverless_client = AwsBaseHook(
|
117
|
+
aws_conn_id=self.aws_conn_id, client_type="redshift-serverless"
|
118
|
+
).conn
|
119
|
+
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#get-credentials
|
120
|
+
cluster_creds = redshift_serverless_client.get_credentials(
|
121
|
+
dbName=conn.schema,
|
122
|
+
workgroupName=serverless_work_group,
|
123
|
+
durationSeconds=serverless_token_duration_seconds,
|
124
|
+
)
|
125
|
+
token = cluster_creds["dbPassword"]
|
126
|
+
login = cluster_creds["dbUser"]
|
127
|
+
else:
|
128
|
+
# Pull the custer-identifier from the beginning of the Redshift URL
|
129
|
+
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
|
130
|
+
cluster_identifier = conn.extra_dejson.get("cluster_identifier")
|
131
|
+
if not cluster_identifier:
|
132
|
+
if conn.host:
|
133
|
+
cluster_identifier = conn.host.split(".", 1)[0]
|
134
|
+
else:
|
135
|
+
raise AirflowException("Please set cluster_identifier or host in redshift connection.")
|
136
|
+
redshift_client = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type="redshift").conn
|
137
|
+
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
|
138
|
+
cluster_creds = redshift_client.get_cluster_credentials(
|
139
|
+
DbUser=conn.login,
|
140
|
+
DbName=conn.schema,
|
141
|
+
ClusterIdentifier=cluster_identifier,
|
142
|
+
AutoCreate=False,
|
143
|
+
)
|
144
|
+
token = cluster_creds["DbPassword"]
|
145
|
+
login = cluster_creds["DbUser"]
|
123
146
|
return login, token, port
|
124
147
|
|
125
148
|
def get_uri(self) -> str:
|
@@ -805,7 +805,7 @@ class S3Hook(AwsBaseHook):
|
|
805
805
|
_prefix = _original_prefix.split("*", 1)[0] if _apply_wildcard else _original_prefix
|
806
806
|
delimiter = delimiter or ""
|
807
807
|
start_after_key = start_after_key or ""
|
808
|
-
|
808
|
+
object_filter_usr = object_filter
|
809
809
|
config = {
|
810
810
|
"PageSize": page_size,
|
811
811
|
"MaxItems": max_items,
|
@@ -827,8 +827,8 @@ class S3Hook(AwsBaseHook):
|
|
827
827
|
if _apply_wildcard:
|
828
828
|
new_keys = (k for k in new_keys if fnmatch.fnmatch(k["Key"], _original_prefix))
|
829
829
|
keys.extend(new_keys)
|
830
|
-
if
|
831
|
-
return
|
830
|
+
if object_filter_usr is not None:
|
831
|
+
return object_filter_usr(keys, from_datetime, to_datetime)
|
832
832
|
|
833
833
|
return self._list_key_object_filter(keys, from_datetime, to_datetime)
|
834
834
|
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
22
|
+
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from mypy_boto3_verifiedpermissions.client import VerifiedPermissionsClient # noqa
|
25
|
+
|
26
|
+
|
27
|
+
class VerifiedPermissionsHook(AwsGenericHook["VerifiedPermissionsClient"]):
|
28
|
+
"""
|
29
|
+
Interact with Amazon Verified Permissions.
|
30
|
+
|
31
|
+
Provide thin wrapper around :external+boto3:py:class:`boto3.client("verifiedpermissions")
|
32
|
+
<VerifiedPermissions.Client>`.
|
33
|
+
|
34
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
35
|
+
are passed down to the underlying AwsBaseHook.
|
36
|
+
|
37
|
+
.. seealso::
|
38
|
+
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
39
|
+
- `Amazon Appflow API Reference <https://docs.aws.amazon.com/verifiedpermissions/latest/apireference/Welcome.html>`__
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, *args, **kwargs) -> None:
|
43
|
+
kwargs["client_type"] = "verifiedpermissions"
|
44
|
+
super().__init__(*args, **kwargs)
|
@@ -96,15 +96,15 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
96
96
|
# Replace unsupported log group name characters
|
97
97
|
return super()._render_filename(ti, try_number).replace(":", "_")
|
98
98
|
|
99
|
-
def set_context(self, ti):
|
99
|
+
def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
|
100
100
|
super().set_context(ti)
|
101
|
-
|
101
|
+
_json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer")
|
102
102
|
self.handler = watchtower.CloudWatchLogHandler(
|
103
103
|
log_group_name=self.log_group,
|
104
104
|
log_stream_name=self._render_filename(ti, ti.try_number),
|
105
105
|
use_queues=not getattr(ti, "is_trigger_log_context", False),
|
106
106
|
boto3_client=self.hook.get_conn(),
|
107
|
-
json_serialize_default=
|
107
|
+
json_serialize_default=_json_serialize,
|
108
108
|
)
|
109
109
|
|
110
110
|
def close(self):
|