apache-airflow-providers-amazon 9.5.0rc1__py3-none-any.whl → 9.5.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. airflow/providers/amazon/aws/auth_manager/avp/entities.py +2 -0
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +67 -18
  3. airflow/providers/amazon/aws/auth_manager/router/login.py +10 -4
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
  5. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  6. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  7. airflow/providers/amazon/aws/hooks/base_aws.py +9 -1
  8. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  9. airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
  10. airflow/providers/amazon/aws/hooks/dms.py +3 -1
  11. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  12. airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
  13. airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
  14. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  15. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  16. airflow/providers/amazon/aws/links/athena.py +1 -2
  17. airflow/providers/amazon/aws/links/base_aws.py +2 -1
  18. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
  19. airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
  20. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  21. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  22. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  23. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  24. airflow/providers/amazon/aws/operators/eks.py +3 -3
  25. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  26. airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
  27. airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
  28. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  29. airflow/providers/amazon/aws/sensors/glacier.py +1 -1
  30. airflow/providers/amazon/aws/sensors/mwaa.py +59 -11
  31. airflow/providers/amazon/aws/sensors/s3.py +1 -1
  32. airflow/providers/amazon/aws/sensors/step_function.py +2 -1
  33. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
  34. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  35. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  36. airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
  37. airflow/providers/amazon/aws/triggers/base.py +10 -1
  38. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  39. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  40. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  41. airflow/providers/amazon/get_provider_info.py +11 -5
  42. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/METADATA +9 -7
  43. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/RECORD +45 -43
  44. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/WHEEL +0 -0
  45. {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/entry_points.txt +0 -0
@@ -34,6 +34,8 @@ class AvpEntities(Enum):
34
34
 
35
35
  # Resource types
36
36
  ASSET = "Asset"
37
+ ASSET_ALIAS = "AssetAlias"
38
+ BACKFILL = "Backfills"
37
39
  CONFIGURATION = "Configuration"
38
40
  CONNECTION = "Connection"
39
41
  CUSTOM = "Custom"
@@ -21,18 +21,12 @@ from collections import defaultdict
21
21
  from collections.abc import Sequence
22
22
  from functools import cached_property
23
23
  from typing import TYPE_CHECKING, Any, cast
24
+ from urllib.parse import urljoin
24
25
 
25
26
  from fastapi import FastAPI
26
27
 
28
+ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
27
29
  from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
28
- from airflow.api_fastapi.auth.managers.models.resource_details import (
29
- AccessView,
30
- ConnectionDetails,
31
- DagAccessEntity,
32
- DagDetails,
33
- PoolDetails,
34
- VariableDetails,
35
- )
36
30
  from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
37
31
  from airflow.configuration import conf
38
32
  from airflow.exceptions import AirflowOptionalProviderFeatureException
@@ -55,7 +49,19 @@ if TYPE_CHECKING:
55
49
  IsAuthorizedPoolRequest,
56
50
  IsAuthorizedVariableRequest,
57
51
  )
58
- from airflow.api_fastapi.auth.managers.models.resource_details import AssetDetails, ConfigurationDetails
52
+ from airflow.api_fastapi.auth.managers.models.resource_details import (
53
+ AccessView,
54
+ AssetAliasDetails,
55
+ AssetDetails,
56
+ BackfillDetails,
57
+ ConfigurationDetails,
58
+ ConnectionDetails,
59
+ DagAccessEntity,
60
+ DagDetails,
61
+ PoolDetails,
62
+ VariableDetails,
63
+ )
64
+ from airflow.api_fastapi.common.types import MenuItem
59
65
 
60
66
 
61
67
  class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
@@ -84,11 +90,11 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
84
90
  return conf.get("api", "base_url")
85
91
 
86
92
  def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
87
- return AwsAuthManagerUser(**token)
93
+ return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
88
94
 
89
95
  def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
90
96
  return {
91
- "user_id": user.get_id(),
97
+ "sub": user.get_id(),
92
98
  "groups": user.get_groups(),
93
99
  "username": user.username,
94
100
  "email": user.email,
@@ -150,6 +156,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
150
156
  context=context,
151
157
  )
