apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +21 -100
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +3 -2
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/hooks/base_aws.py +25 -0
  8. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  9. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  10. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  11. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  12. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
  13. airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
  14. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  15. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  16. airflow/providers/amazon/aws/links/base_aws.py +7 -1
  17. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  18. airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
  19. airflow/providers/amazon/aws/operators/s3.py +147 -157
  20. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  21. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  22. airflow/providers/amazon/aws/sensors/mwaa.py +113 -0
  23. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  24. airflow/providers/amazon/aws/sensors/s3.py +31 -42
  25. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  26. airflow/providers/amazon/aws/triggers/README.md +4 -4
  27. airflow/providers/amazon/aws/triggers/base.py +1 -1
  28. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  29. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  30. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  31. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  32. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  33. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  34. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  35. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  36. airflow/providers/amazon/get_provider_info.py +36 -1
  37. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/METADATA +30 -25
  38. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/RECORD +40 -35
  39. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/WHEEL +1 -1
  40. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  41. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "9.4.0"
32
+ __version__ = "9.5.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.9.0"
@@ -20,7 +20,7 @@ from enum import Enum
20
20
  from typing import TYPE_CHECKING
21
21
 
22
22
  if TYPE_CHECKING:
23
- from airflow.auth.managers.base_auth_manager import ResourceMethod
23
+ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
24
24
 
25
25
  AVP_PREFIX_ENTITIES = "Airflow::"
26
26
 
@@ -36,7 +36,7 @@ from airflow.utils.helpers import prune_dict
36
36
  from airflow.utils.log.logging_mixin import LoggingMixin
37
37
 
38
38
  if TYPE_CHECKING:
39
- from airflow.auth.managers.base_auth_manager import ResourceMethod
39
+ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
40
40
  from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
41
41
 
42
42
 
@@ -18,15 +18,14 @@ from __future__ import annotations
18
18
 
19
19
  import argparse
20
20
  from collections import defaultdict
21
- from collections.abc import Container, Sequence
21
+ from collections.abc import Sequence
22
22
  from functools import cached_property
23
23
  from typing import TYPE_CHECKING, Any, cast
24
24
 
25
25
  from fastapi import FastAPI
26
- from flask import session
27
26
 
