apache-airflow-providers-amazon 9.5.0rc1__py3-none-any.whl → 9.5.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +2 -0
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +67 -18
- airflow/providers/amazon/aws/auth_manager/router/login.py +10 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +9 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +9 -9
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -2
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/links/base_aws.py +2 -1
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +123 -86
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +59 -11
- airflow/providers/amazon/aws/sensors/s3.py +1 -1
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/base.py +10 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +11 -5
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/METADATA +9 -7
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/RECORD +45 -43
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.5.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc3.dist-info}/entry_points.txt +0 -0
@@ -21,18 +21,12 @@ from collections import defaultdict
|
|
21
21
|
from collections.abc import Sequence
|
22
22
|
from functools import cached_property
|
23
23
|
from typing import TYPE_CHECKING, Any, cast
|
24
|
+
from urllib.parse import urljoin
|
24
25
|
|
25
26
|
from fastapi import FastAPI
|
26
27
|
|
28
|
+
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
|
27
29
|
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
|
28
|
-
from airflow.api_fastapi.auth.managers.models.resource_details import (
|
29
|
-
AccessView,
|
30
|
-
ConnectionDetails,
|
31
|
-
DagAccessEntity,
|
32
|
-
DagDetails,
|
33
|
-
PoolDetails,
|
34
|
-
VariableDetails,
|
35
|
-
)
|
36
30
|
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
|
37
31
|
from airflow.configuration import conf
|
38
32
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
@@ -55,7 +49,19 @@ if TYPE_CHECKING:
|
|
55
49
|
IsAuthorizedPoolRequest,
|
56
50
|
IsAuthorizedVariableRequest,
|
57
51
|
)
|
58
|
-
from airflow.api_fastapi.auth.managers.models.resource_details import
|
52
|
+
from airflow.api_fastapi.auth.managers.models.resource_details import (
|
53
|
+
AccessView,
|
54
|
+
AssetAliasDetails,
|
55
|
+
AssetDetails,
|
56
|
+
BackfillDetails,
|
57
|
+
ConfigurationDetails,
|
58
|
+
ConnectionDetails,
|
59
|
+
DagAccessEntity,
|
60
|
+
DagDetails,
|
61
|
+
PoolDetails,
|
62
|
+
VariableDetails,
|
63
|
+
)
|
64
|
+
from airflow.api_fastapi.common.types import MenuItem
|
59
65
|
|
60
66
|
|
61
67
|
class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
@@ -84,11 +90,11 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
84
90
|
return conf.get("api", "base_url")
|
85
91
|
|
86
92
|
def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
|
87
|
-
return AwsAuthManagerUser(**token)
|
93
|
+
return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
|
88
94
|
|
89
95
|
def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
|
90
96
|
return {
|
91
|
-
"
|
97
|
+
"sub": user.get_id(),
|
92
98
|
"groups": user.get_groups(),
|
93
99
|
"username": user.username,
|
94
100
|
"email": user.email,
|
@@ -150,6 +156,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
150
156
|
context=context,
|
151
157
|
)
|
152
158
|
|
159
|
+
def is_authorized_backfill(
|
160
|
+
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: BackfillDetails | None = None
|
161
|
+
) -> bool:
|
162
|
+
backfill_id = details.id if details else None
|
163
|
+
return self.avp_facade.is_authorized(
|
164
|
+
method=method, entity_type=AvpEntities.BACKFILL, user=user, entity_id=backfill_id
|
165
|
+
)
|
166
|
+
|
153
167
|
def is_authorized_asset(
|
154
168
|
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
|
155
169
|
) -> bool:
|
@@ -158,6 +172,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
158
172
|
method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
|
159
173
|
)
|
160
174
|
|
175
|
+
def is_authorized_asset_alias(
|
176
|
+
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetAliasDetails | None = None
|
177
|
+
) -> bool:
|
178
|
+
asset_alias_id = details.id if details else None
|
179
|
+
return self.avp_facade.is_authorized(
|
180
|
+
method=method, entity_type=AvpEntities.ASSET_ALIAS, user=user, entity_id=asset_alias_id
|
181
|
+
)
|
182
|
+
|
161
183
|
def is_authorized_pool(
|
162
184
|
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: PoolDetails | None = None
|
163
185
|
) -> bool:
|
@@ -203,6 +225,25 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
203
225
|
entity_id=resource_name,
|
204
226
|
)
|
205
227
|
|
228
|
+
def filter_authorized_menu_items(
|
229
|
+
self, menu_items: list[MenuItem], *, user: AwsAuthManagerUser
|
230
|
+
) -> list[MenuItem]:
|
231
|
+
requests: dict[str, IsAuthorizedRequest] = {}
|
232
|
+
for menu_item in menu_items:
|
233
|
+
requests[menu_item.value] = self._get_menu_item_request(menu_item.value)
|
234
|
+
|
235
|
+
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
236
|
+
requests=list(requests.values()), user=user
|
237
|
+
)
|
238
|
+
|
239
|
+
def _has_access_to_menu_item(request: IsAuthorizedRequest):
|
240
|
+
result = self.avp_facade.get_batch_is_authorized_single_result(
|
241
|
+
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
242
|
+
)
|
243
|
+
return result["decision"] == "ALLOW"
|
244
|
+
|
245
|
+
return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]
|
246
|
+
|
206
247
|
def batch_is_authorized_connection(
|
207
248
|
self,
|
208
249
|
requests: Sequence[IsAuthorizedConnectionRequest],
|
@@ -213,7 +254,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
213
254
|
{
|
214
255
|
"method": request["method"],
|
215
256
|
"entity_type": AvpEntities.CONNECTION,
|
216
|
-
"entity_id": cast(ConnectionDetails, request["details"]).conn_id
|
257
|
+
"entity_id": cast("ConnectionDetails", request["details"]).conn_id
|
217
258
|
if request.get("details")
|
218
259
|
else None,
|
219
260
|
}
|
@@ -231,10 +272,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
231
272
|
{
|
232
273
|
"method": request["method"],
|
233
274
|
"entity_type": AvpEntities.DAG,
|
234
|
-
"entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
|
275
|
+
"entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None,
|
235
276
|
"context": {
|
236
277
|
"dag_entity": {
|
237
|
-
"string": cast(DagAccessEntity, request["access_entity"]).value,
|
278
|
+
"string": cast("DagAccessEntity", request["access_entity"]).value,
|
238
279
|
},
|
239
280
|
}
|
240
281
|
if request.get("access_entity")
|
@@ -254,7 +295,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
254
295
|
{
|
255
296
|
"method": request["method"],
|
256
297
|
"entity_type": AvpEntities.POOL,
|
257
|
-
"entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
|
298
|
+
"entity_id": cast("PoolDetails", request["details"]).name if request.get("details") else None,
|
258
299
|
}
|
259
300
|
for request in requests
|
260
301
|
]
|
@@ -270,7 +311,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
270
311
|
{
|
271
312
|
"method": request["method"],
|
272
313
|
"entity_type": AvpEntities.VARIABLE,
|
273
|
-
"entity_id": cast(VariableDetails, request["details"]).key
|
314
|
+
"entity_id": cast("VariableDetails", request["details"]).key
|
274
315
|
if request.get("details")
|
275
316
|
else None,
|
276
317
|
}
|
@@ -278,7 +319,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
278
319
|
]
|
279
320
|
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
|
280
321
|
|
281
|
-
def
|
322
|
+
def filter_authorized_dag_ids(
|
282
323
|
self,
|
283
324
|
*,
|
284
325
|
dag_ids: set[str],
|
@@ -309,7 +350,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
309
350
|
return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
|
310
351
|
|
311
352
|
def get_url_login(self, **kwargs) -> str:
|
312
|
-
return f"{
|
353
|
+
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
|
313
354
|
|
314
355
|
@staticmethod
|
315
356
|
def get_cli_commands() -> list[CLICommand]:
|
@@ -337,6 +378,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
337
378
|
|
338
379
|
return app
|
339
380
|
|
381
|
+
@staticmethod
|
382
|
+
def _get_menu_item_request(menu_item_text: str) -> IsAuthorizedRequest:
|
383
|
+
return {
|
384
|
+
"method": "MENU",
|
385
|
+
"entity_type": AvpEntities.MENU,
|
386
|
+
"entity_id": menu_item_text,
|
387
|
+
}
|
388
|
+
|
340
389
|
def _check_avp_schema_version(self):
|
341
390
|
if not self.avp_facade.is_policy_store_schema_up_to_date():
|
342
391
|
self.log.warning(
|
@@ -25,7 +25,8 @@ from fastapi import HTTPException, Request
|
|
25
25
|
from starlette import status
|
26
26
|
from starlette.responses import RedirectResponse
|
27
27
|
|
28
|
-
from airflow.api_fastapi.app import get_auth_manager
|
28
|
+
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, get_auth_manager
|
29
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
|
29
30
|
from airflow.api_fastapi.common.router import AirflowRouter
|
30
31
|
from airflow.configuration import conf
|
31
32
|
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
|
@@ -79,8 +80,13 @@ def login_callback(request: Request):
|
|
79
80
|
username=saml_auth.get_nameid(),
|
80
81
|
email=attributes["email"][0] if "email" in attributes else None,
|
81
82
|
)
|
82
|
-
url =
|
83
|
-
|
83
|
+
url = conf.get("api", "base_url")
|
84
|
+
token = get_auth_manager().generate_jwt(user)
|
85
|
+
response = RedirectResponse(url=url, status_code=303)
|
86
|
+
|
87
|
+
secure = conf.has_option("api", "ssl_cert")
|
88
|
+
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
|
89
|
+
return response
|
84
90
|
|
85
91
|
|
86
92
|
def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
@@ -93,7 +99,7 @@ def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
|
93
99
|
"sp": {
|
94
100
|
"entityId": "aws-auth-manager-saml-client",
|
95
101
|
"assertionConsumerService": {
|
96
|
-
"url": f"{base_url}/
|
102
|
+
"url": f"{base_url}{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login_callback",
|
97
103
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
98
104
|
},
|
99
105
|
},
|
@@ -278,7 +278,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
278
278
|
if not has_exit_codes:
|
279
279
|
return ""
|
280
280
|
reasons = [
|
281
|
-
f
|
281
|
+
f"{container['container_arn']} - {container['reason']}"
|
282
282
|
for container in containers
|
283
283
|
if "reason" in container
|
284
284
|
]
|
@@ -16,15 +16,7 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from
|
20
|
-
from typing import TYPE_CHECKING, cast
|
21
|
-
|
22
|
-
from mypy_boto3_appflow.type_defs import (
|
23
|
-
DestinationFlowConfigTypeDef,
|
24
|
-
SourceFlowConfigTypeDef,
|
25
|
-
TaskTypeDef,
|
26
|
-
TriggerConfigTypeDef,
|
27
|
-
)
|
19
|
+
from typing import TYPE_CHECKING
|
28
20
|
|
29
21
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
30
22
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
@@ -125,11 +117,9 @@ class AppflowHook(AwsGenericHook["AppflowClient"]):
|
|
125
117
|
|
126
118
|
self.conn.update_flow(
|
127
119
|
flowName=response["flowName"],
|
128
|
-
destinationFlowConfigList=
|
129
|
-
|
130
|
-
|
131
|
-
sourceFlowConfig=cast(SourceFlowConfigTypeDef, response["sourceFlowConfig"]),
|
132
|
-
triggerConfig=cast(TriggerConfigTypeDef, response["triggerConfig"]),
|
120
|
+
destinationFlowConfigList=response["destinationFlowConfigList"],
|
121
|
+
sourceFlowConfig=response["sourceFlowConfig"],
|
122
|
+
triggerConfig=response["triggerConfig"],
|
133
123
|
description=response.get("description", "Flow description."),
|
134
|
-
tasks=
|
124
|
+
tasks=tasks,
|
135
125
|
)
|
@@ -146,10 +146,10 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
|
146
146
|
creds = self.get_credentials(region_name=conn_params["region_name"])
|
147
147
|
|
148
148
|
return URL.create(
|
149
|
-
f
|
149
|
+
f"awsathena+{conn_params['driver']}",
|
150
150
|
username=creds.access_key,
|
151
151
|
password=creds.secret_key,
|
152
|
-
host=f
|
152
|
+
host=f"athena.{conn_params['region_name']}.{conn_params['aws_domain']}",
|
153
153
|
port=443,
|
154
154
|
database=conn_params["schema_name"],
|
155
155
|
query={"aws_session_token": creds.token, **self.conn.extra_dejson},
|
@@ -943,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
943
943
|
self,
|
944
944
|
waiter_name: str,
|
945
945
|
parameters: dict[str, str] | None = None,
|
946
|
+
config_overrides: dict[str, Any] | None = None,
|
946
947
|
deferrable: bool = False,
|
947
948
|
client=None,
|
948
949
|
) -> Waiter:
|
@@ -962,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
962
963
|
:param parameters: will scan the waiter config for the keys of that dict,
|
963
964
|
and replace them with the corresponding value. If a custom waiter has
|
964
965
|
such keys to be expanded, they need to be provided here.
|
966
|
+
Note: cannot be used if parameters are included in config_overrides
|
967
|
+
:param config_overrides: will update values of provided keys in the waiter's
|
968
|
+
config. Only specified keys will be updated.
|
965
969
|
:param deferrable: If True, the waiter is going to be an async custom waiter.
|
966
970
|
An async client must be provided in that case.
|
967
971
|
:param client: The client to use for the waiter's operations
|
@@ -970,14 +974,18 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
970
974
|
|
971
975
|
if deferrable and not client:
|
972
976
|
raise ValueError("client must be provided for a deferrable waiter.")
|
977
|
+
if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
|
978
|
+
raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
|
973
979
|
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
|
974
980
|
client = client or self._client
|
975
981
|
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
|
976
982
|
# Technically if waiter_name is in custom_waiters then self.waiter_path must
|
977
983
|
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
|
978
984
|
with open(self.waiter_path) as config_file:
|
979
|
-
config = json.loads(config_file.read())
|
985
|
+
config: dict = json.loads(config_file.read())
|
980
986
|
|
987
|
+
if config_overrides is not None:
|
988
|
+
config["waiters"][waiter_name].update(config_overrides)
|
981
989
|
config = self._apply_parameters_value(config, waiter_name, parameters)
|
982
990
|
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
|
983
991
|
waiter_name
|
@@ -416,8 +416,7 @@ class BatchClientHook(AwsBaseHook):
|
|
416
416
|
)
|
417
417
|
else:
|
418
418
|
raise AirflowException(
|
419
|
-
f"AWS Batch job ({job_id}) description error: exceeded status_retries "
|
420
|
-
f"({self.status_retries})"
|
419
|
+
f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
|
421
420
|
)
|
422
421
|
|
423
422
|
@staticmethod
|
@@ -30,7 +30,7 @@ import json
|
|
30
30
|
import sys
|
31
31
|
from copy import deepcopy
|
32
32
|
from pathlib import Path
|
33
|
-
from typing import TYPE_CHECKING, Callable
|
33
|
+
from typing import TYPE_CHECKING, Any, Callable
|
34
34
|
|
35
35
|
import botocore.client
|
36
36
|
import botocore.exceptions
|
@@ -144,7 +144,12 @@ class BatchWaitersHook(BatchClientHook):
|
|
144
144
|
return self._waiter_model
|
145
145
|
|
146
146
|
def get_waiter(
|
147
|
-
self,
|
147
|
+
self,
|
148
|
+
waiter_name: str,
|
149
|
+
parameters: dict[str, str] | None = None,
|
150
|
+
config_overrides: dict[str, Any] | None = None,
|
151
|
+
deferrable: bool = False,
|
152
|
+
client=None,
|
148
153
|
) -> botocore.waiter.Waiter:
|
149
154
|
"""
|
150
155
|
Get an AWS Batch service waiter, using the configured ``.waiter_model``.
|
@@ -175,7 +180,10 @@ class BatchWaitersHook(BatchClientHook):
|
|
175
180
|
the name (including the casing) of the key name in the waiter
|
176
181
|
model file (typically this is CamelCasing); see ``.list_waiters``.
|
177
182
|
|
178
|
-
:param
|
183
|
+
:param parameters: unused, just here to match the method signature in base_aws
|
184
|
+
:param config_overrides: unused, just here to match the method signature in base_aws
|
185
|
+
:param deferrable: unused, just here to match the method signature in base_aws
|
186
|
+
:param client: unused, just here to match the method signature in base_aws
|
179
187
|
|
180
188
|
:return: a waiter object for the named AWS Batch service
|
181
189
|
"""
|
@@ -292,7 +292,9 @@ class DmsHook(AwsBaseHook):
|
|
292
292
|
return arn
|
293
293
|
|
294
294
|
except ClientError as err:
|
295
|
-
err_str =
|
295
|
+
err_str = (
|
296
|
+
f"Error: {err.get('Error', '').get('Code', '')}: {err.get('Error', '').get('Message', '')}"
|
297
|
+
)
|
296
298
|
self.log.error("Error while creating replication config: %s", err_str)
|
297
299
|
raise err
|
298
300
|
|
@@ -35,7 +35,6 @@ from botocore.signers import RequestSigner
|
|
35
35
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
36
36
|
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
37
37
|
from airflow.utils import yaml
|
38
|
-
from airflow.utils.json import AirflowJsonEncoder
|
39
38
|
|
40
39
|
DEFAULT_PAGINATION_TOKEN = ""
|
41
40
|
STS_TOKEN_EXPIRES_IN = 60
|
@@ -315,7 +314,7 @@ class EksHook(AwsBaseHook):
|
|
315
314
|
)
|
316
315
|
if verbose:
|
317
316
|
cluster_data = response.get("cluster")
|
318
|
-
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data,
|
317
|
+
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, default=repr))
|
319
318
|
return response
|
320
319
|
|
321
320
|
def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict:
|
@@ -343,7 +342,7 @@ class EksHook(AwsBaseHook):
|
|
343
342
|
nodegroup_data = response.get("nodegroup")
|
344
343
|
self.log.info(
|
345
344
|
"Amazon EKS managed node group details: %s",
|
346
|
-
json.dumps(nodegroup_data,
|
345
|
+
json.dumps(nodegroup_data, default=repr),
|
347
346
|
)
|
348
347
|
return response
|
349
348
|
|
@@ -374,9 +373,7 @@ class EksHook(AwsBaseHook):
|
|
374
373
|
)
|
375
374
|
if verbose:
|
376
375
|
fargate_profile_data = response.get("fargateProfile")
|
377
|
-
self.log.info(
|
378
|
-
"AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder)
|
379
|
-
)
|
376
|
+
self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, default=repr))
|
380
377
|
return response
|
381
378
|
|
382
379
|
def get_cluster_state(self, clusterName: str) -> ClusterStates:
|
@@ -67,7 +67,7 @@ class RedshiftHook(AwsBaseHook):
|
|
67
67
|
for the cluster that is being created.
|
68
68
|
:param params: Remaining AWS Create cluster API params.
|
69
69
|
"""
|
70
|
-
response = self.
|
70
|
+
response = self.conn.create_cluster(
|
71
71
|
ClusterIdentifier=cluster_identifier,
|
72
72
|
NodeType=node_type,
|
73
73
|
MasterUsername=master_username,
|
@@ -87,9 +87,9 @@ class RedshiftHook(AwsBaseHook):
|
|
87
87
|
:param cluster_identifier: unique identifier of a cluster
|
88
88
|
"""
|
89
89
|
try:
|
90
|
-
response = self.
|
90
|
+
response = self.conn.describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
|
91
91
|
return response[0]["ClusterStatus"] if response else None
|
92
|
-
except self.
|
92
|
+
except self.conn.exceptions.ClusterNotFoundFault:
|
93
93
|
return "cluster_not_found"
|
94
94
|
|
95
95
|
async def cluster_status_async(self, cluster_identifier: str) -> str:
|
@@ -115,7 +115,7 @@ class RedshiftHook(AwsBaseHook):
|
|
115
115
|
"""
|
116
116
|
final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""
|
117
117
|
|
118
|
-
response = self.
|
118
|
+
response = self.conn.delete_cluster(
|
119
119
|
ClusterIdentifier=cluster_identifier,
|
120
120
|
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
|
121
121
|
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
|
@@ -131,7 +131,7 @@ class RedshiftHook(AwsBaseHook):
|
|
131
131
|
|
132
132
|
:param cluster_identifier: unique identifier of a cluster
|
133
133
|
"""
|
134
|
-
response = self.
|
134
|
+
response = self.conn.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
|
135
135
|
if "Snapshots" not in response:
|
136
136
|
return None
|
137
137
|
snapshots = response["Snapshots"]
|
@@ -149,7 +149,7 @@ class RedshiftHook(AwsBaseHook):
|
|
149
149
|
:param cluster_identifier: unique identifier of a cluster
|
150
150
|
:param snapshot_identifier: unique identifier for a snapshot of a cluster
|
151
151
|
"""
|
152
|
-
response = self.
|
152
|
+
response = self.conn.restore_from_cluster_snapshot(
|
153
153
|
ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
|
154
154
|
)
|
155
155
|
return response["Cluster"] if response["Cluster"] else None
|
@@ -175,7 +175,7 @@ class RedshiftHook(AwsBaseHook):
|
|
175
175
|
"""
|
176
176
|
if tags is None:
|
177
177
|
tags = []
|
178
|
-
response = self.
|
178
|
+
response = self.conn.create_cluster_snapshot(
|
179
179
|
SnapshotIdentifier=snapshot_identifier,
|
180
180
|
ClusterIdentifier=cluster_identifier,
|
181
181
|
ManualSnapshotRetentionPeriod=retention_period,
|
@@ -192,11 +192,11 @@ class RedshiftHook(AwsBaseHook):
|
|
192
192
|
:param snapshot_identifier: A unique identifier for the snapshot that you are requesting
|
193
193
|
"""
|
194
194
|
try:
|
195
|
-
response = self.
|
195
|
+
response = self.conn.describe_cluster_snapshots(
|
196
196
|
SnapshotIdentifier=snapshot_identifier,
|
197
197
|
)
|
198
198
|
snapshot = response.get("Snapshots")[0]
|
199
199
|
snapshot_status: str = snapshot.get("Status")
|
200
200
|
return snapshot_status
|
201
|
-
except self.
|
201
|
+
except self.conn.exceptions.ClusterSnapshotNotFoundFault:
|
202
202
|
return None
|
@@ -186,8 +186,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
186
186
|
RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
|
187
187
|
)
|
188
188
|
raise exception_cls(
|
189
|
-
f"Statement {resp['Id']} terminated with status {status}. "
|
190
|
-
f"Response details: {pformat(resp)}"
|
189
|
+
f"Statement {resp['Id']} terminated with status {status}. Response details: {pformat(resp)}"
|
191
190
|
)
|
192
191
|
|
193
192
|
self.log.info("Query status: %s", status)
|
@@ -1494,7 +1494,9 @@ class S3Hook(AwsBaseHook):
|
|
1494
1494
|
get_hook_lineage_collector().add_output_asset(
|
1495
1495
|
context=self,
|
1496
1496
|
scheme="file",
|
1497
|
-
asset_kwargs={
|
1497
|
+
asset_kwargs={
|
1498
|
+
"path": str(file_path) if file_path.is_absolute() else str(file_path.absolute())
|
1499
|
+
},
|
1498
1500
|
)
|
1499
1501
|
file = open(file_path, "wb")
|
1500
1502
|
else:
|
@@ -131,7 +131,7 @@ def secondary_training_status_message(
|
|
131
131
|
status_strs = []
|
132
132
|
for transition in transitions_to_print:
|
133
133
|
message = transition["StatusMessage"]
|
134
|
-
time_utc = timezone.convert_to_utc(cast(datetime, job_description["LastModifiedTime"]))
|
134
|
+
time_utc = timezone.convert_to_utc(cast("datetime", job_description["LastModifiedTime"]))
|
135
135
|
status_strs.append(f"{time_utc:%Y-%m-%d %H:%M:%S} {transition['Status']} - {message}")
|
136
136
|
|
137
137
|
return "\n".join(status_strs)
|
@@ -25,6 +25,5 @@ class AthenaQueryResultsLink(BaseAwsLink):
|
|
25
25
|
name = "Query Results"
|
26
26
|
key = "_athena_query_results"
|
27
27
|
format_str = (
|
28
|
-
BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}
|
29
|
-
"/query-editor/history/{query_execution_id}"
|
28
|
+
BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#/query-editor/history/{query_execution_id}"
|
30
29
|
)
|
@@ -19,7 +19,6 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from typing import TYPE_CHECKING, ClassVar
|
21
21
|
|
22
|
-
from airflow.models import XCom
|
23
22
|
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
24
23
|
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
25
24
|
|
@@ -30,7 +29,9 @@ if TYPE_CHECKING:
|
|
30
29
|
|
31
30
|
if AIRFLOW_V_3_0_PLUS:
|
32
31
|
from airflow.sdk import BaseOperatorLink
|
32
|
+
from airflow.sdk.execution_time.xcom import XCom
|
33
33
|
else:
|
34
|
+
from airflow.models import XCom # type: ignore[no-redef]
|
34
35
|
from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
|
35
36
|
|
36
37
|
|