152
158
 
159
+ def is_authorized_backfill(
160
+ self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: BackfillDetails | None = None
161
+ ) -> bool:
162
+ backfill_id = details.id if details else None
163
+ return self.avp_facade.is_authorized(
164
+ method=method, entity_type=AvpEntities.BACKFILL, user=user, entity_id=backfill_id
165
+ )
166
+
153
167
  def is_authorized_asset(
154
168
  self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
155
169
  ) -> bool:
@@ -158,6 +172,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
158
172
  method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
159
173
  )
160
174
 
175
+ def is_authorized_asset_alias(
176
+ self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetAliasDetails | None = None
177
+ ) -> bool:
178
+ asset_alias_id = details.id if details else None
179
+ return self.avp_facade.is_authorized(
180
+ method=method, entity_type=AvpEntities.ASSET_ALIAS, user=user, entity_id=asset_alias_id
181
+ )
182
+
161
183
  def is_authorized_pool(
162
184
  self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: PoolDetails | None = None
163
185
  ) -> bool:
@@ -203,6 +225,25 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
203
225
  entity_id=resource_name,
204
226
  )
205
227
 
228
+ def filter_authorized_menu_items(
229
+ self, menu_items: list[MenuItem], *, user: AwsAuthManagerUser
230
+ ) -> list[MenuItem]:
231
+ requests: dict[str, IsAuthorizedRequest] = {}
232
+ for menu_item in menu_items:
233
+ requests[menu_item.value] = self._get_menu_item_request(menu_item.value)
234
+
235
+ batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
236
+ requests=list(requests.values()), user=user
237
+ )
238
+
239
+ def _has_access_to_menu_item(request: IsAuthorizedRequest):
240
+ result = self.avp_facade.get_batch_is_authorized_single_result(
241
+ batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
242
+ )
243
+ return result["decision"] == "ALLOW"
244
+
245
+ return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]
246
+
206
247
  def batch_is_authorized_connection(
207
248
  self,
208
249
  requests: Sequence[IsAuthorizedConnectionRequest],
@@ -213,7 +254,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
213
254
  {
214
255
  "method": request["method"],
215
256
  "entity_type": AvpEntities.CONNECTION,
216
- "entity_id": cast(ConnectionDetails, request["details"]).conn_id
257
+ "entity_id": cast("ConnectionDetails", request["details"]).conn_id
217
258
  if request.get("details")
218
259
  else None,
219
260
  }
@@ -231,10 +272,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
231
272
  {
232
273
  "method": request["method"],
233
274
  "entity_type": AvpEntities.DAG,
234
- "entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
275
+ "entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None,
235
276
  "context": {
236
277
  "dag_entity": {
237
- "string": cast(DagAccessEntity, request["access_entity"]).value,
278
+ "string": cast("DagAccessEntity", request["access_entity"]).value,
238
279
  },
239
280
  }
240
281
  if request.get("access_entity")
@@ -254,7 +295,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
254
295
  {
255
296
  "method": request["method"],
256
297
  "entity_type": AvpEntities.POOL,
257
- "entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
298
+ "entity_id": cast("PoolDetails", request["details"]).name if request.get("details") else None,
258
299
  }
259
300
  for request in requests
260
301
  ]
@@ -270,7 +311,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
270
311
  {
271
312
  "method": request["method"],
272
313
  "entity_type": AvpEntities.VARIABLE,
273
- "entity_id": cast(VariableDetails, request["details"]).key
314
+ "entity_id": cast("VariableDetails", request["details"]).key
274
315
  if request.get("details")
275
316
  else None,
276
317
  }
@@ -278,7 +319,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
278
319
  ]
279
320
  return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
280
321
 
