apache-airflow-providers-amazon 9.4.0__py3-none-any.whl → 9.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
  8. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  9. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  10. airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
  11. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  12. airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
  13. airflow/providers/amazon/aws/hooks/dms.py +3 -1
  14. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  15. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  16. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  17. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  18. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  19. airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
  20. airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
  21. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  22. airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
  23. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  24. airflow/providers/amazon/aws/links/athena.py +1 -2
  25. airflow/providers/amazon/aws/links/base_aws.py +8 -1
  26. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  27. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
  28. airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
  29. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  30. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  31. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  32. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  33. airflow/providers/amazon/aws/operators/eks.py +3 -3
  34. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  35. airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
  36. airflow/providers/amazon/aws/operators/s3.py +147 -157
  37. airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
  38. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  39. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  40. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  41. airflow/providers/amazon/aws/sensors/glacier.py +1 -1
  42. airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
  43. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  44. airflow/providers/amazon/aws/sensors/s3.py +32 -43
  45. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  46. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  47. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
  48. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  49. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  50. airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
  51. airflow/providers/amazon/aws/triggers/README.md +4 -4
  52. airflow/providers/amazon/aws/triggers/base.py +11 -2
  53. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  54. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  55. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  56. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  57. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  58. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  59. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  60. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  61. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  62. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  63. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  64. airflow/providers/amazon/get_provider_info.py +46 -5
  65. {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +38 -31
  66. {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
  67. {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
  68. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  69. {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "9.4.0"
32
+ __version__ = "9.5.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.9.0"
@@ -20,7 +20,7 @@ from enum import Enum
20
20
  from typing import TYPE_CHECKING
21
21
 
22
22
  if TYPE_CHECKING:
23
- from airflow.auth.managers.base_auth_manager import ResourceMethod
23
+ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
24
24
 
25
25
  AVP_PREFIX_ENTITIES = "Airflow::"
26
26
 
@@ -34,6 +34,8 @@ class AvpEntities(Enum):
34
34
 
35
35
  # Resource types
36
36
  ASSET = "Asset"
37
+ ASSET_ALIAS = "AssetAlias"
38
+ BACKFILL = "Backfills"
37
39
  CONFIGURATION = "Configuration"
38
40
  CONNECTION = "Connection"
39
41
  CUSTOM = "Custom"
@@ -36,7 +36,7 @@ from airflow.utils.helpers import prune_dict
36
36
  from airflow.utils.log.logging_mixin import LoggingMixin
37
37
 
38
38
  if TYPE_CHECKING:
39
- from airflow.auth.managers.base_auth_manager import ResourceMethod
39
+ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
40
40
  from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
41
41
 
42
42
 
@@ -18,22 +18,15 @@ from __future__ import annotations
18
18
 
19
19
  import argparse
20
20
  from collections import defaultdict
21
- from collections.abc import Container, Sequence
21
+ from collections.abc import Sequence
22
22
  from functools import cached_property
23
23
  from typing import TYPE_CHECKING, Any, cast
24
+ from urllib.parse import urljoin
24
25
 
25
26
  from fastapi import FastAPI
26
- from flask import session
27
-
28
- from airflow.auth.managers.base_auth_manager import BaseAuthManager
29
- from airflow.auth.managers.models.resource_details import (
30
- AccessView,
31
- ConnectionDetails,
32
- DagAccessEntity,
33
- DagDetails,
34
- PoolDetails,
35
- VariableDetails,
36
- )
27
+
28
+ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
29
+ from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
37
30
  from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
38
31
  from airflow.configuration import conf
39
32
  from airflow.exceptions import AirflowOptionalProviderFeatureException
@@ -49,16 +42,26 @@ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
49
42
  from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
50
43
 
51
44
  if TYPE_CHECKING:
52
- from flask_appbuilder.menu import MenuItem
53
-
54
- from airflow.auth.managers.base_auth_manager import ResourceMethod
55
- from airflow.auth.managers.models.batch_apis import (
45
+ from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
46
+ from airflow.api_fastapi.auth.managers.models.batch_apis import (
56
47
  IsAuthorizedConnectionRequest,
57
48
  IsAuthorizedDagRequest,
58
49
  IsAuthorizedPoolRequest,
59
50
  IsAuthorizedVariableRequest,
60
51
  )
61
- from airflow.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
52
+ from airflow.api_fastapi.auth.managers.models.resource_details import (
53
+ AccessView,
54
+ AssetAliasDetails,
55
+ AssetDetails,
56
+ BackfillDetails,
57
+ ConfigurationDetails,
58
+ ConnectionDetails,
59
+ DagAccessEntity,
60
+ DagDetails,
61
+ PoolDetails,
62
+ VariableDetails,
63
+ )
64
+ from airflow.api_fastapi.common.types import MenuItem
62
65
 
63
66
 
64
67
  class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
@@ -83,21 +86,15 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
83
86
  return AwsAuthManagerAmazonVerifiedPermissionsFacade()
84
87
 
85
88
  @cached_property
86
- def fastapi_endpoint(self) -> str:
87
- return conf.get("fastapi", "base_url")
88
-
89
- def get_user(self) -> AwsAuthManagerUser | None:
90
- return session["aws_user"] if self.is_logged_in() else None
91
-
92
- def is_logged_in(self) -> bool:
93
- return "aws_user" in session
89
+ def apiserver_endpoint(self) -> str:
90
+ return conf.get("api", "base_url")
94
91
 
95
92
  def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
96
- return AwsAuthManagerUser(**token)
93
+ return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
97
94
 
98
95
  def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
99
96
  return {
100
- "user_id": user.get_id(),
97
+ "sub": user.get_id(),
101
98
  "groups": user.get_groups(),
102
99
  "username": user.username,
103
100
  "email": user.email,
@@ -159,12 +156,28 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
159
156
  context=context,
160
157
  )
161
158
 
159
+ def is_authorized_backfill(
160
+ self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: BackfillDetails | None = None
161
+ ) -> bool:
162
+ backfill_id = details.id if details else None
163
+ return self.avp_facade.is_authorized(
164
+ method=method, entity_type=AvpEntities.BACKFILL, user=user, entity_id=backfill_id
165
+ )
166
+
162
167
  def is_authorized_asset(
163
168
  self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
164
169
  ) -> bool:
165
- asset_uri = details.uri if details else None
170
+ asset_id = details.id if details else None
171
+ return self.avp_facade.is_authorized(
172
+ method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
173
+ )
174
+
175
+ def is_authorized_asset_alias(
176
+ self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetAliasDetails | None = None
177
+ ) -> bool:
178
+ asset_alias_id = details.id if details else None
166
179
  return self.avp_facade.is_authorized(
167
- method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_uri
180
+ method=method, entity_type=AvpEntities.ASSET_ALIAS, user=user, entity_id=asset_alias_id
168
181
  )
169
182
 
170
183
  def is_authorized_pool(
@@ -204,7 +217,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
204
217
 
205
218
  def is_authorized_custom_view(
206
219
  self, *, method: ResourceMethod | str, resource_name: str, user: AwsAuthManagerUser
207
- ):
220
+ ) -> bool:
208
221
  return self.avp_facade.is_authorized(
209
222
  method=method,
210
223
  entity_type=AvpEntities.CUSTOM,
@@ -212,6 +225,25 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
212
225
  entity_id=resource_name,
213
226
  )
214
227
 
228
+ def filter_authorized_menu_items(
229
+ self, menu_items: list[MenuItem], *, user: AwsAuthManagerUser
230
+ ) -> list[MenuItem]:
231
+ requests: dict[str, IsAuthorizedRequest] = {}
232
+ for menu_item in menu_items:
233
+ requests[menu_item.value] = self._get_menu_item_request(menu_item.value)
234
+
235
+ batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
236
+ requests=list(requests.values()), user=user
237
+ )
238
+
239
+ def _has_access_to_menu_item(request: IsAuthorizedRequest):
240
+ result = self.avp_facade.get_batch_is_authorized_single_result(
241
+ batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
242
+ )
243
+ return result["decision"] == "ALLOW"
244
+
245
+ return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]
246
+
215
247
  def batch_is_authorized_connection(
216
248
  self,
217
249
  requests: Sequence[IsAuthorizedConnectionRequest],
@@ -222,7 +254,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
222
254
  {
223
255
  "method": request["method"],
224
256
  "entity_type": AvpEntities.CONNECTION,
225
- "entity_id": cast(ConnectionDetails, request["details"]).conn_id
257
+ "entity_id": cast("ConnectionDetails", request["details"]).conn_id
226
258
  if request.get("details")
227
259
  else None,
228
260
  }
@@ -240,10 +272,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
240
272
  {
241
273
  "method": request["method"],
242
274
  "entity_type": AvpEntities.DAG,
243
- "entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
275
+ "entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None,
244
276
  "context": {
245
277
  "dag_entity": {
246
- "string": cast(DagAccessEntity, request["access_entity"]).value,
278
+ "string": cast("DagAccessEntity", request["access_entity"]).value,
247
279
  },
248
280
  }
249
281
  if request.get("access_entity")
@@ -263,7 +295,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
263
295
  {
264
296
  "method": request["method"],
265
297
  "entity_type": AvpEntities.POOL,
266
- "entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
298
+ "entity_id": cast("PoolDetails", request["details"]).name if request.get("details") else None,
267
299
  }
268
300
  for request in requests
269
301
  ]
@@ -279,7 +311,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
279
311
  {
280
312
  "method": request["method"],
281
313
  "entity_type": AvpEntities.VARIABLE,
282
- "entity_id": cast(VariableDetails, request["details"]).key
314
+ "entity_id": cast("VariableDetails", request["details"]).key
283
315
  if request.get("details")
284
316
  else None,
285
317
  }
@@ -287,28 +319,23 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
287
319
  ]
288
320
  return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
289
321
 
290
- def filter_permitted_dag_ids(
322
+ def filter_authorized_dag_ids(
291
323
  self,
292
324
  *,
293
325
  dag_ids: set[str],
294
326
  user: AwsAuthManagerUser,
295
- methods: Container[ResourceMethod] | None = None,
327
+ method: ResourceMethod = "GET",
296
328
  ):
297
- if not methods:
298
- methods = ["PUT", "GET"]
299
-
300
329
  requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
301
330
  requests_list: list[IsAuthorizedRequest] = []
302
331
  for dag_id in dag_ids:
303
- for method in ["GET", "PUT"]:
304
- if method in methods:
305
- request: IsAuthorizedRequest = {
306
- "method": cast("ResourceMethod", method),
307
- "entity_type": AvpEntities.DAG,
308
- "entity_id": dag_id,
309
- }
310
- requests[dag_id][cast("ResourceMethod", method)] = request
311
- requests_list.append(request)
332
+ request: IsAuthorizedRequest = {
333
+ "method": method,
334
+ "entity_type": AvpEntities.DAG,
335
+ "entity_id": dag_id,
336
+ }
337
+ requests[dag_id][method] = request
338
+ requests_list.append(request)
312
339
 
313
340
  batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
314
341
  requests=requests_list, user=user
@@ -320,67 +347,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
320
347
  )
321
348
  return result["decision"] == "ALLOW"
322
349
 
323
- return {
324
- dag_id
325
- for dag_id in dag_ids
326
- if (
327
- "GET" in methods
328
- and _has_access_to_dag(requests[dag_id]["GET"])
329
- or "PUT" in methods
330
- and _has_access_to_dag(requests[dag_id]["PUT"])
331
- )
332
- }
333
-
334
- def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]:
335
- """
336
- Filter menu items based on user permissions.
337
-
338
- :param menu_items: list of all menu items
339
- """
340
- user = self.get_user()
341
- if not user:
342
- return []
343
-
344
- requests: dict[str, IsAuthorizedRequest] = {}
345
- for menu_item in menu_items:
346
- if menu_item.childs:
347
- for child in menu_item.childs:
348
- requests[child.name] = self._get_menu_item_request(child.name)
349
- else:
350
- requests[menu_item.name] = self._get_menu_item_request(menu_item.name)
351
-
352
- batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
353
- requests=list(requests.values()), user=user
354
- )
355
-
356
- def _has_access_to_menu_item(request: IsAuthorizedRequest):
357
- result = self.avp_facade.get_batch_is_authorized_single_result(
358
- batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
359
- )
360
- return result["decision"] == "ALLOW"
361
-
362
- accessible_items = []
363
- for menu_item in menu_items:
364
- if menu_item.childs:
365
- accessible_children = []
366
- for child in menu_item.childs:
367
- if _has_access_to_menu_item(requests[child.name]):
368
- accessible_children.append(child)
369
- menu_item.childs = accessible_children
370
-
371
- # Display the menu if the user has access to at least one sub item
372
- if len(accessible_children) > 0:
373
- accessible_items.append(menu_item)
374
- elif _has_access_to_menu_item(requests[menu_item.name]):
375
- accessible_items.append(menu_item)
376
-
377
- return accessible_items
350
+ return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
378
351
 
379
352
  def get_url_login(self, **kwargs) -> str:
380
- return f"{self.fastapi_endpoint}/auth/login"
381
-
382
- def get_url_logout(self) -> str:
383
- raise NotImplementedError()
353
+ return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
384
354
 
385
355
  @staticmethod
386
356
  def get_cli_commands() -> list[CLICommand]:
@@ -409,11 +379,11 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
409
379
  return app
410
380
 
411
381
  @staticmethod
412
- def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
382
+ def _get_menu_item_request(menu_item_text: str) -> IsAuthorizedRequest:
413
383
  return {
414
384
  "method": "MENU",
415
385
  "entity_type": AvpEntities.MENU,
416
- "entity_id": resource_name,
386
+ "entity_id": menu_item_text,
417
387
  }
418
388
 
419
389
  def _check_avp_schema_version(self):
@@ -25,7 +25,8 @@ from fastapi import HTTPException, Request
25
25
  from starlette import status
26
26
  from starlette.responses import RedirectResponse
27
27
 
28
- from airflow.api_fastapi.app import get_auth_manager
28
+ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, get_auth_manager
29
+ from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
29
30
  from airflow.api_fastapi.common.router import AirflowRouter
30
31
  from airflow.configuration import conf
31
32
  from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
@@ -79,12 +80,18 @@ def login_callback(request: Request):
79
80
  username=saml_auth.get_nameid(),
80
81
  email=attributes["email"][0] if "email" in attributes else None,
81
82
  )