28
- from airflow.auth.managers.base_auth_manager import BaseAuthManager
29
- from airflow.auth.managers.models.resource_details import (
27
+ from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
28
+ from airflow.api_fastapi.auth.managers.models.resource_details import (
30
29
  AccessView,
31
30
  ConnectionDetails,
32
31
  DagAccessEntity,
@@ -49,16 +48,14 @@ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
49
48
  from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
50
49
 
51
50
  if TYPE_CHECKING:
52
- from flask_appbuilder.menu import MenuItem
53
-
54
- from airflow.auth.managers.base_auth_manager import ResourceMethod
55
- from airflow.auth.managers.models.batch_apis import (
51
+ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
52
+ from airflow.api_fastapi.auth.managers.models.batch_apis import (
56
53
  IsAuthorizedConnectionRequest,
57
54
  IsAuthorizedDagRequest,
58
55
  IsAuthorizedPoolRequest,
59
56
  IsAuthorizedVariableRequest,
60
57
  )
61
- from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
58
+ from airflow.api_fastapi.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
62
59
 
63
60
 
64
61
  class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
@@ -83,14 +80,8 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
83
80
  return AwsAuthManagerAmazonVerifiedPermissionsFacade()
84
81
 
85
82
  @cached_property
86
- def fastapi_endpoint(self) -> str:
87
- return conf.get("fastapi", "base_url")
88
-
89
- def get_user(self) -> AwsAuthManagerUser | None:
90
- return session["aws_user"] if self.is_logged_in() else None
91
-
92
- def is_logged_in(self) -> bool:
93
- return "aws_user" in session
83
+ def apiserver_endpoint(self) -> str:
84
+ return conf.get("api", "base_url")
94
85
 
95
86
  def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
96
87
  return AwsAuthManagerUser(**token)
@@ -162,9 +153,9 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
162
153
  def is_authorized_asset(
163
154
  self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
164
155
  ) -> bool:
165
- asset_uri = details.uri if details else None
156
+ asset_id = details.id if details else None
166
157
  return self.avp_facade.is_authorized(
167
- method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_uri
158
+ method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
168
159
  )
169
160
 
170
161
  def is_authorized_pool(
@@ -204,7 +195,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
204
195
 
205
196
  def is_authorized_custom_view(
206
197
  self, *, method: ResourceMethod | str, resource_name: str, user: AwsAuthManagerUser
207
- ):
198
+ ) -> bool:
208
199
  return self.avp_facade.is_authorized(
209
200
  method=method,
210
201
  entity_type=AvpEntities.CUSTOM,
@@ -292,23 +283,18 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
292
283
  *,
293
284
  dag_ids: set[str],
294
285
  user: AwsAuthManagerUser,
295
- methods: Container[ResourceMethod] | None = None,
286
+ method: ResourceMethod = "GET",
296
287
  ):
297
- if not methods:
298
- methods = ["PUT", "GET"]
299
-
300
288
  requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
301
289
  requests_list: list[IsAuthorizedRequest] = []
302
290
  for dag_id in dag_ids:
303
- for method in ["GET", "PUT"]:
304
- if method in methods:
305
- request: IsAuthorizedRequest = {
306
- "method": cast("ResourceMethod", method),
307
- "entity_type": AvpEntities.DAG,
308
- "entity_id": dag_id,
309
- }
310
- requests[dag_id][cast("ResourceMethod", method)] = request
311
- requests_list.append(request)
291
+ request: IsAuthorizedRequest = {
292
+ "method": method,
293
+ "entity_type": AvpEntities.DAG,
294
+ "entity_id": dag_id,
295
+ }
296
+ requests[dag_id][method] = request
297
+ requests_list.append(request)
312
298
 
313
299
  batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
314
300
  requests=requests_list, user=user
@@ -320,67 +306,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
320
306
  )
321
307
  return result["decision"] == "ALLOW"
322
308
 
323
- return {
324
- dag_id
325
- for dag_id in dag_ids
326
- if (
327
- "GET" in methods
328
- and _has_access_to_dag(requests[dag_id]["GET"])
329
- or "PUT" in methods
330
- and _has_access_to_dag(requests[dag_id]["PUT"])
331
- )
332
- }
333
-
334
- def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]:
335
- """
336
- Filter menu items based on user permissions.
337
-
338
- :param menu_items: list of all menu items
339
- """
340
- user = self.get_user()
341
- if not user:
342
- return []
343
-
344
- requests: dict[str, IsAuthorizedRequest] = {}
345
- for menu_item in menu_items:
346
- if menu_item.childs:
347
- for child in menu_item.childs:
348
- requests[child.name] = self._get_menu_item_request(child.name)
349
- else:
350
- requests[menu_item.name] = self._get_menu_item_request(menu_item.name)
351
-
352
- batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
353
- requests=list(requests.values()), user=user
354
- )
355
-
356
- def _has_access_to_menu_item(request: IsAuthorizedRequest):
357
- result = self.avp_facade.get_batch_is_authorized_single_result(
358
- batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
359
- )
360
- return result["decision"] == "ALLOW"
361
-
362
- accessible_items = []
363
- for menu_item in menu_items:
364
- if menu_item.childs:
365
- accessible_children = []
366
- for child in menu_item.childs:
367
- if _has_access_to_menu_item(requests[child.name]):
368
- accessible_children.append(child)
369
- menu_item.childs = accessible_children
370
-
371
- # Display the menu if the user has access to at least one sub item
372
- if len(accessible_children) > 0:
373
- accessible_items.append(menu_item)
374
- elif _has_access_to_menu_item(requests[menu_item.name]):
375
- accessible_items.append(menu_item)
376
-
377
- return accessible_items
309
+ return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
378
310
 
379
311
  def get_url_login(self, **kwargs) -> str:
380
- return f"{self.fastapi_endpoint}/auth/login"
381
-
382
- def get_url_logout(self) -> str:
383
- raise NotImplementedError()
312
+ return f"{self.apiserver_endpoint}/auth/login"
384
313
 
385
314
  @staticmethod
386
315
  def get_cli_commands() -> list[CLICommand]:
@@ -408,14 +337,6 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
408
337
 
409
338
  return app
410
339
 
411
- @staticmethod
412
- def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
413
- return {
414
- "method": "MENU",
415
- "entity_type": AvpEntities.MENU,
416
- "entity_id": resource_name,
417
- }
418
-
419
340
  def _check_avp_schema_version(self):
