apache-airflow-providers-amazon 9.4.0__py3-none-any.whl → 9.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
- airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
- airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/links/base_aws.py +8 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +32 -43
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +11 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +46 -5
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +38 -31
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.
|
32
|
+
__version__ = "9.5.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.9.0"
|
@@ -20,7 +20,7 @@ from enum import Enum
|
|
20
20
|
from typing import TYPE_CHECKING
|
21
21
|
|
22
22
|
if TYPE_CHECKING:
|
23
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
23
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
24
24
|
|
25
25
|
AVP_PREFIX_ENTITIES = "Airflow::"
|
26
26
|
|
@@ -34,6 +34,8 @@ class AvpEntities(Enum):
|
|
34
34
|
|
35
35
|
# Resource types
|
36
36
|
ASSET = "Asset"
|
37
|
+
ASSET_ALIAS = "AssetAlias"
|
38
|
+
BACKFILL = "Backfills"
|
37
39
|
CONFIGURATION = "Configuration"
|
38
40
|
CONNECTION = "Connection"
|
39
41
|
CUSTOM = "Custom"
|
@@ -36,7 +36,7 @@ from airflow.utils.helpers import prune_dict
|
|
36
36
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
37
37
|
|
38
38
|
if TYPE_CHECKING:
|
39
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
39
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
40
40
|
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
41
41
|
|
42
42
|
|
@@ -18,22 +18,15 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
import argparse
|
20
20
|
from collections import defaultdict
|
21
|
-
from collections.abc import
|
21
|
+
from collections.abc import Sequence
|
22
22
|
from functools import cached_property
|
23
23
|
from typing import TYPE_CHECKING, Any, cast
|
24
|
+
from urllib.parse import urljoin
|
24
25
|
|
25
26
|
from fastapi import FastAPI
|
26
|
-
|
27
|
-
|
28
|
-
from airflow.auth.managers.base_auth_manager import BaseAuthManager
|
29
|
-
from airflow.auth.managers.models.resource_details import (
|
30
|
-
AccessView,
|
31
|
-
ConnectionDetails,
|
32
|
-
DagAccessEntity,
|
33
|
-
DagDetails,
|
34
|
-
PoolDetails,
|
35
|
-
VariableDetails,
|
36
|
-
)
|
27
|
+
|
28
|
+
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
|
29
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
|
37
30
|
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
|
38
31
|
from airflow.configuration import conf
|
39
32
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
@@ -49,16 +42,26 @@ from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
|
49
42
|
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
50
43
|
|
51
44
|
if TYPE_CHECKING:
|
52
|
-
from
|
53
|
-
|
54
|
-
from airflow.auth.managers.base_auth_manager import ResourceMethod
|
55
|
-
from airflow.auth.managers.models.batch_apis import (
|
45
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
|
46
|
+
from airflow.api_fastapi.auth.managers.models.batch_apis import (
|
56
47
|
IsAuthorizedConnectionRequest,
|
57
48
|
IsAuthorizedDagRequest,
|
58
49
|
IsAuthorizedPoolRequest,
|
59
50
|
IsAuthorizedVariableRequest,
|
60
51
|
)
|
61
|
-
from airflow.auth.managers.models.resource_details import
|
52
|
+
from airflow.api_fastapi.auth.managers.models.resource_details import (
|
53
|
+
AccessView,
|
54
|
+
AssetAliasDetails,
|
55
|
+
AssetDetails,
|
56
|
+
BackfillDetails,
|
57
|
+
ConfigurationDetails,
|
58
|
+
ConnectionDetails,
|
59
|
+
DagAccessEntity,
|
60
|
+
DagDetails,
|
61
|
+
PoolDetails,
|
62
|
+
VariableDetails,
|
63
|
+
)
|
64
|
+
from airflow.api_fastapi.common.types import MenuItem
|
62
65
|
|
63
66
|
|
64
67
|
class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
@@ -83,21 +86,15 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
83
86
|
return AwsAuthManagerAmazonVerifiedPermissionsFacade()
|
84
87
|
|
85
88
|
@cached_property
|
86
|
-
def
|
87
|
-
return conf.get("
|
88
|
-
|
89
|
-
def get_user(self) -> AwsAuthManagerUser | None:
|
90
|
-
return session["aws_user"] if self.is_logged_in() else None
|
91
|
-
|
92
|
-
def is_logged_in(self) -> bool:
|
93
|
-
return "aws_user" in session
|
89
|
+
def apiserver_endpoint(self) -> str:
|
90
|
+
return conf.get("api", "base_url")
|
94
91
|
|
95
92
|
def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
|
96
|
-
return AwsAuthManagerUser(**token)
|
93
|
+
return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
|
97
94
|
|
98
95
|
def serialize_user(self, user: AwsAuthManagerUser) -> dict[str, Any]:
|
99
96
|
return {
|
100
|
-
"
|
97
|
+
"sub": user.get_id(),
|
101
98
|
"groups": user.get_groups(),
|
102
99
|
"username": user.username,
|
103
100
|
"email": user.email,
|
@@ -159,12 +156,28 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
159
156
|
context=context,
|
160
157
|
)
|
161
158
|
|
159
|
+
def is_authorized_backfill(
|
160
|
+
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: BackfillDetails | None = None
|
161
|
+
) -> bool:
|
162
|
+
backfill_id = details.id if details else None
|
163
|
+
return self.avp_facade.is_authorized(
|
164
|
+
method=method, entity_type=AvpEntities.BACKFILL, user=user, entity_id=backfill_id
|
165
|
+
)
|
166
|
+
|
162
167
|
def is_authorized_asset(
|
163
168
|
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetDetails | None = None
|
164
169
|
) -> bool:
|
165
|
-
|
170
|
+
asset_id = details.id if details else None
|
171
|
+
return self.avp_facade.is_authorized(
|
172
|
+
method=method, entity_type=AvpEntities.ASSET, user=user, entity_id=asset_id
|
173
|
+
)
|
174
|
+
|
175
|
+
def is_authorized_asset_alias(
|
176
|
+
self, *, method: ResourceMethod, user: AwsAuthManagerUser, details: AssetAliasDetails | None = None
|
177
|
+
) -> bool:
|
178
|
+
asset_alias_id = details.id if details else None
|
166
179
|
return self.avp_facade.is_authorized(
|
167
|
-
method=method, entity_type=AvpEntities.
|
180
|
+
method=method, entity_type=AvpEntities.ASSET_ALIAS, user=user, entity_id=asset_alias_id
|
168
181
|
)
|
169
182
|
|
170
183
|
def is_authorized_pool(
|
@@ -204,7 +217,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
204
217
|
|
205
218
|
def is_authorized_custom_view(
|
206
219
|
self, *, method: ResourceMethod | str, resource_name: str, user: AwsAuthManagerUser
|
207
|
-
):
|
220
|
+
) -> bool:
|
208
221
|
return self.avp_facade.is_authorized(
|
209
222
|
method=method,
|
210
223
|
entity_type=AvpEntities.CUSTOM,
|
@@ -212,6 +225,25 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
212
225
|
entity_id=resource_name,
|
213
226
|
)
|
214
227
|
|
228
|
+
def filter_authorized_menu_items(
|
229
|
+
self, menu_items: list[MenuItem], *, user: AwsAuthManagerUser
|
230
|
+
) -> list[MenuItem]:
|
231
|
+
requests: dict[str, IsAuthorizedRequest] = {}
|
232
|
+
for menu_item in menu_items:
|
233
|
+
requests[menu_item.value] = self._get_menu_item_request(menu_item.value)
|
234
|
+
|
235
|
+
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
236
|
+
requests=list(requests.values()), user=user
|
237
|
+
)
|
238
|
+
|
239
|
+
def _has_access_to_menu_item(request: IsAuthorizedRequest):
|
240
|
+
result = self.avp_facade.get_batch_is_authorized_single_result(
|
241
|
+
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
242
|
+
)
|
243
|
+
return result["decision"] == "ALLOW"
|
244
|
+
|
245
|
+
return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]
|
246
|
+
|
215
247
|
def batch_is_authorized_connection(
|
216
248
|
self,
|
217
249
|
requests: Sequence[IsAuthorizedConnectionRequest],
|
@@ -222,7 +254,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
222
254
|
{
|
223
255
|
"method": request["method"],
|
224
256
|
"entity_type": AvpEntities.CONNECTION,
|
225
|
-
"entity_id": cast(ConnectionDetails, request["details"]).conn_id
|
257
|
+
"entity_id": cast("ConnectionDetails", request["details"]).conn_id
|
226
258
|
if request.get("details")
|
227
259
|
else None,
|
228
260
|
}
|
@@ -240,10 +272,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
240
272
|
{
|
241
273
|
"method": request["method"],
|
242
274
|
"entity_type": AvpEntities.DAG,
|
243
|
-
"entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
|
275
|
+
"entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None,
|
244
276
|
"context": {
|
245
277
|
"dag_entity": {
|
246
|
-
"string": cast(DagAccessEntity, request["access_entity"]).value,
|
278
|
+
"string": cast("DagAccessEntity", request["access_entity"]).value,
|
247
279
|
},
|
248
280
|
}
|
249
281
|
if request.get("access_entity")
|
@@ -263,7 +295,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
263
295
|
{
|
264
296
|
"method": request["method"],
|
265
297
|
"entity_type": AvpEntities.POOL,
|
266
|
-
"entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
|
298
|
+
"entity_id": cast("PoolDetails", request["details"]).name if request.get("details") else None,
|
267
299
|
}
|
268
300
|
for request in requests
|
269
301
|
]
|
@@ -279,7 +311,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
279
311
|
{
|
280
312
|
"method": request["method"],
|
281
313
|
"entity_type": AvpEntities.VARIABLE,
|
282
|
-
"entity_id": cast(VariableDetails, request["details"]).key
|
314
|
+
"entity_id": cast("VariableDetails", request["details"]).key
|
283
315
|
if request.get("details")
|
284
316
|
else None,
|
285
317
|
}
|
@@ -287,28 +319,23 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
287
319
|
]
|
288
320
|
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
|
289
321
|
|
290
|
-
def
|
322
|
+
def filter_authorized_dag_ids(
|
291
323
|
self,
|
292
324
|
*,
|
293
325
|
dag_ids: set[str],
|
294
326
|
user: AwsAuthManagerUser,
|
295
|
-
|
327
|
+
method: ResourceMethod = "GET",
|
296
328
|
):
|
297
|
-
if not methods:
|
298
|
-
methods = ["PUT", "GET"]
|
299
|
-
|
300
329
|
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
|
301
330
|
requests_list: list[IsAuthorizedRequest] = []
|
302
331
|
for dag_id in dag_ids:
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
requests[dag_id][cast("ResourceMethod", method)] = request
|
311
|
-
requests_list.append(request)
|
332
|
+
request: IsAuthorizedRequest = {
|
333
|
+
"method": method,
|
334
|
+
"entity_type": AvpEntities.DAG,
|
335
|
+
"entity_id": dag_id,
|
336
|
+
}
|
337
|
+
requests[dag_id][method] = request
|
338
|
+
requests_list.append(request)
|
312
339
|
|
313
340
|
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
314
341
|
requests=requests_list, user=user
|
@@ -320,67 +347,10 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
320
347
|
)
|
321
348
|
return result["decision"] == "ALLOW"
|
322
349
|
|
323
|
-
return {
|
324
|
-
dag_id
|
325
|
-
for dag_id in dag_ids
|
326
|
-
if (
|
327
|
-
"GET" in methods
|
328
|
-
and _has_access_to_dag(requests[dag_id]["GET"])
|
329
|
-
or "PUT" in methods
|
330
|
-
and _has_access_to_dag(requests[dag_id]["PUT"])
|
331
|
-
)
|
332
|
-
}
|
333
|
-
|
334
|
-
def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]:
|
335
|
-
"""
|
336
|
-
Filter menu items based on user permissions.
|
337
|
-
|
338
|
-
:param menu_items: list of all menu items
|
339
|
-
"""
|
340
|
-
user = self.get_user()
|
341
|
-
if not user:
|
342
|
-
return []
|
343
|
-
|
344
|
-
requests: dict[str, IsAuthorizedRequest] = {}
|
345
|
-
for menu_item in menu_items:
|
346
|
-
if menu_item.childs:
|
347
|
-
for child in menu_item.childs:
|
348
|
-
requests[child.name] = self._get_menu_item_request(child.name)
|
349
|
-
else:
|
350
|
-
requests[menu_item.name] = self._get_menu_item_request(menu_item.name)
|
351
|
-
|
352
|
-
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
353
|
-
requests=list(requests.values()), user=user
|
354
|
-
)
|
355
|
-
|
356
|
-
def _has_access_to_menu_item(request: IsAuthorizedRequest):
|
357
|
-
result = self.avp_facade.get_batch_is_authorized_single_result(
|
358
|
-
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
359
|
-
)
|
360
|
-
return result["decision"] == "ALLOW"
|
361
|
-
|
362
|
-
accessible_items = []
|
363
|
-
for menu_item in menu_items:
|
364
|
-
if menu_item.childs:
|
365
|
-
accessible_children = []
|
366
|
-
for child in menu_item.childs:
|
367
|
-
if _has_access_to_menu_item(requests[child.name]):
|
368
|
-
accessible_children.append(child)
|
369
|
-
menu_item.childs = accessible_children
|
370
|
-
|
371
|
-
# Display the menu if the user has access to at least one sub item
|
372
|
-
if len(accessible_children) > 0:
|
373
|
-
accessible_items.append(menu_item)
|
374
|
-
elif _has_access_to_menu_item(requests[menu_item.name]):
|
375
|
-
accessible_items.append(menu_item)
|
376
|
-
|
377
|
-
return accessible_items
|
350
|
+
return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
|
378
351
|
|
379
352
|
def get_url_login(self, **kwargs) -> str:
|
380
|
-
return f"{
|
381
|
-
|
382
|
-
def get_url_logout(self) -> str:
|
383
|
-
raise NotImplementedError()
|
353
|
+
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
|
384
354
|
|
385
355
|
@staticmethod
|
386
356
|
def get_cli_commands() -> list[CLICommand]:
|
@@ -409,11 +379,11 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
409
379
|
return app
|
410
380
|
|
411
381
|
@staticmethod
|
412
|
-
def _get_menu_item_request(
|
382
|
+
def _get_menu_item_request(menu_item_text: str) -> IsAuthorizedRequest:
|
413
383
|
return {
|
414
384
|
"method": "MENU",
|
415
385
|
"entity_type": AvpEntities.MENU,
|
416
|
-
"entity_id":
|
386
|
+
"entity_id": menu_item_text,
|
417
387
|
}
|
418
388
|
|
419
389
|
def _check_avp_schema_version(self):
|
@@ -25,7 +25,8 @@ from fastapi import HTTPException, Request
|
|
25
25
|
from starlette import status
|
26
26
|
from starlette.responses import RedirectResponse
|
27
27
|
|
28
|
-
from airflow.api_fastapi.app import get_auth_manager
|
28
|
+
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX, get_auth_manager
|
29
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN
|
29
30
|
from airflow.api_fastapi.common.router import AirflowRouter
|
30
31
|
from airflow.configuration import conf
|
31
32
|
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
|
@@ -79,12 +80,18 @@ def login_callback(request: Request):
|
|
79
80
|
username=saml_auth.get_nameid(),
|
80
81
|
email=attributes["email"][0] if "email" in attributes else None,
|
81
82
|
)
|
82
|
-
|
83
|
+
url = conf.get("api", "base_url")
|
84
|
+
token = get_auth_manager().generate_jwt(user)
|
85
|
+
response = RedirectResponse(url=url, status_code=303)
|
86
|
+
|
87
|
+
secure = conf.has_option("api", "ssl_cert")
|
88
|
+
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
|
89
|
+
return response
|
83
90
|
|
84
91
|
|
85
92
|
def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
86
93
|
request_data = _prepare_request(request)
|
87
|
-
base_url = conf.get(section="
|
94
|
+
base_url = conf.get(section="api", key="base_url")
|
88
95
|
settings = {
|
89
96
|
# We want to keep this flag on in case of errors.
|
90
97
|
# It provides an error reasons, if turned off, it does not
|
@@ -92,7 +99,7 @@ def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
|
92
99
|
"sp": {
|
93
100
|
"entityId": "aws-auth-manager-saml-client",
|
94
101
|
"assertionConsumerService": {
|
95
|
-
"url": f"{base_url}/
|
102
|
+
"url": f"{base_url}{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login_callback",
|
96
103
|
"binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
97
104
|
},
|
98
105
|
},
|
@@ -19,11 +19,14 @@ from __future__ import annotations
|
|
19
19
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
20
20
|
|
21
21
|
try:
|
22
|
-
from airflow.auth.managers.models.base_user import BaseUser
|
22
|
+
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
|
23
23
|
except ImportError:
|
24
|
-
|
25
|
-
|
26
|
-
|
24
|
+
try:
|
25
|
+
from airflow.auth.managers.models.base_user import BaseUser # type: ignore[no-redef]
|
26
|
+
except ImportError:
|
27
|
+
raise AirflowOptionalProviderFeatureException(
|
28
|
+
"Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
|
29
|
+
) from None
|
27
30
|
|
28
31
|
|
29
32
|
class AwsAuthManagerUser(BaseUser):
|
@@ -278,7 +278,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
278
278
|
if not has_exit_codes:
|
279
279
|
return ""
|
280
280
|
reasons = [
|
281
|
-
f
|
281
|
+
f"{container['container_arn']} - {container['reason']}"
|
282
282
|
for container in containers
|
283
283
|
if "reason" in container
|
284
284
|
]
|
@@ -16,15 +16,7 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from
|
20
|
-
from typing import TYPE_CHECKING, cast
|
21
|
-
|
22
|
-
from mypy_boto3_appflow.type_defs import (
|
23
|
-
DestinationFlowConfigTypeDef,
|
24
|
-
SourceFlowConfigTypeDef,
|
25
|
-
TaskTypeDef,
|
26
|
-
TriggerConfigTypeDef,
|
27
|
-
)
|
19
|
+
from typing import TYPE_CHECKING
|
28
20
|
|
29
21
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
30
22
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
@@ -125,11 +117,9 @@ class AppflowHook(AwsGenericHook["AppflowClient"]):
|
|
125
117
|
|
126
118
|
self.conn.update_flow(
|
127
119
|
flowName=response["flowName"],
|
128
|
-
destinationFlowConfigList=
|
129
|
-
|
130
|
-
|
131
|
-
sourceFlowConfig=cast(SourceFlowConfigTypeDef, response["sourceFlowConfig"]),
|
132
|
-
triggerConfig=cast(TriggerConfigTypeDef, response["triggerConfig"]),
|
120
|
+
destinationFlowConfigList=response["destinationFlowConfigList"],
|
121
|
+
sourceFlowConfig=response["sourceFlowConfig"],
|
122
|
+
triggerConfig=response["triggerConfig"],
|
133
123
|
description=response.get("description", "Flow description."),
|
134
|
-
tasks=
|
124
|
+
tasks=tasks,
|
135
125
|
)
|
@@ -146,10 +146,10 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
|
146
146
|
creds = self.get_credentials(region_name=conn_params["region_name"])
|
147
147
|
|
148
148
|
return URL.create(
|
149
|
-
f
|
149
|
+
f"awsathena+{conn_params['driver']}",
|
150
150
|
username=creds.access_key,
|
151
151
|
password=creds.secret_key,
|
152
|
-
host=f
|
152
|
+
host=f"athena.{conn_params['region_name']}.{conn_params['aws_domain']}",
|
153
153
|
port=443,
|
154
154
|
database=conn_params["schema_name"],
|
155
155
|
query={"aws_session_token": creds.token, **self.conn.extra_dejson},
|
@@ -30,6 +30,7 @@ import inspect
|
|
30
30
|
import json
|
31
31
|
import logging
|
32
32
|
import os
|
33
|
+
import warnings
|
33
34
|
from copy import deepcopy
|
34
35
|
from functools import cached_property, wraps
|
35
36
|
from pathlib import Path
|
@@ -41,6 +42,7 @@ import botocore.session
|
|
41
42
|
import jinja2
|
42
43
|
import requests
|
43
44
|
import tenacity
|
45
|
+
from asgiref.sync import sync_to_async
|
44
46
|
from botocore.config import Config
|
45
47
|
from botocore.waiter import Waiter, WaiterModel
|
46
48
|
from dateutil.tz import tzlocal
|
@@ -50,6 +52,7 @@ from airflow.configuration import conf
|
|
50
52
|
from airflow.exceptions import (
|
51
53
|
AirflowException,
|
52
54
|
AirflowNotFoundException,
|
55
|
+
AirflowProviderDeprecationWarning,
|
53
56
|
)
|
54
57
|
from airflow.hooks.base import BaseHook
|
55
58
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
@@ -747,7 +750,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
747
750
|
|
748
751
|
@property
|
749
752
|
def async_conn(self):
|
753
|
+
"""
|
754
|
+
[DEPRECATED] Get an aiobotocore client to use for async operations.
|
755
|
+
|
756
|
+
This property is deprecated. Accessing it in an async context will cause the event loop to block.
|
757
|
+
Use the async method `get_async_conn` instead.
|
758
|
+
"""
|
759
|
+
warnings.warn(
|
760
|
+
"The property `async_conn` is deprecated. Accessing it in an async context will cause the event loop to block. "
|
761
|
+
"Use the async method `get_async_conn` instead.",
|
762
|
+
AirflowProviderDeprecationWarning,
|
763
|
+
stacklevel=2,
|
764
|
+
)
|
765
|
+
|
766
|
+
return self._get_async_conn()
|
767
|
+
|
768
|
+
async def get_async_conn(self):
|
750
769
|
"""Get an aiobotocore client to use for async operations."""
|
770
|
+
# We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
|
771
|
+
# because one of it's arguments `self.region_name` is a `@property` decorated function
|
772
|
+
# calling the cached property `self.conn_config` at the end.
|
773
|
+
return await sync_to_async(self._get_async_conn)()
|
774
|
+
|
775
|
+
def _get_async_conn(self):
|
751
776
|
if not self.client_type:
|
752
777
|
raise ValueError("client_type must be specified.")
|
753
778
|
|
@@ -918,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
918
943
|
self,
|
919
944
|
waiter_name: str,
|
920
945
|
parameters: dict[str, str] | None = None,
|
946
|
+
config_overrides: dict[str, Any] | None = None,
|
921
947
|
deferrable: bool = False,
|
922
948
|
client=None,
|
923
949
|
) -> Waiter:
|
@@ -937,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
937
963
|
:param parameters: will scan the waiter config for the keys of that dict,
|
938
964
|
and replace them with the corresponding value. If a custom waiter has
|
939
965
|
such keys to be expanded, they need to be provided here.
|
966
|
+
Note: cannot be used if parameters are included in config_overrides
|
967
|
+
:param config_overrides: will update values of provided keys in the waiter's
|
968
|
+
config. Only specified keys will be updated.
|
940
969
|
:param deferrable: If True, the waiter is going to be an async custom waiter.
|
941
970
|
An async client must be provided in that case.
|
942
971
|
:param client: The client to use for the waiter's operations
|
@@ -945,14 +974,18 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
945
974
|
|
946
975
|
if deferrable and not client:
|
947
976
|
raise ValueError("client must be provided for a deferrable waiter.")
|
977
|
+
if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
|
978
|
+
raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
|
948
979
|
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
|
949
980
|
client = client or self._client
|
950
981
|
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
|
951
982
|
# Technically if waiter_name is in custom_waiters then self.waiter_path must
|
952
983
|
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
|
953
984
|
with open(self.waiter_path) as config_file:
|
954
|
-
config = json.loads(config_file.read())
|
985
|
+
config: dict = json.loads(config_file.read())
|
955
986
|
|
987
|
+
if config_overrides is not None:
|
988
|
+
config["waiters"][waiter_name].update(config_overrides)
|
956
989
|
config = self._apply_parameters_value(config, waiter_name, parameters)
|
957
990
|
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
|
958
991
|
waiter_name
|
@@ -416,8 +416,7 @@ class BatchClientHook(AwsBaseHook):
|
|
416
416
|
)
|
417
417
|
else:
|
418
418
|
raise AirflowException(
|
419
|
-
f"AWS Batch job ({job_id}) description error: exceeded status_retries "
|
420
|
-
f"({self.status_retries})"
|
419
|
+
f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
|
421
420
|
)
|
422
421
|
|
423
422
|
@staticmethod
|
@@ -30,7 +30,7 @@ import json
|
|
30
30
|
import sys
|
31
31
|
from copy import deepcopy
|
32
32
|
from pathlib import Path
|
33
|
-
from typing import TYPE_CHECKING, Callable
|
33
|
+
from typing import TYPE_CHECKING, Any, Callable
|
34
34
|
|
35
35
|
import botocore.client
|
36
36
|
import botocore.exceptions
|
@@ -144,7 +144,12 @@ class BatchWaitersHook(BatchClientHook):
|
|
144
144
|
return self._waiter_model
|
145
145
|
|
146
146
|
def get_waiter(
|
147
|
-
self,
|
147
|
+
self,
|
148
|
+
waiter_name: str,
|
149
|
+
parameters: dict[str, str] | None = None,
|
150
|
+
config_overrides: dict[str, Any] | None = None,
|
151
|
+
deferrable: bool = False,
|
152
|
+
client=None,
|
148
153
|
) -> botocore.waiter.Waiter:
|
149
154
|
"""
|
150
155
|
Get an AWS Batch service waiter, using the configured ``.waiter_model``.
|
@@ -175,7 +180,10 @@ class BatchWaitersHook(BatchClientHook):
|
|
175
180
|
the name (including the casing) of the key name in the waiter
|
176
181
|
model file (typically this is CamelCasing); see ``.list_waiters``.
|
177
182
|
|
178
|
-
:param
|
183
|
+
:param parameters: unused, just here to match the method signature in base_aws
|
184
|
+
:param config_overrides: unused, just here to match the method signature in base_aws
|
185
|
+
:param deferrable: unused, just here to match the method signature in base_aws
|
186
|
+
:param client: unused, just here to match the method signature in base_aws
|
179
187
|
|
180
188
|
:return: a waiter object for the named AWS Batch service
|
181
189
|
"""
|
@@ -292,7 +292,9 @@ class DmsHook(AwsBaseHook):
|
|
292
292
|
return arn
|
293
293
|
|
294
294
|
except ClientError as err:
|
295
|
-
err_str =
|
295
|
+
err_str = (
|
296
|
+
f"Error: {err.get('Error', '').get('Code', '')}: {err.get('Error', '').get('Message', '')}"
|
297
|
+
)
|
296
298
|
self.log.error("Error while creating replication config: %s", err_str)
|
297
299
|
raise err
|
298
300
|
|
@@ -173,7 +173,7 @@ class EC2Hook(AwsBaseHook):
|
|
173
173
|
return [instance["InstanceId"] for instance in self.get_instances(filters=filters)]
|
174
174
|
|
175
175
|
async def get_instance_state_async(self, instance_id: str) -> str:
|
176
|
-
async with self.
|
176
|
+
async with await self.get_async_conn() as client:
|
177
177
|
response = await client.describe_instances(InstanceIds=[instance_id])
|
178
178
|
return response["Reservations"][0]["Instances"][0]["State"]["Name"]
|
179
179
|
|
@@ -35,7 +35,6 @@ from botocore.signers import RequestSigner
|
|
35
35
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
36
36
|
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
37
37
|
from airflow.utils import yaml
|
38
|
-
from airflow.utils.json import AirflowJsonEncoder
|
39
38
|
|
40
39
|
DEFAULT_PAGINATION_TOKEN = ""
|
41
40
|
STS_TOKEN_EXPIRES_IN = 60
|
@@ -315,7 +314,7 @@ class EksHook(AwsBaseHook):
|
|
315
314
|
)
|
316
315
|
if verbose:
|
317
316
|
cluster_data = response.get("cluster")
|
318
|
-
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data,
|
317
|
+
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, default=repr))
|
319
318
|
return response
|
320
319
|
|
321
320
|
def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict:
|
@@ -343,7 +342,7 @@ class EksHook(AwsBaseHook):
|
|
343
342
|
nodegroup_data = response.get("nodegroup")
|
344
343
|
self.log.info(
|
345
344
|
"Amazon EKS managed node group details: %s",
|
346
|
-
json.dumps(nodegroup_data,
|
345
|
+
json.dumps(nodegroup_data, default=repr),
|
347
346
|
)
|
348
347
|
return response
|
349
348
|
|
@@ -374,9 +373,7 @@ class EksHook(AwsBaseHook):
|
|
374
373
|
)
|
375
374
|
if verbose:
|
376
375
|
fargate_profile_data = response.get("fargateProfile")
|
377
|
-
self.log.info(
|
378
|
-
"AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder)
|
379
|
-
)
|
376
|
+
self.log.info("AWS Fargate profile details: %s", json.dumps(fargate_profile_data, default=repr))
|
380
377
|
return response
|
381
378
|
|
382
379
|
def get_cluster_state(self, clusterName: str) -> ClusterStates:
|