82
- return RedirectResponse(url=f"/webapp?token={get_auth_manager().get_jwt_token(user)}", status_code=303)
83
+ url = conf.get("api", "base_url")
84
+ token = get_auth_manager().generate_jwt(user)
85
+ response = RedirectResponse(url=url, status_code=303)
86
+
87
+ secure = conf.has_option("api", "ssl_cert")
88
+ response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
89
+ return response
83
90
 
84
91
 
85
92
  def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
86
93
  request_data = _prepare_request(request)
87
- base_url = conf.get(section="fastapi", key="base_url")
94
+ base_url = conf.get(section="api", key="base_url")
88
95
  settings = {
89
96
  # We want to keep this flag on in case of errors.
90
97
  # It provides an error reasons, if turned off, it does not
@@ -92,7 +99,7 @@ def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
92
99
  "sp": {
93
100
  "entityId": "aws-auth-manager-saml-client",
94
101
  "assertionConsumerService": {
95
- "url": f"{base_url}/auth/login_callback",
102
+ "url": f"{base_url}{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login_callback",
96
103
  "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
97
104
  },
98
105
  },
@@ -19,11 +19,14 @@ from __future__ import annotations
19
19
  from airflow.exceptions import AirflowOptionalProviderFeatureException
20
20
 
21
21
  try:
22
- from airflow.auth.managers.models.base_user import BaseUser
22
+ from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
23
23
  except ImportError:
24
- raise AirflowOptionalProviderFeatureException(
25
- "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
26
- )
24
+ try:
25
+ from airflow.auth.managers.models.base_user import BaseUser # type: ignore[no-redef]
26
+ except ImportError:
27
+ raise AirflowOptionalProviderFeatureException(
28
+ "Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
29
+ ) from None
27
30
 
28
31
 
29
32
  class AwsAuthManagerUser(BaseUser):
@@ -278,7 +278,7 @@ class AwsEcsExecutor(BaseExecutor):
278
278
  if not has_exit_codes:
279
279
  return ""
280
280
  reasons = [
281
- f'{container["container_arn"]} - {container["reason"]}'
281
+ f"{container['container_arn']} - {container['reason']}"
282
282
  for container in containers
283
283
  if "reason" in container
284
284
  ]
@@ -16,15 +16,7 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
- from collections.abc import Sequence
20
- from typing import TYPE_CHECKING, cast
21
-
22
- from mypy_boto3_appflow.type_defs import (
23
- DestinationFlowConfigTypeDef,
24
- SourceFlowConfigTypeDef,
25
- TaskTypeDef,
26
- TriggerConfigTypeDef,
27
- )
19
+ from typing import TYPE_CHECKING
28
20
 
29
21
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
30
22
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -125,11 +117,9 @@ class AppflowHook(AwsGenericHook["AppflowClient"]):
125
117
 
126
118
  self.conn.update_flow(
127
119
  flowName=response["flowName"],
128
- destinationFlowConfigList=cast(
129
- Sequence[DestinationFlowConfigTypeDef], response["destinationFlowConfigList"]
130
- ),
131
- sourceFlowConfig=cast(SourceFlowConfigTypeDef, response["sourceFlowConfig"]),
132
- triggerConfig=cast(TriggerConfigTypeDef, response["triggerConfig"]),
120
+ destinationFlowConfigList=response["destinationFlowConfigList"],
121
+ sourceFlowConfig=response["sourceFlowConfig"],
122
+ triggerConfig=response["triggerConfig"],
133
123
  description=response.get("description", "Flow description."),
134
- tasks=cast(Sequence[TaskTypeDef], tasks),
124
+ tasks=tasks,
135
125
  )
@@ -146,10 +146,10 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
146
146
  creds = self.get_credentials(region_name=conn_params["region_name"])
147
147
 
148
148
  return URL.create(
149
- f'awsathena+{conn_params["driver"]}',
149
+ f"awsathena+{conn_params['driver']}",
150
150
  username=creds.access_key,
151
151
  password=creds.secret_key,
152
- host=f'athena.{conn_params["region_name"]}.{conn_params["aws_domain"]}',
152
+ host=f"athena.{conn_params['region_name']}.{conn_params['aws_domain']}",
153
153
  port=443,
154
154
  database=conn_params["schema_name"],
155
155
  query={"aws_session_token": creds.token, **self.conn.extra_dejson},
@@ -30,6 +30,7 @@ import inspect
30
30
  import json
31
31
  import logging
32
32
  import os
33
+ import warnings
33
34
  from copy import deepcopy
34
35
  from functools import cached_property, wraps
35
36
  from pathlib import Path
@@ -41,6 +42,7 @@ import botocore.session
41
42
  import jinja2
42
43
  import requests
43
44
  import tenacity
45
+ from asgiref.sync import sync_to_async
44
46
  from botocore.config import Config
45
47
  from botocore.waiter import Waiter, WaiterModel
46
48
  from dateutil.tz import tzlocal
@@ -50,6 +52,7 @@ from airflow.configuration import conf
50
52
  from airflow.exceptions import (
51
53
  AirflowException,
52
54
  AirflowNotFoundException,
55
+ AirflowProviderDeprecationWarning,
53
56
  )
54
57
  from airflow.hooks.base import BaseHook
55
58
  from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
@@ -747,7 +750,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
747
750
 
748
751
  @property
749
752
  def async_conn(self):
753
+ """
754
+ [DEPRECATED] Get an aiobotocore client to use for async operations.
755
+
756
+ This property is deprecated. Accessing it in an async context will cause the event loop to block.
757
+ Use the async method `get_async_conn` instead.
758
+ """
759
+ warnings.warn(
760
+ "The property `async_conn` is deprecated. Accessing it in an async context will cause the event loop to block. "
761
+ "Use the async method `get_async_conn` instead.",
762
+ AirflowProviderDeprecationWarning,
763
+ stacklevel=2,
764
+ )
765
+
766
+ return self._get_async_conn()
767
+
768
+ async def get_async_conn(self):
750
769
  """Get an aiobotocore client to use for async operations."""
770
+ # We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
771
+ # because one of it's arguments `self.region_name` is a `@property` decorated function
772
+ # calling the cached property `self.conn_config` at the end.
773
+ return await sync_to_async(self._get_async_conn)()
774
+
775
+ def _get_async_conn(self):
751
776
  if not self.client_type:
752
777
  raise ValueError("client_type must be specified.")
753
778
 
@@ -918,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
918
943
  self,
919
944
  waiter_name: str,
920
945
  parameters: dict[str, str] | None = None,
946
+ config_overrides: dict[str, Any] | None = None,
921
947
  deferrable: bool = False,
922
948
  client=None,
923
949
  ) -> Waiter:
@@ -937,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
937
963
  :param parameters: will scan the waiter config for the keys of that dict,
938
964
  and replace them with the corresponding value. If a custom waiter has
939
965
  such keys to be expanded, they need to be provided here.
966
+ Note: cannot be used if parameters are included in config_overrides
967
+ :param config_overrides: will update values of provided keys in the waiter's
968
+ config. Only specified keys will be updated.
940
969
  :param deferrable: If True, the waiter is going to be an async custom waiter.
941
970
  An async client must be provided in that case.
942
971
  :param client: The client to use for the waiter's operations
@@ -945,14 +974,18 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
945
974
 
946
975
  if deferrable and not client:
947
976
  raise ValueError("client must be provided for a deferrable waiter.")
977
+ if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
978
+ raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
948
979
  # Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
949
980
  client = client or self._client
950
981
  if self.waiter_path and (waiter_name in self._list_custom_waiters()):
951
982
  # Technically if waiter_name is in custom_waiters then self.waiter_path must
952
983
  # exist but MyPy doesn't like the fact that self.waiter_path could be None.
953
984
  with open(self.waiter_path) as config_file:
954
- config = json.loads(config_file.read())
985
+ config: dict = json.loads(config_file.read())
955
986
 
987
+ if config_overrides is not None:
988
+ config["waiters"][waiter_name].update(config_overrides)
956
989
  config = self._apply_parameters_value(config, waiter_name, parameters)
957
990
  return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
958
991
  waiter_name
@@ -416,8 +416,7 @@ class BatchClientHook(AwsBaseHook):
416
416
  )