420
341
  if not self.avp_facade.is_policy_store_schema_up_to_date():
421
342
  self.log.warning(
@@ -79,12 +79,13 @@ def login_callback(request: Request):
79
79
  username=saml_auth.get_nameid(),
80
80
  email=attributes["email"][0] if "email" in attributes else None,
81
81
  )
82
- return RedirectResponse(url=f"/webapp?token={get_auth_manager().get_jwt_token(user)}", status_code=303)
82
+ url = f"{conf.get('api', 'base_url')}/?token={get_auth_manager().get_jwt_token(user)}"
83
+ return RedirectResponse(url=url, status_code=303)
83
84
 
84
85
 
85
86
  def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
86
87
  request_data = _prepare_request(request)
87
- base_url = conf.get(section="fastapi", key="base_url")
88
+ base_url = conf.get(section="api", key="base_url")
88
89
  settings = {
89
90
  # We want to keep this flag on in case of errors.
90
91
  # It provides an error reasons, if turned off, it does not
@@ -19,11 +19,14 @@ from __future__ import annotations
19
19
  from airflow.exceptions import AirflowOptionalProviderFeatureException
20
20
 
21
21
  try:
22
- from airflow.auth.managers.models.base_user import BaseUser
22
+ from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
23
23
  except ImportError:
24
- raise AirflowOptionalProviderFeatureException(
25
- "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
26
- )
24
+ try:
25
+ from airflow.auth.managers.models.base_user import BaseUser # type: ignore[no-redef]
26
+ except ImportError:
27
+ raise AirflowOptionalProviderFeatureException(
28
+ "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
29
+ ) from None
27
30
 
28
31
 
29
32
  class AwsAuthManagerUser(BaseUser):
@@ -30,6 +30,7 @@ import inspect
30
30
  import json
31
31
  import logging
32
32
  import os
33
+ import warnings
33
34
  from copy import deepcopy
34
35
  from functools import cached_property, wraps
35
36
  from pathlib import Path
@@ -41,6 +42,7 @@ import botocore.session
41
42
  import jinja2
42
43
  import requests
43
44
  import tenacity
45
+ from asgiref.sync import sync_to_async
44
46
  from botocore.config import Config
45
47
  from botocore.waiter import Waiter, WaiterModel
46
48
  from dateutil.tz import tzlocal
@@ -50,6 +52,7 @@ from airflow.configuration import conf
50
52
  from airflow.exceptions import (
51
53
  AirflowException,
52
54
  AirflowNotFoundException,
55
+ AirflowProviderDeprecationWarning,
53
56
  )
54
57
  from airflow.hooks.base import BaseHook
55
58
  from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
@@ -747,7 +750,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
747
750
 
748
751
  @property
749
752
  def async_conn(self):
753
+ """
754
+ [DEPRECATED] Get an aiobotocore client to use for async operations.
755
+
756
+ This property is deprecated. Accessing it in an async context will cause the event loop to block.
757
+ Use the async method `get_async_conn` instead.
758
+ """
759
+ warnings.warn(
760
+ "The property `async_conn` is deprecated. Accessing it in an async context will cause the event loop to block. "
761
+ "Use the async method `get_async_conn` instead.",
762
+ AirflowProviderDeprecationWarning,
763
+ stacklevel=2,
764
+ )
765
+
766
+ return self._get_async_conn()
767
+
768
+ async def get_async_conn(self):
750
769
  """Get an aiobotocore client to use for async operations."""
770
+ # We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
771
+ # because one of it's arguments `self.region_name` is a `@property` decorated function
772
+ # calling the cached property `self.conn_config` at the end.
773
+ return await sync_to_async(self._get_async_conn)()
774
+
775
+ def _get_async_conn(self):
751
776
  if not self.client_type:
752
777
  raise ValueError("client_type must be specified.")
753
778
 
@@ -173,7 +173,7 @@ class EC2Hook(AwsBaseHook):
173
173
  return [instance["InstanceId"] for instance in self.get_instances(filters=filters)]
174
174
 
175
175
  async def get_instance_state_async(self, instance_id: str) -> str:
176
- async with self.async_conn as client:
176
+ async with await self.get_async_conn() as client:
177
177
  response = await client.describe_instances(InstanceIds=[instance_id])
178
178
  return response["Reservations"][0]["Instances"][0]["State"]["Name"]
179
179
 
@@ -211,7 +211,7 @@ class GlueJobHook(AwsBaseHook):
211
211
 
212
212
  The async version of get_job_state.
213
213
  """
214
- async with self.async_conn as client:
214
+ async with await self.get_async_conn() as client:
215
215
  job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
216
216
  return job_run["JobRun"]["JobRunState"]
217
217
 
@@ -236,6 +236,9 @@ class GlueJobHook(AwsBaseHook):
236
236
  """
237
237
  log_client = self.logs_hook.get_conn()
238
238
  paginator = log_client.get_paginator("filter_log_events")
239
+ job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]
240
+ # StartTime needs to be an int and is Epoch time in milliseconds
241
+ start_time = int(job_run["StartedOn"].timestamp() * 1000)
239
242
 
