apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +69 -97
- airflow/providers/amazon/aws/auth_manager/router/login.py +9 -4
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/base_aws.py +8 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker.py +1 -2
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +160 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +31 -42
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +11 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +45 -4
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/METADATA +38 -31
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/RECORD +55 -48
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.
|
32
|
+
__version__ = "9.5.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.9.0"
|
@@ -20,7 +20,7 @@ from enum import Enum
|
|
20
20
|
from typing import TYPE_CHECKING
|
21
21
|
|
22
22
|
if TYPE_CHECKING:
|
23
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
23
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
24
24
|
|
25
25
|
AVP_PREFIX_ENTITIES = "Airflow::"
|
26
26
|
|
@@ -34,6 +34,8 @@ class AvpEntities(Enum):
|
|
34
34
|
|
35
35
|
# Resource types
|
36
36
|
ASSET = "Asset"
|
37
|
+
ASSET_ALIAS = "AssetAlias"
|
38
|
+
BACKFILL = "Backfills"
|
37
39
|
CONFIGURATION = "Configuration"
|
38
40
|
CONNECTION = "Connection"
|
39
41
|
CUSTOM = "Custom"
|
@@ -36,7 +36,7 @@ from airflow.utils.helpers import prune_dict
|
|
36
36
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
37
37
|
|
38
38
|
if TYPE_CHECKING:
|
39
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
39
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
40
40
|
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
41
41
|
|
42
42
|
|
@@ -18,16 +18,18 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
import argparse
|
20
20
|
from collections import defaultdict
|
21
|
-
from collections.abc import
|
21
|
+
from collections.abc import Sequence
|
22
22
|
from functools import cached_property
|
23
23
|
from typing import TYPE_CHECKING, Any, cast
|
24
|
+
from urllib.parse import urljoin
|
24
25
|
|
25
26
|
from fastapi import FastAPI
|
26
|
-
from flask import session
|
27
27
|
|
28
|
-
from airflow.
|
29
|
-
from airflow.auth.managers.
|
28
|
+
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
|
29
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
|
30
|
+
from airflow.api_fastapi.auth.managers.models.resource_details import (
|
30
31
|
AccessView,
|
32
|
+
BackfillDetails,
|
31
33
|
ConnectionDetails,
|
32
34
|
DagAccessEntity,
|
33
35
|
DagDetails,
|
@@ -49,16 +51,19 @@ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
|
49
51
|
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
50
52
|
|
51
53
|
if TYPE_CHECKING:
|
52
|
-
from
|
53
|
-
|
54
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
55
|
-
from airflow.auth.managers.models.batch_apis import (
|
54
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
55
|
+
from airflow.api_fastapi.auth.managers.models.batch_apis import (
|
56
56
|
IsAuthorizedConnectionRequest,
|
57
57
|
IsAuthorizedDagRequest,
|
58
58
|
IsAuthorizedPoolRequest,
|
59
59
|
IsAuthorizedVariableRequest,
|
60
60
|
)
|
61
|
-
from airflow.auth.managers.models.resource_details import
|
61
|
+
from airflow.api_fastapi.auth.managers.models.resource_details import (
|
62
|
+
AssetAliasDetails,
|
63
|
+
AssetDetails,
|
64
|
+
ConfigurationDetails,
|
65
|
+
)
|
66
|
+
from airflow.api_fastapi.common.types import MenuItem
|
62
67
|
|
63
68
|
|
64
69
|
class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
@@ -83,21 +88,15 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
83
88
|
return AwsAuthManagerAmazonVerifiedPermissionsFacade()
|
84
89
|
|
85
90
|
@cached_property
|
86
|
-
def
|
87
|
-
return conf.get("
|
88
|
-
|
89
|
-
def get_user(self) -> AwsAuthManagerUser | None:
|
90
|
-
return session["aws_user"] if self.is_logged_in() else None
|
91
|
-
|
92
|
-
def is_logged_in(self) -> bool:
|
93
|
-
return "aws_user" in session
|
91
|
+
def apiserver_endpoint(self) -> str:
|
92
|
+
return conf.get("api", "base_url")
|
94
93
|
|
95
94
|
def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
|
96
|
-
return AwsAuthManagerUser(**token)
|
95
|
+
return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
|
97
96
|
|
98
97
|
def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
|
99
98
|
return {
|
100
|
-
"
|
99
|
+
"sub": user.get_id(),
|
101
100
|
"groups": user.get_groups(),
|
102
101
|
"username": user.username,
|
103
102
|
"email": user.email,
|
@@ -159,12 +158,28 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
159
158
|
context=context,
|
160
159
|
)
|
161
160
|
|
161
|
+
def is_authorized_backfill(
|
162
|
+
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: BackfillDetails | None = None
|
163
|
+
) -> bool:
|
164
|
+
backfill_id = details.id if details else None
|
165
|
+
return self.avp_facade.is_authorized(
|
166
|
+
method=method, entity_type=AvpEntities.BACKFILL, user=user, entity_id=backfill_id
|
167
|
+
)
|
168
|
+
|
162
169
|
def is_authorized_asset(
|
163
170
|
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
|
164
171
|
) -> bool:
|
165
|
-
|
172
|
+
asset_id = details.id if details else None
|
166
173
|
return self.avp_facade.is_authorized(
|
167
|
-
method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=
|
174
|
+
method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
|
175
|
+
)
|
176
|
+
|
177
|
+
def is_authorized_asset_alias(
|
178
|
+
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetAliasDetails | None = None
|
179
|
+
) -> bool:
|
180
|
+
asset_alias_id = details.id if details else None
|
181
|
+
return self.avp_facade.is_authorized(
|
182
|
+
method=method, entity_type=AvpEntities.ASSET_ALIAS, user=user, entity_id=asset_alias_id
|
168
183
|
)
|
169
184
|
|
170
185
|
def is_authorized_pool(
|
@@ -204,7 +219,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
204
219
|
|
205
220
|
def is_authorized_custom_view(
|
206
221
|
self, *, method: ResourceMethod | str, resource_name: str, user: AwsAuthManagerUser
|
207
|
-
):
|
222
|
+
) -> bool:
|
208
223
|
return self.avp_facade.is_authorized(
|
209
224
|
method=method,
|
210
225
|
entity_type=AvpEntities.CUSTOM,
|
@@ -212,6 +227,25 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
212
227
|
entity_id=resource_name,
|
213
228
|
)
|
214
229
|
|
230
|
+
def filter_authorized_menu_items(
|
231
|
+
self, menu_items: list[MenuItem], *, user: AwsAuthManagerUser
|
232
|
+
) -> list[MenuItem]:
|
233
|
+
requests: dict[str, IsAuthorizedRequest] = {}
|
234
|
+
for menu_item in menu_items:
|
235
|
+
requests[menu_item.value] = self._get_menu_item_request(menu_item.value)
|
236
|
+
|
237
|
+
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
238
|
+
requests=list(requests.values()), user=user
|
239
|
+
)
|
240
|
+
|
241
|
+
def _has_access_to_menu_item(request: IsAuthorizedRequest):
|
242
|
+
result = self.avp_facade.get_batch_is_authorized_single_result(
|
243
|
+
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
244
|
+
)
|
245
|
+
return result["decision"] == "ALLOW"
|
246
|
+
|
247
|
+
return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]
|
248
|
+
|
215
249
|
def batch_is_authorized_connection(
|
216
250
|
self,
|
217
251
|
requests: Sequence[IsAuthorizedConnectionRequest],
|
@@ -287,28 +321,23 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
287
321
|
]
|
288
322
|
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
|
289
323
|
|
290
|
-
def
|
324
|
+
def filter_authorized_dag_ids(
|
291
325
|
self,
|
292
326
|
*,
|
293
327
|
dag_ids: set[str],
|
294
328
|
user: AwsAuthManagerUser,
|
295
|
-
|
329
|
+
method: ResourceMethod = "GET",
|
296
330
|
):
|
297
|
-
if not methods:
|
298
|
-
methods = ["PUT", "GET"]
|
299
|
-
|
300
331
|
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
|
301
332
|
requests_list: list[IsAuthorizedRequest] = []
|
302
333
|
for dag_id in dag_ids:
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
requests[dag_id][cast("ResourceMethod", method)] = request
|
311
|
-
requests_list.append(request)
|
334
|
+
request: IsAuthorizedRequest = {
|
335
|
+
"method": method,
|
336
|
+
"entity_type": AvpEntities.DAG,
|
337
|
+
"entity_id": dag_id,
|
338
|
+
}
|
339
|
+
requests[dag_id][method] = request
|
340
|
+
requests_list.append(request)
|
312
341
|
|
313
342
|
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
314
343
|
requests=requests_list, user=user
|
@@ -320,67 +349,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
320
349
|
)
|
321
350
|
return result["decision"] == "ALLOW"
|
322
351
|
|
323
|
-
return {
|
324
|
-
dag_id
|
325
|
-
for dag_id in dag_ids
|
326
|
-
if (
|
327
|
-
"GET" in methods
|
328
|
-
and _has_access_to_dag(requests[dag_id]["GET"])
|
329
|
-
or "PUT" in methods
|
330
|
-
and _has_access_to_dag(requests[dag_id]["PUT"])
|
331
|
-
)
|
332
|
-
}
|
333
|
-
|
334
|
-
def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]:
|
335
|
-
"""
|
336
|
-
Filter menu items based on user permissions.
|
337
|
-
|
338
|
-
:param menu_items: list of all menu items
|
339
|
-
"""
|
340
|
-
user = self.get_user()
|
341
|
-
if not user:
|
342
|
-
return []
|
343
|
-
|
344
|
-
requests: dict[str, IsAuthorizedRequest] = {}
|
345
|
-
for menu_item in menu_items:
|
346
|
-
if menu_item.childs:
|
347
|
-
for child in menu_item.childs:
|
348
|
-
requests[child.name] = self._get_menu_item_request(child.name)
|
349
|
-
else:
|
350
|
-
requests[menu_item.name] = self._get_menu_item_request(menu_item.name)
|
351
|
-
|
352
|
-
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
353
|
-
requests=list(requests.values()), user=user
|
354
|
-
)
|
355
|
-
|
356
|
-
def _has_access_to_menu_item(request: IsAuthorizedRequest):
|
357
|
-
result = self.avp_facade.get_batch_is_authorized_single_result(
|
358
|
-
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
359
|
-
)
|
360
|
-
return result["decision"] == "ALLOW"
|
361
|
-
|
362
|
-
accessible_items = []
|
363
|
-
for menu_item in menu_items:
|
364
|
-
if menu_item.childs:
|
365
|
-
accessible_children = []
|
366
|
-
for child in menu_item.childs:
|
367
|
-
if _has_access_to_menu_item(requests[child.name]):
|
368
|
-
accessible_children.append(child)
|
369
|
-
menu_item.childs = accessible_children
|
370
|
-
|
371
|
-
# Display the menu if the user has access to at least one sub item
|
372
|
-
if len(accessible_children) > 0:
|
373
|
-
accessible_items.append(menu_item)
|
374
|
-
elif _has_access_to_menu_item(requests[menu_item.name]):
|
375
|
-
accessible_items.append(menu_item)
|
376
|
-
|
377
|
-
return accessible_items
|
352
|
+
return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
|
378
353
|
|
379
354
|
def get_url_login(self, **kwargs) -> str:
|
380
|
-
return f"{
|
381
|
-
|
382
|
-
def get_url_logout(self) -> str:
|
383
|
-
raise NotImplementedError()
|
355
|
+
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
|
384
356
|
|
385
357
|
@staticmethod
|
386
358
|
def get_cli_commands() -> list[CLICommand]:
|
@@ -409,11 +381,11 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
409
381
|
return app
|
410
382
|
|
411
383
|
@staticmethod
|
412
|
-
def _get_menu_item_request(
|
384
|
+
def _get_menu_item_request(menu_item_text: str) -> IsAuthorizedRequest:
|
413
385
|
return {
|
414
386
|
"method": "MENU",
|
415
387
|
"entity_type": AvpEntities.MENU,
|
416
|
-
"entity_id":
|
388
|
+
"entity_id": menu_item_text,
|
417
389
|
}
|
418
390
|
|
419
391
|
def _check_avp_schema_version(self):
|
@@ -25,7 +25,8 @@ from fastapi import HTTPException, Request
|
|
25
25
|
from starlette import status
|
26
26
|
from starlette.responses import RedirectResponse
|
27
27
|
|
28
|
-
from airflow.api_fastapi.app import get_auth_manager
|
28
|
+
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, get_auth_manager
|
29
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
|
29
30
|
from airflow.api_fastapi.common.router import AirflowRouter
|
30
31
|
from airflow.configuration import conf
|
31
32
|
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
|
@@ -79,12 +80,16 @@ def login_callback(request: Request):
|
|
79
80
|
username=saml_auth.get_nameid(),
|
80
81
|
email=attributes["email"][0] if "email" in attributes else None,
|
81
82
|
)
|
82
|
-
|
83
|
+
url = conf.get("api", "base_url")
|
84
|
+
token = get_auth_manager().generate_jwt(user)
|
85
|
+
response = RedirectResponse(url=url, status_code=303)
|
86
|
+
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=True)
|
87
|
+
return response
|
83
88
|
|
84
89
|
|
85
90
|
def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
86
91
|
request_data = _prepare_request(request)
|
87
|
-
base_url = conf.get(section="
|
92
|
+
base_url = conf.get(section="api", key="base_url")
|
88
93
|
settings = {
|
89
94
|
# We want to keep this flag on in case of errors.
|
90
95
|
# It provides an error reasons, if turned off, it does not
|
@@ -92,7 +97,7 @@ def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
|
92
97
|
"sp": {
|
93
98
|
"entityId": "aws-auth-manager-saml-client",
|
94
99
|
"assertionConsumerService": {
|
95
|
-
"url": f"{base_url}/
|
100
|
+
"url": f"{base_url}{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login_callback",
|
96
101
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
97
102
|
},
|
98
103
|
},
|
@@ -19,11 +19,14 @@ from __future__ import annotations
|
|
19
19
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
20
20
|
|
21
21
|
try:
|
22
|
-
from airflow.auth.managers.models.base_user import BaseUser
|
22
|
+
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
|
23
23
|
except ImportError:
|
24
|
-
|
25
|
-
|
26
|
-
|
24
|
+
try:
|
25
|
+
from airflow.auth.managers.models.base_user import BaseUser # type: ignore[no-redef]
|
26
|
+
except ImportError:
|
27
|
+
raise AirflowOptionalProviderFeatureException(
|
28
|
+
"Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
|
29
|
+
) from None
|
27
30
|
|
28
31
|
|
29
32
|
class AwsAuthManagerUser(BaseUser):
|
@@ -16,15 +16,7 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from
|
20
|
-
from typing import TYPE_CHECKING, cast
|
21
|
-
|
22
|
-
from mypy_boto3_appflow.type_defs import (
|
23
|
-
DestinationFlowConfigTypeDef,
|
24
|
-
SourceFlowConfigTypeDef,
|
25
|
-
TaskTypeDef,
|
26
|
-
TriggerConfigTypeDef,
|
27
|
-
)
|
19
|
+
from typing import TYPE_CHECKING
|
28
20
|
|
29
21
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
30
22
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
@@ -125,11 +117,9 @@ class AppflowHook(AwsGenericHook["AppflowClient"]):
|
|
125
117
|
|
126
118
|
self.conn.update_flow(
|
127
119
|
flowName=response["flowName"],
|
128
|
-
destinationFlowConfigList=
|
129
|
-
|
130
|
-
|
131
|
-
sourceFlowConfig=cast(SourceFlowConfigTypeDef, response["sourceFlowConfig"]),
|
132
|
-
triggerConfig=cast(TriggerConfigTypeDef, response["triggerConfig"]),
|
120
|
+
destinationFlowConfigList=response["destinationFlowConfigList"],
|
121
|
+
sourceFlowConfig=response["sourceFlowConfig"],
|
122
|
+
triggerConfig=response["triggerConfig"],
|
133
123
|
description=response.get("description", "Flow description."),
|
134
|
-
tasks=
|
124
|
+
tasks=tasks,
|
135
125
|
)
|
@@ -30,6 +30,7 @@ import inspect
|
|
30
30
|
import json
|
31
31
|
import logging
|
32
32
|
import os
|
33
|
+
import warnings
|
33
34
|
from copy import deepcopy
|
34
35
|
from functools import cached_property, wraps
|
35
36
|
from pathlib import Path
|
@@ -41,6 +42,7 @@ import botocore.session
|
|
41
42
|
import jinja2
|
42
43
|
import requests
|
43
44
|
import tenacity
|
45
|
+
from asgiref.sync import sync_to_async
|
44
46
|
from botocore.config import Config
|
45
47
|
from botocore.waiter import Waiter, WaiterModel
|
46
48
|
from dateutil.tz import tzlocal
|
@@ -50,6 +52,7 @@ from airflow.configuration import conf
|
|
50
52
|
from airflow.exceptions import (
|
51
53
|
AirflowException,
|
52
54
|
AirflowNotFoundException,
|
55
|
+
AirflowProviderDeprecationWarning,
|
53
56
|
)
|
54
57
|
from airflow.hooks.base import BaseHook
|
55
58
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
@@ -747,7 +750,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
747
750
|
|
748
751
|
@property
|
749
752
|
def async_conn(self):
|
753
|
+
"""
|
754
|
+
[DEPRECATED] Get an aiobotocore client to use for async operations.
|
755
|
+
|
756
|
+
This property is deprecated. Accessing it in an async context will cause the event loop to block.
|
757
|
+
Use the async method `get_async_conn` instead.
|
758
|
+
"""
|
759
|
+
warnings.warn(
|
760
|
+
"The property `async_conn` is deprecated. Accessing it in an async context will cause the event loop to block. "
|
761
|
+
"Use the async method `get_async_conn` instead.",
|
762
|
+
AirflowProviderDeprecationWarning,
|
763
|
+
stacklevel=2,
|
764
|
+
)
|
765
|
+
|
766
|
+
return self._get_async_conn()
|
767
|
+
|
768
|
+
async def get_async_conn(self):
|
750
769
|
"""Get an aiobotocore client to use for async operations."""
|
770
|
+
# We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
|
771
|
+
# because one of it's arguments `self.region_name` is a `@property` decorated function
|
772
|
+
# calling the cached property `self.conn_config` at the end.
|
773
|
+
return await sync_to_async(self._get_async_conn)()
|
774
|
+
|
775
|
+
def _get_async_conn(self):
|
751
776
|
if not self.client_type:
|
752
777
|
raise ValueError("client_type must be specified.")
|
753
778
|
|
@@ -918,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
918
943
|
self,
|
919
944
|
waiter_name: str,
|
920
945
|
parameters: dict[str, str] | None = None,
|
946
|
+
config_overrides: dict[str, Any] | None = None,
|
921
947
|
deferrable: bool = False,
|
922
948
|
client=None,
|
923
949
|
) -> Waiter:
|
@@ -937,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
937
963
|
:param parameters: will scan the waiter config for the keys of that dict,
|
938
964
|
and replace them with the corresponding value. If a custom waiter has
|
939
965
|
such keys to be expanded, they need to be provided here.
|
966
|
+
Note: cannot be used if parameters are included in config_overrides
|
967
|
+
:param config_overrides: will update values of provided keys in the waiter's
|
968
|
+
config. Only specified keys will be updated.
|
940
969
|
:param deferrable: If True, the waiter is going to be an async custom waiter.
|
941
970
|
An async client must be provided in that case.
|
942
971
|
:param client: The client to use for the waiter's operations
|
@@ -945,14 +974,18 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
945
974
|
|
946
975
|
if deferrable and not client:
|
947
976
|
raise ValueError("client must be provided for a deferrable waiter.")
|
977
|
+
if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
|
978
|
+
raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
|
948
979
|
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
|
949
980
|
client = client or self._client
|
950
981
|
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
|
951
982
|
# Technically if waiter_name is in custom_waiters then self.waiter_path must
|
952
983
|
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
|
953
984
|
with open(self.waiter_path) as config_file:
|
954
|
-
config = json.loads(config_file.read())
|
985
|
+
config: dict = json.loads(config_file.read())
|
955
986
|
|
987
|
+
if config_overrides is not None:
|
988
|
+
config["waiters"][waiter_name].update(config_overrides)
|
956
989
|
config = self._apply_parameters_value(config, waiter_name, parameters)
|
957
990
|
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
|
958
991
|
waiter_name
|
@@ -173,7 +173,7 @@ class EC2Hook(AwsBaseHook):
|
|
173
173
|
return [instance["InstanceId"] for instance in self.get_instances(filters=filters)]
|
174
174
|
|
175
175
|
async def get_instance_state_async(self, instance_id: str) -> str:
|
176
|
-
async with self.
|
176
|
+
async with await self.get_async_conn() as client:
|
177
177
|
response = await client.describe_instances(InstanceIds=[instance_id])
|
178
178
|
return response["Reservations"][0]["Instances"][0]["State"]["Name"]
|
179
179
|
|
@@ -35,7 +35,6 @@ from botocore.signers import RequestSigner
|
|
35
35
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
36
36
|
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
37
37
|
from airflow.utils import yaml
|
38
|
-
from airflow.utils.json import AirflowJsonEncoder
|
39
38
|
|
40
39
|
DEFAULT_PAGINATION_TOKEN = ""
|
41
40
|
STS_TOKEN_EXPIRES_IN = 60
|
@@ -315,7 +314,7 @@ class EksHook(AwsBaseHook):
|
|
315
314
|
)
|
316
315
|
if verbose:
|
317
316
|
cluster_data = response.get("cluster")
|
318
|
-
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data,
|
317
|
+
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, default=repr))
|
319
318
|
return response
|
320
319
|
|
321
320
|
def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict:
|
@@ -343,7 +342,7 @@ class EksHook(AwsBaseHook):
|
|
343
342
|
nodegroup_data = response.get("nodegroup")
|
344
343
|
self.log.info(
|
345
344
|
"Amazon EKS managed node group details: %s",
|
346
|
-
json.dumps(nodegroup_data,
|
345
|
+
json.dumps(nodegroup_data, default=repr),
|
347
346
|
)
|
348
347
|
return response
|
349
348
|
|
@@ -374,9 +373,7 @@ class EksHook(AwsBaseHook):
|
|
374
373
|
)
|
375
374
|
if verbose:
|
376
375
|
fargate_profile_data = response.get("fargateProfile")
|
377
|
-
self.log.info(
|
378
|
-
"AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder)
|
379
|
-
)
|
376
|
+
self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, default=repr))
|
380
377
|
return response
|
381
378
|
|
382
379
|
def get_cluster_state(self, clusterName: str) -> ClusterStates:
|
@@ -211,7 +211,7 @@ class GlueJobHook(AwsBaseHook):
|
|
211
211
|
|
212
212
|
The async version of get_job_state.
|
213
213
|
"""
|
214
|
-
async with self.
|
214
|
+
async with await self.get_async_conn() as client:
|
215
215
|
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
|
216
216
|
return job_run["JobRun"]["JobRunState"]
|
217
217
|
|
@@ -236,6 +236,9 @@ class GlueJobHook(AwsBaseHook):
|
|
236
236
|
"""
|
237
237
|
log_client = self.logs_hook.get_conn()
|
238
238
|
paginator = log_client.get_paginator("filter_log_events")
|
239
|
+
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]
|
240
|
+
# StartTime needs to be an int and is Epoch time in milliseconds
|
241
|
+
start_time = int(job_run["StartedOn"].timestamp() * 1000)
|
239
242
|
|
240
243
|
def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
|
241
244
|
"""Mutualize iteration over the 2 different log streams glue jobs write to."""
|
@@ -245,6 +248,7 @@ class GlueJobHook(AwsBaseHook):
|
|
245
248
|
for response in paginator.paginate(
|
246
249
|
logGroupName=log_group,
|
247
250
|
logStreamNames=[run_id],
|
251
|
+
startTime=start_time,
|
248
252
|
PaginationConfig={"StartingToken": continuation_token},
|
249
253
|
):
|
250
254
|
fetched_logs.extend([event["message"] for event in response["events"]])
|
@@ -270,7 +274,7 @@ class GlueJobHook(AwsBaseHook):
|
|
270
274
|
self.log.info("No new log from the Glue Job in %s", log_group)
|
271
275
|
return next_token
|
272
276
|
|
273
|
-
log_group_prefix =
|
277
|
+
log_group_prefix = job_run["LogGroupName"]
|
274
278
|
log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
|
275
279
|
log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
|
276
280
|
# one would think that the error log group would contain only errors, but it actually contains
|
@@ -152,7 +152,7 @@ class AwsLogsHook(AwsBaseHook):
|
|
152
152
|
If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
|
153
153
|
:param count: The maximum number of items returned
|
154
154
|
"""
|
155
|
-
async with self.
|
155
|
+
async with await self.get_async_conn() as client:
|
156
156
|
try:
|
157
157
|
response: dict[str, Any] = await client.describe_log_streams(
|
158
158
|
logGroupName=log_group,
|
@@ -194,7 +194,7 @@ class AwsLogsHook(AwsBaseHook):
|
|
194
194
|
else:
|
195
195
|
token_arg = {}
|
196
196
|
|
197
|
-
async with self.
|
197
|
+
async with await self.get_async_conn() as client:
|
198
198
|
response = await client.get_log_events(
|
199
199
|
logGroupName=log_group,
|
200
200
|
logStreamName=log_stream_name,
|