417
417
  else:
418
418
  raise AirflowException(
419
- f"AWS Batch job ({job_id}) description error: exceeded status_retries "
420
- f"({self.status_retries})"
419
+ f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
421
420
  )
422
421
 
423
422
  @staticmethod
@@ -30,7 +30,7 @@ import json
30
30
  import sys
31
31
  from copy import deepcopy
32
32
  from pathlib import Path
33
- from typing import TYPE_CHECKING, Callable
33
+ from typing import TYPE_CHECKING, Any, Callable
34
34
 
35
35
  import botocore.client
36
36
  import botocore.exceptions
@@ -144,7 +144,12 @@ class BatchWaitersHook(BatchClientHook):
144
144
  return self._waiter_model
145
145
 
146
146
  def get_waiter(
147
- self, waiter_name: str, _: dict[str, str] | None = None, deferrable: bool = False, client=None
147
+ self,
148
+ waiter_name: str,
149
+ parameters: dict[str, str] | None = None,
150
+ config_overrides: dict[str, Any] | None = None,
151
+ deferrable: bool = False,
152
+ client=None,
148
153
  ) -> botocore.waiter.Waiter:
149
154
  """
150
155
  Get an AWS Batch service waiter, using the configured ``.waiter_model``.
@@ -175,7 +180,10 @@ class BatchWaitersHook(BatchClientHook):
175
180
  the name (including the casing) of the key name in the waiter
176
181
  model file (typically this is CamelCasing); see ``.list_waiters``.
177
182
 
178
- :param _: unused, just here to match the method signature in base_aws
183
+ :param parameters: unused, just here to match the method signature in base_aws
184
+ :param config_overrides: unused, just here to match the method signature in base_aws
185
+ :param deferrable: unused, just here to match the method signature in base_aws
186
+ :param client: unused, just here to match the method signature in base_aws
179
187
 
180
188
  :return: a waiter object for the named AWS Batch service
181
189
  """