240
243
  def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
241
244
  """Mutualize iteration over the 2 different log streams glue jobs write to."""
@@ -245,6 +248,7 @@ class GlueJobHook(AwsBaseHook):
245
248
  for response in paginator.paginate(
246
249
  logGroupName=log_group,
247
250
  logStreamNames=[run_id],
251
+ startTime=start_time,
248
252
  PaginationConfig={"StartingToken": continuation_token},
249
253
  ):
250
254
  fetched_logs.extend([event["message"] for event in response["events"]])
@@ -270,7 +274,7 @@ class GlueJobHook(AwsBaseHook):
270
274
  self.log.info("No new log from the Glue Job in %s", log_group)
271
275
  return next_token
272
276
 
273
- log_group_prefix = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"]
277
+ log_group_prefix = job_run["LogGroupName"]
274
278
  log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
275
279
  log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
276
280
  # one would think that the error log group would contain only errors, but it actually contains
@@ -152,7 +152,7 @@ class AwsLogsHook(AwsBaseHook):
152
152
  If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
153
153
  :param count: The maximum number of items returned
154
154
  """
155
- async with self.async_conn as client:
155
+ async with await self.get_async_conn() as client:
156
156
  try:
157
157
  response: dict[str, Any] = await client.describe_log_streams(
158
158
  logGroupName=log_group,
@@ -194,7 +194,7 @@ class AwsLogsHook(AwsBaseHook):
194
194
  else:
195
195
  token_arg = {}
196
196
 
197
- async with self.async_conn as client:
197
+ async with await self.get_async_conn() as client:
198
198
  response = await client.get_log_events(
199
199
  logGroupName=log_group,
200
200
  logStreamName=log_stream_name,
@@ -18,6 +18,7 @@
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import requests
21
22
  from botocore.exceptions import ClientError
22
23
 
23
24
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
29
30
 
30
31
  Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
31
32
 
33
+ If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
34
+ that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
35
+ make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
36
+ setting `generate_local_token` to True. Learn more here:
37
+ https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
38
+
32
39
  Additional arguments (such as ``aws_conn_id``) may be specified and
33
40
  are passed down to the underlying AwsBaseHook.
34
41
 
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
47
54
  method: str,
48
55
  body: dict | None = None,
49
56
  query_params: dict | None = None,
57
+ generate_local_token: bool = False,
50
58
  ) -> dict:
51
59
  """
52
60
  Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
56
64
 
57
65
  :param env_name: name of the MWAA environment
58
66
  :param path: Apache Airflow REST API endpoint path to be called
59
- :param method: HTTP method used for making Airflow REST API calls
67
+ :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
60
68
  :param body: Request body for the Apache Airflow REST API call
61
69
  :param query_params: Query parameters to be included in the Apache Airflow REST API call
