apache-airflow-providers-amazon 9.4.0__py3-none-any.whl → 9.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +21 -100
- airflow/providers/amazon/aws/auth_manager/router/login.py +3 -2
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/hooks/base_aws.py +25 -0
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/base_aws.py +7 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +113 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +31 -42
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +1 -1
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/get_provider_info.py +36 -1
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/METADATA +33 -28
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/RECORD +40 -35
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.
|
32
|
+
__version__ = "9.5.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.9.0"
|
@@ -20,7 +20,7 @@ from enum import Enum
|
|
20
20
|
from typing import TYPE_CHECKING
|
21
21
|
|
22
22
|
if TYPE_CHECKING:
|
23
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
23
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
24
24
|
|
25
25
|
AVP_PREFIX_ENTITIES = "Airflow::"
|
26
26
|
|
@@ -36,7 +36,7 @@ from airflow.utils.helpers import prune_dict
|
|
36
36
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
37
37
|
|
38
38
|
if TYPE_CHECKING:
|
39
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
39
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
40
40
|
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
41
41
|
|
42
42
|
|
@@ -18,15 +18,14 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
import argparse
|
20
20
|
from collections import defaultdict
|
21
|
-
from collections.abc import
|
21
|
+
from collections.abc import Sequence
|
22
22
|
from functools import cached_property
|
23
23
|
from typing import TYPE_CHECKING, Any, cast
|
24
24
|
|
25
25
|
from fastapi import FastAPI
|
26
|
-
from flask import session
|
27
26
|
|
28
|
-
from airflow.auth.managers.base_auth_manager import BaseAuthManager
|
29
|
-
from airflow.auth.managers.models.resource_details import (
|
27
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
|
28
|
+
from airflow.api_fastapi.auth.managers.models.resource_details import (
|
30
29
|
AccessView,
|
31
30
|
ConnectionDetails,
|
32
31
|
DagAccessEntity,
|
@@ -49,16 +48,14 @@ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
|
49
48
|
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
50
49
|
|
51
50
|
if TYPE_CHECKING:
|
52
|
-
from
|
53
|
-
|
54
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
55
|
-
from airflow.auth.managers.models.batch_apis import (
|
51
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
52
|
+
from airflow.api_fastapi.auth.managers.models.batch_apis import (
|
56
53
|
IsAuthorizedConnectionRequest,
|
57
54
|
IsAuthorizedDagRequest,
|
58
55
|
IsAuthorizedPoolRequest,
|
59
56
|
IsAuthorizedVariableRequest,
|
60
57
|
)
|
61
|
-
from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
|
58
|
+
from airflow.api_fastapi.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
|
62
59
|
|
63
60
|
|
64
61
|
class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
@@ -83,14 +80,8 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
83
80
|
return AwsAuthManagerAmazonVerifiedPermissionsFacade()
|
84
81
|
|
85
82
|
@cached_property
|
86
|
-
def
|
87
|
-
return conf.get("
|
88
|
-
|
89
|
-
def get_user(self) -> AwsAuthManagerUser | None:
|
90
|
-
return session["aws_user"] if self.is_logged_in() else None
|
91
|
-
|
92
|
-
def is_logged_in(self) -> bool:
|
93
|
-
return "aws_user" in session
|
83
|
+
def apiserver_endpoint(self) -> str:
|
84
|
+
return conf.get("api", "base_url")
|
94
85
|
|
95
86
|
def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
|
96
87
|
return AwsAuthManagerUser(**token)
|
@@ -162,9 +153,9 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
162
153
|
def is_authorized_asset(
|
163
154
|
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
|
164
155
|
) -> bool:
|
165
|
-
|
156
|
+
asset_id = details.id if details else None
|
166
157
|
return self.avp_facade.is_authorized(
|
167
|
-
method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=
|
158
|
+
method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
|
168
159
|
)
|
169
160
|
|
170
161
|
def is_authorized_pool(
|
@@ -204,7 +195,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
204
195
|
|
205
196
|
def is_authorized_custom_view(
|
206
197
|
self, *, method: ResourceMethod | str, resource_name: str, user: AwsAuthManagerUser
|
207
|
-
):
|
198
|
+
) -> bool:
|
208
199
|
return self.avp_facade.is_authorized(
|
209
200
|
method=method,
|
210
201
|
entity_type=AvpEntities.CUSTOM,
|
@@ -292,23 +283,18 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
292
283
|
*,
|
293
284
|
dag_ids: set[str],
|
294
285
|
user: AwsAuthManagerUser,
|
295
|
-
|
286
|
+
method: ResourceMethod = "GET",
|
296
287
|
):
|
297
|
-
if not methods:
|
298
|
-
methods = ["PUT", "GET"]
|
299
|
-
|
300
288
|
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
|
301
289
|
requests_list: list[IsAuthorizedRequest] = []
|
302
290
|
for dag_id in dag_ids:
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
requests[dag_id][cast("ResourceMethod", method)] = request
|
311
|
-
requests_list.append(request)
|
291
|
+
request: IsAuthorizedRequest = {
|
292
|
+
"method": method,
|
293
|
+
"entity_type": AvpEntities.DAG,
|
294
|
+
"entity_id": dag_id,
|
295
|
+
}
|
296
|
+
requests[dag_id][method] = request
|
297
|
+
requests_list.append(request)
|
312
298
|
|
313
299
|
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
314
300
|
requests=requests_list, user=user
|
@@ -320,67 +306,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
320
306
|
)
|
321
307
|
return result["decision"] == "ALLOW"
|
322
308
|
|
323
|
-
return {
|
324
|
-
dag_id
|
325
|
-
for dag_id in dag_ids
|
326
|
-
if (
|
327
|
-
"GET" in methods
|
328
|
-
and _has_access_to_dag(requests[dag_id]["GET"])
|
329
|
-
or "PUT" in methods
|
330
|
-
and _has_access_to_dag(requests[dag_id]["PUT"])
|
331
|
-
)
|
332
|
-
}
|
333
|
-
|
334
|
-
def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]:
|
335
|
-
"""
|
336
|
-
Filter menu items based on user permissions.
|
337
|
-
|
338
|
-
:param menu_items: list of all menu items
|
339
|
-
"""
|
340
|
-
user = self.get_user()
|
341
|
-
if not user:
|
342
|
-
return []
|
343
|
-
|
344
|
-
requests: dict[str, IsAuthorizedRequest] = {}
|
345
|
-
for menu_item in menu_items:
|
346
|
-
if menu_item.childs:
|
347
|
-
for child in menu_item.childs:
|
348
|
-
requests[child.name] = self._get_menu_item_request(child.name)
|
349
|
-
else:
|
350
|
-
requests[menu_item.name] = self._get_menu_item_request(menu_item.name)
|
351
|
-
|
352
|
-
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
353
|
-
requests=list(requests.values()), user=user
|
354
|
-
)
|
355
|
-
|
356
|
-
def _has_access_to_menu_item(request: IsAuthorizedRequest):
|
357
|
-
result = self.avp_facade.get_batch_is_authorized_single_result(
|
358
|
-
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
359
|
-
)
|
360
|
-
return result["decision"] == "ALLOW"
|
361
|
-
|
362
|
-
accessible_items = []
|
363
|
-
for menu_item in menu_items:
|
364
|
-
if menu_item.childs:
|
365
|
-
accessible_children = []
|
366
|
-
for child in menu_item.childs:
|
367
|
-
if _has_access_to_menu_item(requests[child.name]):
|
368
|
-
accessible_children.append(child)
|
369
|
-
menu_item.childs = accessible_children
|
370
|
-
|
371
|
-
# Display the menu if the user has access to at least one sub item
|
372
|
-
if len(accessible_children) > 0:
|
373
|
-
accessible_items.append(menu_item)
|
374
|
-
elif _has_access_to_menu_item(requests[menu_item.name]):
|
375
|
-
accessible_items.append(menu_item)
|
376
|
-
|
377
|
-
return accessible_items
|
309
|
+
return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
|
378
310
|
|
379
311
|
def get_url_login(self, **kwargs) -> str:
|
380
|
-
return f"{self.
|
381
|
-
|
382
|
-
def get_url_logout(self) -> str:
|
383
|
-
raise NotImplementedError()
|
312
|
+
return f"{self.apiserver_endpoint}/auth/login"
|
384
313
|
|
385
314
|
@staticmethod
|
386
315
|
def get_cli_commands() -> list[CLICommand]:
|
@@ -408,14 +337,6 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
408
337
|
|
409
338
|
return app
|
410
339
|
|
411
|
-
@staticmethod
|
412
|
-
def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
|
413
|
-
return {
|
414
|
-
"method": "MENU",
|
415
|
-
"entity_type": AvpEntities.MENU,
|
416
|
-
"entity_id": resource_name,
|
417
|
-
}
|
418
|
-
|
419
340
|
def _check_avp_schema_version(self):
|
420
341
|
if not self.avp_facade.is_policy_store_schema_up_to_date():
|
421
342
|
self.log.warning(
|
@@ -79,12 +79,13 @@ def login_callback(request: Request):
|
|
79
79
|
username=saml_auth.get_nameid(),
|
80
80
|
email=attributes["email"][0] if "email" in attributes else None,
|
81
81
|
)
|
82
|
-
|
82
|
+
url = f"{conf.get('api', 'base_url')}/?token={get_auth_manager().get_jwt_token(user)}"
|
83
|
+
return RedirectResponse(url=url, status_code=303)
|
83
84
|
|
84
85
|
|
85
86
|
def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
86
87
|
request_data = _prepare_request(request)
|
87
|
-
base_url = conf.get(section="
|
88
|
+
base_url = conf.get(section="api", key="base_url")
|
88
89
|
settings = {
|
89
90
|
# We want to keep this flag on in case of errors.
|
90
91
|
# It provides an error reasons, if turned off, it does not
|
@@ -19,11 +19,14 @@ from __future__ import annotations
|
|
19
19
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
20
20
|
|
21
21
|
try:
|
22
|
-
from airflow.auth.managers.models.base_user import BaseUser
|
22
|
+
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
|
23
23
|
except ImportError:
|
24
|
-
|
25
|
-
|
26
|
-
|
24
|
+
try:
|
25
|
+
from airflow.auth.managers.models.base_user import BaseUser # type: ignore[no-redef]
|
26
|
+
except ImportError:
|
27
|
+
raise AirflowOptionalProviderFeatureException(
|
28
|
+
"Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
|
29
|
+
) from None
|
27
30
|
|
28
31
|
|
29
32
|
class AwsAuthManagerUser(BaseUser):
|
@@ -30,6 +30,7 @@ import inspect
|
|
30
30
|
import json
|
31
31
|
import logging
|
32
32
|
import os
|
33
|
+
import warnings
|
33
34
|
from copy import deepcopy
|
34
35
|
from functools import cached_property, wraps
|
35
36
|
from pathlib import Path
|
@@ -41,6 +42,7 @@ import botocore.session
|
|
41
42
|
import jinja2
|
42
43
|
import requests
|
43
44
|
import tenacity
|
45
|
+
from asgiref.sync import sync_to_async
|
44
46
|
from botocore.config import Config
|
45
47
|
from botocore.waiter import Waiter, WaiterModel
|
46
48
|
from dateutil.tz import tzlocal
|
@@ -50,6 +52,7 @@ from airflow.configuration import conf
|
|
50
52
|
from airflow.exceptions import (
|
51
53
|
AirflowException,
|
52
54
|
AirflowNotFoundException,
|
55
|
+
AirflowProviderDeprecationWarning,
|
53
56
|
)
|
54
57
|
from airflow.hooks.base import BaseHook
|
55
58
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
@@ -747,7 +750,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
747
750
|
|
748
751
|
@property
|
749
752
|
def async_conn(self):
|
753
|
+
"""
|
754
|
+
[DEPRECATED] Get an aiobotocore client to use for async operations.
|
755
|
+
|
756
|
+
This property is deprecated. Accessing it in an async context will cause the event loop to block.
|
757
|
+
Use the async method `get_async_conn` instead.
|
758
|
+
"""
|
759
|
+
warnings.warn(
|
760
|
+
"The property `async_conn` is deprecated. Accessing it in an async context will cause the event loop to block. "
|
761
|
+
"Use the async method `get_async_conn` instead.",
|
762
|
+
AirflowProviderDeprecationWarning,
|
763
|
+
stacklevel=2,
|
764
|
+
)
|
765
|
+
|
766
|
+
return self._get_async_conn()
|
767
|
+
|
768
|
+
async def get_async_conn(self):
|
750
769
|
"""Get an aiobotocore client to use for async operations."""
|
770
|
+
# We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
|
771
|
+
# because one of it's arguments `self.region_name` is a `@property` decorated function
|
772
|
+
# calling the cached property `self.conn_config` at the end.
|
773
|
+
return await sync_to_async(self._get_async_conn)()
|
774
|
+
|
775
|
+
def _get_async_conn(self):
|
751
776
|
if not self.client_type:
|
752
777
|
raise ValueError("client_type must be specified.")
|
753
778
|
|
@@ -173,7 +173,7 @@ class EC2Hook(AwsBaseHook):
|
|
173
173
|
return [instance["InstanceId"] for instance in self.get_instances(filters=filters)]
|
174
174
|
|
175
175
|
async def get_instance_state_async(self, instance_id: str) -> str:
|
176
|
-
async with self.
|
176
|
+
async with await self.get_async_conn() as client:
|
177
177
|
response = await client.describe_instances(InstanceIds=[instance_id])
|
178
178
|
return response["Reservations"][0]["Instances"][0]["State"]["Name"]
|
179
179
|
|
@@ -211,7 +211,7 @@ class GlueJobHook(AwsBaseHook):
|
|
211
211
|
|
212
212
|
The async version of get_job_state.
|
213
213
|
"""
|
214
|
-
async with self.
|
214
|
+
async with await self.get_async_conn() as client:
|
215
215
|
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
|
216
216
|
return job_run["JobRun"]["JobRunState"]
|
217
217
|
|
@@ -236,6 +236,9 @@ class GlueJobHook(AwsBaseHook):
|
|
236
236
|
"""
|
237
237
|
log_client = self.logs_hook.get_conn()
|
238
238
|
paginator = log_client.get_paginator("filter_log_events")
|
239
|
+
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]
|
240
|
+
# StartTime needs to be an int and is Epoch time in milliseconds
|
241
|
+
start_time = int(job_run["StartedOn"].timestamp() * 1000)
|
239
242
|
|
240
243
|
def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
|
241
244
|
"""Mutualize iteration over the 2 different log streams glue jobs write to."""
|
@@ -245,6 +248,7 @@ class GlueJobHook(AwsBaseHook):
|
|
245
248
|
for response in paginator.paginate(
|
246
249
|
logGroupName=log_group,
|
247
250
|
logStreamNames=[run_id],
|
251
|
+
startTime=start_time,
|
248
252
|
PaginationConfig={"StartingToken": continuation_token},
|
249
253
|
):
|
250
254
|
fetched_logs.extend([event["message"] for event in response["events"]])
|
@@ -270,7 +274,7 @@ class GlueJobHook(AwsBaseHook):
|
|
270
274
|
self.log.info("No new log from the Glue Job in %s", log_group)
|
271
275
|
return next_token
|
272
276
|
|
273
|
-
log_group_prefix =
|
277
|
+
log_group_prefix = job_run["LogGroupName"]
|
274
278
|
log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
|
275
279
|
log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
|
276
280
|
# one would think that the error log group would contain only errors, but it actually contains
|
@@ -152,7 +152,7 @@ class AwsLogsHook(AwsBaseHook):
|
|
152
152
|
If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
|
153
153
|
:param count: The maximum number of items returned
|
154
154
|
"""
|
155
|
-
async with self.
|
155
|
+
async with await self.get_async_conn() as client:
|
156
156
|
try:
|
157
157
|
response: dict[str, Any] = await client.describe_log_streams(
|
158
158
|
logGroupName=log_group,
|
@@ -194,7 +194,7 @@ class AwsLogsHook(AwsBaseHook):
|
|
194
194
|
else:
|
195
195
|
token_arg = {}
|
196
196
|
|
197
|
-
async with self.
|
197
|
+
async with await self.get_async_conn() as client:
|
198
198
|
response = await client.get_log_events(
|
199
199
|
logGroupName=log_group,
|
200
200
|
logStreamName=log_stream_name,
|
@@ -18,6 +18,7 @@
|
|
18
18
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
|
+
import requests
|
21
22
|
from botocore.exceptions import ClientError
|
22
23
|
|
23
24
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
|
|
29
30
|
|
30
31
|
Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
|
31
32
|
|
33
|
+
If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
|
34
|
+
that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
|
35
|
+
make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
|
36
|
+
setting `generate_local_token` to True. Learn more here:
|
37
|
+
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
|
38
|
+
|
32
39
|
Additional arguments (such as ``aws_conn_id``) may be specified and
|
33
40
|
are passed down to the underlying AwsBaseHook.
|
34
41
|
|
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
|
|
47
54
|
method: str,
|
48
55
|
body: dict | None = None,
|
49
56
|
query_params: dict | None = None,
|
57
|
+
generate_local_token: bool = False,
|
50
58
|
) -> dict:
|
51
59
|
"""
|
52
60
|
Invoke the REST API on the Airflow webserver with the specified inputs.
|
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
|
|
56
64
|
|
57
65
|
:param env_name: name of the MWAA environment
|
58
66
|
:param path: Apache Airflow REST API endpoint path to be called
|
59
|
-
:param method: HTTP method used for making Airflow REST API calls
|
67
|
+
:param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
|
60
68
|
:param body: Request body for the Apache Airflow REST API call
|
61
69
|
:param query_params: Query parameters to be included in the Apache Airflow REST API call
|
70
|
+
:param generate_local_token: If True, only the local web token method is used without trying boto's
|
71
|
+
`invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
|
72
|
+
boto's `invoke_rest_api`
|
62
73
|
"""
|
63
|
-
|
74
|
+
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
|
75
|
+
body = {k: v for k, v in body.items() if v is not None} if body else {}
|
76
|
+
query_params = query_params or {}
|
64
77
|
api_kwargs = {
|
65
78
|
"Name": env_name,
|
66
79
|
"Path": path,
|
67
80
|
"Method": method,
|
68
|
-
|
69
|
-
"
|
70
|
-
"QueryParameters": query_params if query_params else {},
|
81
|
+
"Body": body,
|
82
|
+
"QueryParameters": query_params,
|
71
83
|
}
|
84
|
+
|
85
|
+
if generate_local_token:
|
86
|
+
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
87
|
+
|
72
88
|
try:
|
73
|
-
|
89
|
+
response = self.conn.invoke_rest_api(**api_kwargs)
|
74
90
|
# ResponseMetadata is removed because it contains data that is either very unlikely to be useful
|
75
91
|
# in XComs and logs, or redundant given the data already included in the response
|
76
|
-
|
77
|
-
return
|
92
|
+
response.pop("ResponseMetadata", None)
|
93
|
+
return response
|
94
|
+
|
78
95
|
except ClientError as e:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
96
|
+
if (
|
97
|
+
e.response["Error"]["Code"] == "AccessDeniedException"
|
98
|
+
and "Airflow role" in e.response["Error"]["Message"]
|
99
|
+
):
|
100
|
+
self.log.info(
|
101
|
+
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
|
102
|
+
)
|
103
|
+
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
104
|
+
else:
|
105
|
+
to_log = e.response
|
106
|
+
# ResponseMetadata is removed because it contains data that is either very unlikely to be
|
107
|
+
# useful in XComs and logs, or redundant given the data already included in the response
|
108
|
+
to_log.pop("ResponseMetadata", None)
|
109
|
+
self.log.error(to_log)
|
110
|
+
raise
|
111
|
+
|
112
|
+
def _invoke_rest_api_using_local_session_token(
|
113
|
+
self,
|
114
|
+
**api_kwargs,
|
115
|
+
) -> dict:
|
116
|
+
try:
|
117
|
+
session, hostname = self._get_session_conn(api_kwargs["Name"])
|
118
|
+
|
119
|
+
response = session.request(
|
120
|
+
method=api_kwargs["Method"],
|
121
|
+
url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
|
122
|
+
params=api_kwargs["QueryParameters"],
|
123
|
+
json=api_kwargs["Body"],
|
124
|
+
timeout=10,
|
125
|
+
)
|
126
|
+
response.raise_for_status()
|
127
|
+
|
128
|
+
except requests.HTTPError as e:
|
129
|
+
self.log.error(e.response.json())
|
130
|
+
raise
|
131
|
+
|
132
|
+
return {
|
133
|
+
"RestApiStatusCode": response.status_code,
|
134
|
+
"RestApiResponse": response.json(),
|
135
|
+
}
|
136
|
+
|
137
|
+
# Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
|
138
|
+
def _get_session_conn(self, env_name: str) -> tuple:
|
139
|
+
create_token_response = self.conn.create_web_login_token(Name=env_name)
|
140
|
+
web_server_hostname = create_token_response["WebServerHostname"]
|
141
|
+
web_token = create_token_response["WebToken"]
|
142
|
+
|
143
|
+
login_url = f"https://{web_server_hostname}/aws_mwaa/login"
|
144
|
+
login_payload = {"token": web_token}
|
145
|
+
session = requests.Session()
|
146
|
+
login_response = session.post(login_url, data=login_payload, timeout=10)
|
147
|
+
login_response.raise_for_status()
|
148
|
+
|
149
|
+
return session, web_server_hostname
|
@@ -93,7 +93,7 @@ class RedshiftHook(AwsBaseHook):
|
|
93
93
|
return "cluster_not_found"
|
94
94
|
|
95
95
|
async def cluster_status_async(self, cluster_identifier: str) -> str:
|
96
|
-
async with self.
|
96
|
+
async with await self.get_async_conn() as client:
|
97
97
|
response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
|
98
98
|
return response["Clusters"][0]["ClusterStatus"] if response else None
|
99
99
|
|
@@ -275,7 +275,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
275
275
|
|
276
276
|
:param statement_id: the UUID of the statement
|
277
277
|
"""
|
278
|
-
async with self.
|
278
|
+
async with await self.get_async_conn() as client:
|
279
279
|
desc = await client.describe_statement(Id=statement_id)
|
280
280
|
return desc["Status"] in RUNNING_STATES
|
281
281
|
|
@@ -288,6 +288,6 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
288
288
|
|
289
289
|
:param statement_id: the UUID of the statement
|
290
290
|
"""
|
291
|
-
async with self.
|
291
|
+
async with await self.get_async_conn() as client:
|
292
292
|
resp = await client.describe_statement(Id=statement_id)
|
293
293
|
return self.parse_statement_response(resp)
|
@@ -1318,7 +1318,7 @@ class SageMakerHook(AwsBaseHook):
|
|
1318
1318
|
|
1319
1319
|
:param job_name: the name of the training job
|
1320
1320
|
"""
|
1321
|
-
async with self.
|
1321
|
+
async with await self.get_async_conn() as client:
|
1322
1322
|
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
|
1323
1323
|
return response
|
1324
1324
|
|