@@ -292,7 +292,9 @@ class DmsHook(AwsBaseHook):
292
292
  return arn
293
293
 
294
294
  except ClientError as err:
295
- err_str = f"Error: {err.get('Error','').get('Code','')}: {err.get('Error','').get('Message','')}"
295
+ err_str = (
296
+ f"Error: {err.get('Error', '').get('Code', '')}: {err.get('Error', '').get('Message', '')}"
297
+ )
296
298
  self.log.error("Error while creating replication config: %s", err_str)
297
299
  raise err
298
300
 
@@ -173,7 +173,7 @@ class EC2Hook(AwsBaseHook):
173
173
  return [instance["InstanceId"] for instance in self.get_instances(filters=filters)]
174
174
 
175
175
  async def get_instance_state_async(self, instance_id: str) -> str:
176
- async with self.async_conn as client:
176
+ async with await self.get_async_conn() as client:
177
177
  response = await client.describe_instances(InstanceIds=[instance_id])
178
178
  return response["Reservations"][0]["Instances"][0]["State"]["Name"]
179
179
 
@@ -35,7 +35,6 @@ from botocore.signers import RequestSigner
35
35
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
36
36
  from airflow.providers.amazon.aws.hooks.sts import StsHook
37
37
  from airflow.utils import yaml