70
+ :param generate_local_token: If True, only the local web token method is used without trying boto's
71
+ `invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
72
+ boto's `invoke_rest_api`
62
73
  """
63
- body = body or {}
74
+ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
75
+ body = {k: v for k, v in body.items() if v is not None} if body else {}
76
+ query_params = query_params or {}
64
77
  api_kwargs = {
65
78
  "Name": env_name,
66
79
  "Path": path,
67
80
  "Method": method,
68
- # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
69
- "Body": {k: v for k, v in body.items() if v is not None},
70
- "QueryParameters": query_params if query_params else {},
81
+ "Body": body,
82
+ "QueryParameters": query_params,
71
83
  }
84
+
85
+ if generate_local_token:
86
+ return self._invoke_rest_api_using_local_session_token(**api_kwargs)
87
+
72
88
  try:
73
- result = self.conn.invoke_rest_api(**api_kwargs)
89
+ response = self.conn.invoke_rest_api(**api_kwargs)
74
90
  # ResponseMetadata is removed because it contains data that is either very unlikely to be useful
75
91
  # in XComs and logs, or redundant given the data already included in the response
76
- result.pop("ResponseMetadata", None)
77
- return result
92
+ response.pop("ResponseMetadata", None)
93
+ return response
94
+
78
95
  except ClientError as e:
79
- to_log = e.response
80
- # ResponseMetadata and Error are removed because they contain data that is either very unlikely to
81
- # be useful in XComs and logs, or redundant given the data already included in the response
82
- to_log.pop("ResponseMetadata", None)
83
- to_log.pop("Error", None)
84
- self.log.error(to_log)
85
- raise e
96
+ if (
97
+ e.response["Error"]["Code"] == "AccessDeniedException"
98
+ and "Airflow role" in e.response["Error"]["Message"]
99
+ ):
100
+ self.log.info(
101
+ "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
102
+ )
103
+ return self._invoke_rest_api_using_local_session_token(**api_kwargs)
104
+ else:
105
+ to_log = e.response
106
+ # ResponseMetadata is removed because it contains data that is either very unlikely to be
107
+ # useful in XComs and logs, or redundant given the data already included in the response
108
+ to_log.pop("ResponseMetadata", None)
109
+ self.log.error(to_log)
110
+ raise
111
+
112
+ def _invoke_rest_api_using_local_session_token(
113
+ self,
114
+ **api_kwargs,
115
+ ) -> dict:
116
+ try:
117
+ session, hostname = self._get_session_conn(api_kwargs["Name"])
118
+
119
+ response = session.request(
120
+ method=api_kwargs["Method"],
121
+ url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
122
+ params=api_kwargs["QueryParameters"],
123
+ json=api_kwargs["Body"],
124
+ timeout=10,
125
+ )
126
+ response.raise_for_status()
127
+
128
+ except requests.HTTPError as e:
129
+ self.log.error(e.response.json())
130
+ raise
131
+
132
+ return {
133
+ "RestApiStatusCode": response.status_code,
134
+ "RestApiResponse": response.json(),
135
+ }
136
+
137
+ # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
138
+ def _get_session_conn(self, env_name: str) -> tuple:
139
+ create_token_response = self.conn.create_web_login_token(Name=env_name)
140
+ web_server_hostname = create_token_response["WebServerHostname"]
141
+ web_token = create_token_response["WebToken"]
142
+
143
+ login_url = f"https://{web_server_hostname}/aws_mwaa/login"
144
+ login_payload = {"token": web_token}
145
+ session = requests.Session()
146
+ login_response = session.post(login_url, data=login_payload, timeout=10)
147
+ login_response.raise_for_status()
148
+
149
+ return session, web_server_hostname
@@ -93,7 +93,7 @@ class RedshiftHook(AwsBaseHook):
93
93
  return "cluster_not_found"
94
94
 
95
95
  async def cluster_status_async(self, cluster_identifier: str) -> str:
96
- async with self.async_conn as client:
96
+ async with await self.get_async_conn() as client:
97
97
  response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
98
98
  return response["Clusters"][0]["ClusterStatus"] if response else None
99
99
 
@@ -275,7 +275,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
275
275
 
276
276
  :param statement_id: the UUID of the statement
277
277
  """
278
- async with self.async_conn as client:
278
+ async with await self.get_async_conn() as client:
279
279
  desc = await client.describe_statement(Id=statement_id)
280
280
  return desc["Status"] in RUNNING_STATES
281
281
 
@@ -288,6 +288,6 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
288
288
 
289
289
  :param statement_id: the UUID of the statement
290
290
  """
291
- async with self.async_conn as client:
291
+ async with await self.get_async_conn() as client:
292
292
  resp = await client.describe_statement(Id=statement_id)
293
293
  return self.parse_statement_response(resp)
@@ -1318,7 +1318,7 @@ class SageMakerHook(AwsBaseHook):
1318
1318
 
1319
1319
  :param job_name: the name of the training job
1320
1320
  """
1321
- async with self.async_conn as client:
1321
+ async with await self.get_async_conn() as client:
1322
1322
  response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
1323
1323
  return response
1324
1324