281
- def filter_permitted_dag_ids(
322
+ def filter_authorized_dag_ids(
282
323
  self,
283
324
  *,
284
325
  dag_ids: set[str],
@@ -309,7 +350,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
309
350
  return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
310
351
 
311
352
  def get_url_login(self, **kwargs) -> str:
312
- return f"{self.apiserver_endpoint}/auth/login"
353
+ return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
313
354
 
314
355
  @staticmethod
315
356
  def get_cli_commands() -> list[CLICommand]:
@@ -337,6 +378,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
337
378
 
338
379
  return app
339
380
 
381
+ @staticmethod
382
+ def _get_menu_item_request(menu_item_text: str) -> IsAuthorizedRequest:
383
+ return {
384
+ "method": "MENU",
385
+ "entity_type": AvpEntities.MENU,
386
+ "entity_id": menu_item_text,
387
+ }
388
+
340
389
  def _check_avp_schema_version(self):
341
390
  if not self.avp_facade.is_policy_store_schema_up_to_date():
342
391
  self.log.warning(
@@ -25,7 +25,8 @@ from fastapi import HTTPException, Request
25
25
  from starlette import status
26
26
  from starlette.responses import RedirectResponse
27
27
 
28
- from airflow.api_fastapi.app import get_auth_manager
28
+ from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, get_auth_manager
29
+ from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
29
30
  from airflow.api_fastapi.common.router import AirflowRouter
30
31
  from airflow.configuration import conf
31
32
  from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
@@ -79,8 +80,13 @@ def login_callback(request: Request):
79
80
  username=saml_auth.get_nameid(),
80
81
  email=attributes["email"][0] if "email" in attributes else None,
81
82
  )
82
- url = f"{conf.get('api', 'base_url')}/?token={get_auth_manager().get_jwt_token(user)}"
83
- return RedirectResponse(url=url, status_code=303)
83
+ url = conf.get("api", "base_url")
84
+ token = get_auth_manager().generate_jwt(user)
85
+ response = RedirectResponse(url=url, status_code=303)
86
+
87
+ secure = conf.has_option("api", "ssl_cert")
88
+ response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
89
+ return response
84
90
 
85
91
 
86
92
  def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
@@ -93,7 +99,7 @@ def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
93
99
  "sp": {
94
100
  "entityId": "aws-auth-manager-saml-client",
95
101
  "assertionConsumerService": {
96
- "url": f"{base_url}/auth/login_callback",
102
+ "url": f"{base_url}{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login_callback",
97
103
  "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
98
104
  },
99
105
  },
@@ -278,7 +278,7 @@ class AwsEcsExecutor(BaseExecutor):
278
278
  if not has_exit_codes:
279
279
  return ""
280
280
  reasons = [
281
- f'{container["container_arn"]} - {container["reason"]}'
281
+ f"{container['container_arn']} - {container['reason']}"
282
282
  for container in containers
283
283
  if "reason" in container
284
284
  ]
@@ -16,15 +16,7 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
- from collections.abc import Sequence
20
- from typing import TYPE_CHECKING, cast
21
-
22
- from mypy_boto3_appflow.type_defs import (
23
- DestinationFlowConfigTypeDef,
24
- SourceFlowConfigTypeDef,
25
- TaskTypeDef,
26
- TriggerConfigTypeDef,
27
- )
19
+ from typing import TYPE_CHECKING
28
20
 
29
21
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
30
22
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -125,11 +117,9 @@ class AppflowHook(AwsGenericHook["AppflowClient"]):
125
117
 
126
118
  self.conn.update_flow(
127
119
  flowName=response["flowName"],
128
- destinationFlowConfigList=cast(
129
- Sequence[DestinationFlowConfigTypeDef], response["destinationFlowConfigList"]
130
- ),
131
- sourceFlowConfig=cast(SourceFlowConfigTypeDef, response["sourceFlowConfig"]),
132
- triggerConfig=cast(TriggerConfigTypeDef, response["triggerConfig"]),
120
+ destinationFlowConfigList=response["destinationFlowConfigList"],
121
+ sourceFlowConfig=response["sourceFlowConfig"],
122
+ triggerConfig=response["triggerConfig"],
133
123
  description=response.get("description", "Flow description."),
134
- tasks=cast(Sequence[TaskTypeDef], tasks),
124
+ tasks=tasks,
135
125
  )
@@ -146,10 +146,10 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
146
146
  creds = self.get_credentials(region_name=conn_params["region_name"])
147
147
 
148
148
  return URL.create(
149
- f'awsathena+{conn_params["driver"]}',
149
+ f"awsathena+{conn_params['driver']}",
150
150
  username=creds.access_key,
151
151
  password=creds.secret_key,
152
- host=f'athena.{conn_params["region_name"]}.{conn_params["aws_domain"]}',
152
+ host=f"athena.{conn_params['region_name']}.{conn_params['aws_domain']}",
153
153
  port=443,
154
154
  database=conn_params["schema_name"],
155
155
  query={"aws_session_token": creds.token, **self.conn.extra_dejson},
@@ -943,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
943
943
  self,
944
944
  waiter_name: str,
945
945
  parameters: dict[str, str] | None = None,
946
+ config_overrides: dict[str, Any] | None = None,
946
947
  deferrable: bool = False,
947
948
  client=None,
948
949
  ) -> Waiter:
@@ -962,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
962
963
  :param parameters: will scan the waiter config for the keys of that dict,
963
964
  and replace them with the corresponding value. If a custom waiter has
964
965
  such keys to be expanded, they need to be provided here.
966
+ Note: cannot be used if parameters are included in config_overrides
967
+ :param config_overrides: will update values of provided keys in the waiter's
968
+ config. Only specified keys will be updated.
965
969
  :param deferrable: If True, the waiter is going to be an async custom waiter.
966
970
  An async client must be provided in that case.
967
971
  :param client: The client to use for the waiter's operations
@@ -970,14 +974,18 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
970
974
 
971
975
  if deferrable and not client:
972
976
  raise ValueError("client must be provided for a deferrable waiter.")
977
+ if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
978
+ raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
973
979
  # Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
974
980
  client = client or self._client
975
981
  if self.waiter_path and (waiter_name in self._list_custom_waiters()):
976
982
  # Technically if waiter_name is in custom_waiters then self.waiter_path must
977
983
  # exist but MyPy doesn't like the fact that self.waiter_path could be None.
978
984
  with open(self.waiter_path) as config_file:
979
- config = json.loads(config_file.read())
985
+ config: dict = json.loads(config_file.read())
980
986
 
987
+ if config_overrides is not None:
988
+ config["waiters"][waiter_name].update(config_overrides)
981
989
  config = self._apply_parameters_value(config, waiter_name, parameters)
982
990
  return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
983
991
  waiter_name
@@ -416,8 +416,7 @@ class BatchClientHook(AwsBaseHook):
416
416
  )
417
417
  else:
418
418
  raise AirflowException(
419
- f"AWS Batch job ({job_id}) description error: exceeded status_retries "
420
- f"({self.status_retries})"
419
+ f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
421
420
  )
422
421
 
423
422
  @staticmethod
@@ -30,7 +30,7 @@ import json
30
30
  import sys
31
31
  from copy import deepcopy
32
32
  from pathlib import Path
33
- from typing import TYPE_CHECKING, Callable
33
+ from typing import TYPE_CHECKING, Any, Callable
34
34
 
35
35
  import botocore.client
36
36
  import botocore.exceptions
@@ -144,7 +144,12 @@ class BatchWaitersHook(BatchClientHook):
144
144
  return self._waiter_model
145
145
 
146
146
  def get_waiter(
147
- self, waiter_name: str, _: dict[str, str] | None = None, deferrable: bool = False, client=None
147
+ self,
148
+ waiter_name: str,
149
+ parameters: dict[str, str] | None = None,
150
+ config_overrides: dict[str, Any] | None = None,
151
+ deferrable: bool = False,
152
+ client=None,
148
153
  ) -> botocore.waiter.Waiter:
149
154
  """
150
155
  Get an AWS Batch service waiter, using the configured ``.waiter_model``.
@@ -175,7 +180,10 @@ class BatchWaitersHook(BatchClientHook):
175
180
  the name (including the casing) of the key name in the waiter
176
181
  model file (typically this is CamelCasing); see ``.list_waiters``.
177
182
 
178
- :param _: unused, just here to match the method signature in base_aws
183
+ :param parameters: unused, just here to match the method signature in base_aws
184
+ :param config_overrides: unused, just here to match the method signature in base_aws
185
+ :param deferrable: unused, just here to match the method signature in base_aws
186
+ :param client: unused, just here to match the method signature in base_aws
179
187
 
180
188
  :return: a waiter object for the named AWS Batch service
181
189
  """
@@ -292,7 +292,9 @@ class DmsHook(AwsBaseHook):
292
292
  return arn
293
293
 
294
294
  except ClientError as err:
295
- err_str = f"Error: {err.get('Error','').get('Code','')}: {err.get('Error','').get('Message','')}"
295
+ err_str = (
296
+ f"Error: {err.get('Error', '').get('Code', '')}: {err.get('Error', '').get('Message', '')}"
297
+ )
296
298
  self.log.error("Error while creating replication config: %s", err_str)
297
299
  raise err
298
300
 
@@ -35,7 +35,6 @@ from botocore.signers import RequestSigner
35
35
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
36
36
  from airflow.providers.amazon.aws.hooks.sts import StsHook
37
37
  from airflow.utils import yaml
38
- from airflow.utils.json import AirflowJsonEncoder
39
38
 
40
39
  DEFAULT_PAGINATION_TOKEN = ""
41
40
  STS_TOKEN_EXPIRES_IN = 60
@@ -315,7 +314,7 @@ class EksHook(AwsBaseHook):
315
314
  )
316
315
  if verbose:
317
316
  cluster_data = response.get("cluster")
318
- self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=AirflowJsonEncoder))
317
+ self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, default=repr))
319
318
  return response
320
319
 
321
320
  def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict:
@@ -343,7 +342,7 @@ class EksHook(AwsBaseHook):
343
342
  nodegroup_data = response.get("nodegroup")
344
343
  self.log.info(
345
344
  "Amazon EKS managed node group details: %s",
346
- json.dumps(nodegroup_data, cls=AirflowJsonEncoder),
345
+ json.dumps(nodegroup_data, default=repr),
347
346
  )
348
347
  return response
349
348
 
@@ -374,9 +373,7 @@ class EksHook(AwsBaseHook):
374
373
  )
375
374
  if verbose:
376
375
  fargate_profile_data = response.get("fargateProfile")
377
- self.log.info(
378
- "AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder)
379
- )
376
+ self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, default=repr))
380
377
  return response
381
378
 
382
379
  def get_cluster_state(self, clusterName: str) -> ClusterStates:
@@ -67,7 +67,7 @@ class RedshiftHook(AwsBaseHook):
67
67
  for the cluster that is being created.
68
68
  :param params: Remaining AWS Create cluster API params.
69
69
  """
70
- response = self.get_conn().create_cluster(
70
+ response = self.conn.create_cluster(
71
71
  ClusterIdentifier=cluster_identifier,
72
72
  NodeType=node_type,
73
73
  MasterUsername=master_username,
@@ -87,9 +87,9 @@ class RedshiftHook(AwsBaseHook):
87
87
  :param cluster_identifier: unique identifier of a cluster
88
88
  """
89
89
  try:
90
- response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
90
+ response = self.conn.describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
91
91
  return response[0]["ClusterStatus"] if response else None
92
- except self.get_conn().exceptions.ClusterNotFoundFault:
92
+ except self.conn.exceptions.ClusterNotFoundFault:
93
93
  return "cluster_not_found"
94
94
 
95
95
  async def cluster_status_async(self, cluster_identifier: str) -> str:
@@ -115,7 +115,7 @@ class RedshiftHook(AwsBaseHook):
115
115
  """
116
116
  final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""
117
117
 
118
- response = self.get_conn().delete_cluster(
118
+ response = self.conn.delete_cluster(
119
119
  ClusterIdentifier=cluster_identifier,
120
120
  SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
121
121
  FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
@@ -131,7 +131,7 @@ class RedshiftHook(AwsBaseHook):
131
131
 
132
132
  :param cluster_identifier: unique identifier of a cluster
133
133
  """
134
- response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
134
+ response = self.conn.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
135
135
  if "Snapshots" not in response:
136
136
  return None
137
137
  snapshots = response["Snapshots"]
@@ -149,7 +149,7 @@ class RedshiftHook(AwsBaseHook):
149
149
  :param cluster_identifier: unique identifier of a cluster
150
150
  :param snapshot_identifier: unique identifier for a snapshot of a cluster
151
151
  """
152
- response = self.get_conn().restore_from_cluster_snapshot(
152
+ response = self.conn.restore_from_cluster_snapshot(
153
153
  ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
154
154
  )
155
155
  return response["Cluster"] if response["Cluster"] else None
@@ -175,7 +175,7 @@ class RedshiftHook(AwsBaseHook):
175
175
  """
176
176
  if tags is None:
177
177
  tags = []
178
- response = self.get_conn().create_cluster_snapshot(
178
+ response = self.conn.create_cluster_snapshot(
179
179
  SnapshotIdentifier=snapshot_identifier,
180
180
  ClusterIdentifier=cluster_identifier,
181
181
  ManualSnapshotRetentionPeriod=retention_period,
@@ -192,11 +192,11 @@ class RedshiftHook(AwsBaseHook):
192
192
  :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
193
193
  """
194
194
  try:
195
- response = self.get_conn().describe_cluster_snapshots(
195
+ response = self.conn.describe_cluster_snapshots(
196
196
  SnapshotIdentifier=snapshot_identifier,
197
197
  )
198
198
  snapshot = response.get("Snapshots")[0]
199
199
  snapshot_status: str = snapshot.get("Status")
200
200
  return snapshot_status
201
- except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
201
+ except self.conn.exceptions.ClusterSnapshotNotFoundFault:
202
202
  return None
@@ -186,8 +186,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
186
186
  RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
187
187
  )
188
188
  raise exception_cls(
189
- f"Statement {resp['Id']} terminated with status {status}. "
190
- f"Response details: {pformat(resp)}"
189
+ f"Statement {resp['Id']} terminated with status {status}. Response details: {pformat(resp)}"
191
190
  )
192
191
 
193
192
  self.log.info("Query status: %s", status)
@@ -1494,7 +1494,9 @@ class S3Hook(AwsBaseHook):
1494
1494
  get_hook_lineage_collector().add_output_asset(
1495
1495
  context=self,
1496
1496
  scheme="file",
1497
- asset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
1497
+ asset_kwargs={
1498
+ "path": str(file_path) if file_path.is_absolute() else str(file_path.absolute())
1499
+ },
1498
1500
  )
1499
1501
  file = open(file_path, "wb")
1500
1502
  else:
@@ -131,7 +131,7 @@ def secondary_training_status_message(
131
131
  status_strs = []
132
132
  for transition in transitions_to_print:
133
133
  message = transition["StatusMessage"]
134
- time_utc = timezone.convert_to_utc(cast(datetime, job_description["LastModifiedTime"]))
134
+ time_utc = timezone.convert_to_utc(cast("datetime", job_description["LastModifiedTime"]))
135
135
  status_strs.append(f"{time_utc:%Y-%m-%d %H:%M:%S} {transition['Status']} - {message}")
136
136
 
137
137
  return "\n".join(status_strs)
@@ -25,6 +25,5 @@ class AthenaQueryResultsLink(BaseAwsLink):
25
25
  name = "Query Results"
26
26
  key = "_athena_query_results"
27
27
  format_str = (
28
- BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#"
29
- "/query-editor/history/{query_execution_id}"
28
+ BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#/query-editor/history/{query_execution_id}"
30
29
  )
@@ -19,7 +19,6 @@ from __future__ import annotations
19
19
 
20
20
  from typing import TYPE_CHECKING, ClassVar
21
21
 
22
- from airflow.models import XCom
23
22
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
24
23
  from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
25
24
 
@@ -30,7 +29,9 @@ if TYPE_CHECKING:
30
29
 
31
30
  if AIRFLOW_V_3_0_PLUS:
32
31
  from airflow.sdk import BaseOperatorLink
32
+ from airflow.sdk.execution_time.xcom import XCom
33
33
  else:
34
+ from airflow.models import XCom # type: ignore[no-redef]
34
35
  from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
35
36
 
36
37