38
- from airflow.utils.json import AirflowJsonEncoder
39
38
 
40
39
  DEFAULT_PAGINATION_TOKEN = ""
41
40
  STS_TOKEN_EXPIRES_IN = 60
@@ -315,7 +314,7 @@ class EksHook(AwsBaseHook):
315
314
  )
316
315
  if verbose:
317
316
  cluster_data = response.get("cluster")
318
- self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=AirflowJsonEncoder))
317
+ self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, default=repr))
319
318
  return response
320
319
 
321
320
  def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict:
@@ -343,7 +342,7 @@ class EksHook(AwsBaseHook):
343
342
  nodegroup_data = response.get("nodegroup")
344
343
  self.log.info(
345
344
  "Amazon EKS managed node group details: %s",
346
- json.dumps(nodegroup_data, cls=AirflowJsonEncoder),
345
+ json.dumps(nodegroup_data, default=repr),
347
346
  )
348
347
  return response
349
348
 
@@ -374,9 +373,7 @@ class EksHook(AwsBaseHook):
374
373
  )
375
374
  if verbose:
376
375
  fargate_profile_data = response.get("fargateProfile")
377
- self.log.info(
378
- "AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder)
379
- )
376
+ self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, default=repr))
380
377
  return response
381
378
 
382
379
  def get_cluster_state(self, clusterName: str) -> ClusterStates: