apache-airflow-providers-amazon 9.14.0__py3-none-any.whl → 9.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +3 -3
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +106 -5
- airflow/providers/amazon/aws/auth_manager/routes/login.py +7 -1
- airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py +5 -1
- airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +6 -2
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +2 -2
- airflow/providers/amazon/aws/hooks/batch_client.py +4 -6
- airflow/providers/amazon/aws/hooks/batch_waiters.py +0 -1
- airflow/providers/amazon/aws/hooks/chime.py +1 -1
- airflow/providers/amazon/aws/hooks/datasync.py +3 -3
- airflow/providers/amazon/aws/hooks/firehose.py +56 -0
- airflow/providers/amazon/aws/hooks/glue.py +7 -1
- airflow/providers/amazon/aws/hooks/kinesis.py +31 -13
- airflow/providers/amazon/aws/hooks/mwaa.py +38 -7
- airflow/providers/amazon/aws/hooks/redshift_sql.py +20 -6
- airflow/providers/amazon/aws/hooks/s3.py +41 -11
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/hooks/ses.py +76 -10
- airflow/providers/amazon/aws/hooks/sns.py +74 -18
- airflow/providers/amazon/aws/hooks/sqs.py +64 -11
- airflow/providers/amazon/aws/hooks/ssm.py +34 -6
- airflow/providers/amazon/aws/hooks/step_function.py +1 -1
- airflow/providers/amazon/aws/links/base_aws.py +1 -1
- airflow/providers/amazon/aws/notifications/ses.py +139 -0
- airflow/providers/amazon/aws/notifications/sns.py +16 -1
- airflow/providers/amazon/aws/notifications/sqs.py +17 -1
- airflow/providers/amazon/aws/operators/base_aws.py +2 -2
- airflow/providers/amazon/aws/operators/bedrock.py +2 -0
- airflow/providers/amazon/aws/operators/cloud_formation.py +2 -2
- airflow/providers/amazon/aws/operators/datasync.py +2 -1
- airflow/providers/amazon/aws/operators/emr.py +44 -33
- airflow/providers/amazon/aws/operators/mwaa.py +12 -3
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/operators/ssm.py +122 -17
- airflow/providers/amazon/aws/secrets/secrets_manager.py +3 -4
- airflow/providers/amazon/aws/sensors/base_aws.py +2 -2
- airflow/providers/amazon/aws/sensors/mwaa.py +14 -1
- airflow/providers/amazon/aws/sensors/s3.py +27 -13
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/sensors/ssm.py +33 -17
- airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +3 -3
- airflow/providers/amazon/aws/transfers/base.py +5 -5
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +4 -4
- airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/ftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +48 -5
- airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +2 -5
- airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -1
- airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/local_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +6 -6
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_ftp.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +6 -6
- airflow/providers/amazon/aws/transfers/s3_to_sftp.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_sql.py +1 -1
- airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/sftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +4 -5
- airflow/providers/amazon/aws/triggers/bedrock.py +1 -1
- airflow/providers/amazon/aws/triggers/s3.py +29 -2
- airflow/providers/amazon/aws/triggers/ssm.py +17 -1
- airflow/providers/amazon/aws/utils/connection_wrapper.py +2 -5
- airflow/providers/amazon/aws/utils/mixins.py +1 -1
- airflow/providers/amazon/aws/utils/waiter.py +2 -2
- airflow/providers/amazon/aws/waiters/emr.json +6 -6
- airflow/providers/amazon/get_provider_info.py +19 -1
- airflow/providers/amazon/version_compat.py +19 -16
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/METADATA +25 -19
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/RECORD +79 -76
- apache_airflow_providers_amazon-9.18.0rc2.dist-info/licenses/NOTICE +5 -0
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/entry_points.txt +0 -0
- {airflow/providers/amazon → apache_airflow_providers_amazon-9.18.0rc2.dist-info/licenses}/LICENSE +0 -0
|
@@ -29,11 +29,11 @@ from airflow import __version__ as airflow_version
|
|
|
29
29
|
|
|
30
30
|
__all__ = ["__version__"]
|
|
31
31
|
|
|
32
|
-
__version__ = "9.
|
|
32
|
+
__version__ = "9.18.0"
|
|
33
33
|
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
|
35
|
-
"2.
|
|
35
|
+
"2.11.0"
|
|
36
36
|
):
|
|
37
37
|
raise RuntimeError(
|
|
38
|
-
f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.
|
|
38
|
+
f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.11.0+"
|
|
39
39
|
)
|
|
@@ -338,6 +338,37 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
|
338
338
|
]
|
|
339
339
|
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
|
|
340
340
|
|
|
341
|
+
def filter_authorized_connections(
|
|
342
|
+
self,
|
|
343
|
+
*,
|
|
344
|
+
conn_ids: set[str],
|
|
345
|
+
user: AwsAuthManagerUser,
|
|
346
|
+
method: ResourceMethod = "GET",
|
|
347
|
+
team_name: str | None = None,
|
|
348
|
+
) -> set[str]:
|
|
349
|
+
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
|
|
350
|
+
requests_list: list[IsAuthorizedRequest] = []
|
|
351
|
+
for conn_id in conn_ids:
|
|
352
|
+
request: IsAuthorizedRequest = {
|
|
353
|
+
"method": method,
|
|
354
|
+
"entity_type": AvpEntities.CONNECTION,
|
|
355
|
+
"entity_id": conn_id,
|
|
356
|
+
}
|
|
357
|
+
requests[conn_id][method] = request
|
|
358
|
+
requests_list.append(request)
|
|
359
|
+
|
|
360
|
+
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
|
361
|
+
requests=requests_list, user=user
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
return {
|
|
365
|
+
conn_id
|
|
366
|
+
for conn_id in conn_ids
|
|
367
|
+
if self._is_authorized_from_batch_response(
|
|
368
|
+
batch_is_authorized_results, requests[conn_id][method], user
|
|
369
|
+
)
|
|
370
|
+
}
|
|
371
|
+
|
|
341
372
|
def filter_authorized_dag_ids(
|
|
342
373
|
self,
|
|
343
374
|
*,
|
|
@@ -361,13 +392,75 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
|
361
392
|
requests=requests_list, user=user
|
|
362
393
|
)
|
|
363
394
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
395
|
+
return {
|
|
396
|
+
dag_id
|
|
397
|
+
for dag_id in dag_ids
|
|
398
|
+
if self._is_authorized_from_batch_response(
|
|
399
|
+
batch_is_authorized_results, requests[dag_id][method], user
|
|
367
400
|
)
|
|
368
|
-
|
|
401
|
+
}
|
|
369
402
|
|
|
370
|
-
|
|
403
|
+
def filter_authorized_pools(
|
|
404
|
+
self,
|
|
405
|
+
*,
|
|
406
|
+
pool_names: set[str],
|
|
407
|
+
user: AwsAuthManagerUser,
|
|
408
|
+
method: ResourceMethod = "GET",
|
|
409
|
+
team_name: str | None = None,
|
|
410
|
+
) -> set[str]:
|
|
411
|
+
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
|
|
412
|
+
requests_list: list[IsAuthorizedRequest] = []
|
|
413
|
+
for pool_name in pool_names:
|
|
414
|
+
request: IsAuthorizedRequest = {
|
|
415
|
+
"method": method,
|
|
416
|
+
"entity_type": AvpEntities.POOL,
|
|
417
|
+
"entity_id": pool_name,
|
|
418
|
+
}
|
|
419
|
+
requests[pool_name][method] = request
|
|
420
|
+
requests_list.append(request)
|
|
421
|
+
|
|
422
|
+
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
|
423
|
+
requests=requests_list, user=user
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
return {
|
|
427
|
+
pool_name
|
|
428
|
+
for pool_name in pool_names
|
|
429
|
+
if self._is_authorized_from_batch_response(
|
|
430
|
+
batch_is_authorized_results, requests[pool_name][method], user
|
|
431
|
+
)
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
def filter_authorized_variables(
|
|
435
|
+
self,
|
|
436
|
+
*,
|
|
437
|
+
variable_keys: set[str],
|
|
438
|
+
user: AwsAuthManagerUser,
|
|
439
|
+
method: ResourceMethod = "GET",
|
|
440
|
+
team_name: str | None = None,
|
|
441
|
+
) -> set[str]:
|
|
442
|
+
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
|
|
443
|
+
requests_list: list[IsAuthorizedRequest] = []
|
|
444
|
+
for variable_key in variable_keys:
|
|
445
|
+
request: IsAuthorizedRequest = {
|
|
446
|
+
"method": method,
|
|
447
|
+
"entity_type": AvpEntities.VARIABLE,
|
|
448
|
+
"entity_id": variable_key,
|
|
449
|
+
}
|
|
450
|
+
requests[variable_key][method] = request
|
|
451
|
+
requests_list.append(request)
|
|
452
|
+
|
|
453
|
+
batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
|
|
454
|
+
requests=requests_list, user=user
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
return {
|
|
458
|
+
variable_key
|
|
459
|
+
for variable_key in variable_keys
|
|
460
|
+
if self._is_authorized_from_batch_response(
|
|
461
|
+
batch_is_authorized_results, requests[variable_key][method], user
|
|
462
|
+
)
|
|
463
|
+
}
|
|
371
464
|
|
|
372
465
|
def get_url_login(self, **kwargs) -> str:
|
|
373
466
|
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
|
|
@@ -406,6 +499,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
|
406
499
|
"entity_id": menu_item_text,
|
|
407
500
|
}
|
|
408
501
|
|
|
502
|
+
def _is_authorized_from_batch_response(
|
|
503
|
+
self, batch_is_authorized_results: list[dict], request: IsAuthorizedRequest, user: AwsAuthManagerUser
|
|
504
|
+
):
|
|
505
|
+
result = self.avp_facade.get_batch_is_authorized_single_result(
|
|
506
|
+
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
|
|
507
|
+
)
|
|
508
|
+
return result["decision"] == "ALLOW"
|
|
509
|
+
|
|
409
510
|
def _check_avp_schema_version(self):
|
|
410
511
|
if not self.avp_facade.is_policy_store_schema_up_to_date():
|
|
411
512
|
self.log.warning(
|
|
@@ -35,6 +35,7 @@ from airflow.configuration import conf
|
|
|
35
35
|
from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
|
|
36
36
|
from airflow.providers.amazon.aws.auth_manager.datamodels.login import LoginResponse
|
|
37
37
|
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
|
|
38
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_1_PLUS
|
|
38
39
|
|
|
39
40
|
try:
|
|
40
41
|
from onelogin.saml2.auth import OneLogin_Saml2_Auth
|
|
@@ -101,7 +102,12 @@ def login_callback(request: Request):
|
|
|
101
102
|
if relay_state == "login-redirect":
|
|
102
103
|
response = RedirectResponse(url=url, status_code=303)
|
|
103
104
|
secure = bool(conf.get("api", "ssl_cert", fallback=""))
|
|
104
|
-
|
|
105
|
+
# In Airflow 3.1.1 authentication changes, front-end no longer handle the token
|
|
106
|
+
# See https://github.com/apache/airflow/pull/55506
|
|
107
|
+
if AIRFLOW_V_3_1_1_PLUS:
|
|
108
|
+
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure, httponly=True)
|
|
109
|
+
else:
|
|
110
|
+
response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
|
|
105
111
|
return response
|
|
106
112
|
if relay_state == "login-token":
|
|
107
113
|
return LoginResponse(access_token=token)
|
|
@@ -66,7 +66,11 @@ def run_and_report(command, task_key):
|
|
|
66
66
|
try:
|
|
67
67
|
log.info("Starting execution for task: %s", task_key)
|
|
68
68
|
result = subprocess.run(
|
|
69
|
-
command,
|
|
69
|
+
command,
|
|
70
|
+
check=False,
|
|
71
|
+
shell=isinstance(command, str),
|
|
72
|
+
stdout=subprocess.PIPE,
|
|
73
|
+
stderr=subprocess.STDOUT,
|
|
70
74
|
)
|
|
71
75
|
return_code = result.returncode
|
|
72
76
|
log.info("Execution completed for task %s with return code %s", task_key, return_code)
|
|
@@ -49,7 +49,7 @@ try:
|
|
|
49
49
|
except ImportError:
|
|
50
50
|
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
|
|
51
51
|
|
|
52
|
-
from
|
|
52
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
|
53
53
|
|
|
54
54
|
if TYPE_CHECKING:
|
|
55
55
|
from sqlalchemy.orm import Session
|
|
@@ -40,11 +40,15 @@ MULTI_LINE_QUERY_LOG_PREFIX = "\n\t\t"
|
|
|
40
40
|
|
|
41
41
|
def query_params_to_string(params: dict[str, str | Collection[str]]) -> str:
|
|
42
42
|
result = ""
|
|
43
|
-
for key,
|
|
43
|
+
for key, original_value in params.items():
|
|
44
|
+
value: str | Collection[str]
|
|
44
45
|
if key == "QueryString":
|
|
45
46
|
value = (
|
|
46
|
-
MULTI_LINE_QUERY_LOG_PREFIX
|
|
47
|
+
MULTI_LINE_QUERY_LOG_PREFIX
|
|
48
|
+
+ str(original_value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip()
|
|
47
49
|
)
|
|
50
|
+
else:
|
|
51
|
+
value = original_value
|
|
48
52
|
result += f"\t{key}: {value}\n"
|
|
49
53
|
return result.rstrip()
|
|
50
54
|
|
|
@@ -56,7 +56,7 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
|
|
56
56
|
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
|
57
57
|
|
|
58
58
|
.. note::
|
|
59
|
-
get_uri() depends on SQLAlchemy and PyAthena
|
|
59
|
+
get_uri() depends on SQLAlchemy and PyAthena
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
62
|
conn_name_attr = "athena_conn_id"
|
|
@@ -163,7 +163,7 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
|
|
163
163
|
port=443,
|
|
164
164
|
database=conn_params["schema_name"],
|
|
165
165
|
query={"aws_session_token": creds.token, **self.conn.extra_dejson},
|
|
166
|
-
)
|
|
166
|
+
).render_as_string(hide_password=False)
|
|
167
167
|
|
|
168
168
|
def get_conn(self) -> AthenaConnection:
|
|
169
169
|
"""Get a ``pyathena.Connection`` object."""
|
|
@@ -60,7 +60,7 @@ from airflow.exceptions import (
|
|
|
60
60
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
|
61
61
|
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
|
|
62
62
|
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
|
63
|
-
from airflow.providers.
|
|
63
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
64
64
|
from airflow.providers_manager import ProvidersManager
|
|
65
65
|
from airflow.utils.helpers import exactly_one
|
|
66
66
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
|
@@ -790,7 +790,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
|
790
790
|
async def get_async_conn(self):
|
|
791
791
|
"""Get an aiobotocore client to use for async operations."""
|
|
792
792
|
# We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
|
|
793
|
-
# because one of
|
|
793
|
+
# because one of its arguments `self.region_name` is a `@property` decorated function
|
|
794
794
|
# calling the cached property `self.conn_config` at the end.
|
|
795
795
|
return await sync_to_async(self._get_async_conn)()
|
|
796
796
|
|
|
@@ -386,8 +386,7 @@ class BatchClientHook(AwsBaseHook):
|
|
|
386
386
|
)
|
|
387
387
|
if job_status in match_status:
|
|
388
388
|
return True
|
|
389
|
-
|
|
390
|
-
raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")
|
|
389
|
+
raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")
|
|
391
390
|
|
|
392
391
|
def get_job_description(self, job_id: str) -> dict:
|
|
393
392
|
"""
|
|
@@ -426,10 +425,9 @@ class BatchClientHook(AwsBaseHook):
|
|
|
426
425
|
"check Amazon Provider AWS Connection documentation for more details.",
|
|
427
426
|
str(err),
|
|
428
427
|
)
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
)
|
|
428
|
+
raise AirflowException(
|
|
429
|
+
f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
|
|
430
|
+
)
|
|
433
431
|
|
|
434
432
|
@staticmethod
|
|
435
433
|
def parse_job_description(job_id: str, response: dict) -> dict:
|
|
@@ -33,7 +33,7 @@ class ChimeWebhookHook(HttpHook):
|
|
|
33
33
|
"""
|
|
34
34
|
Interact with Amazon Chime Webhooks to create notifications.
|
|
35
35
|
|
|
36
|
-
.. warning:: This hook is only designed to work with web hooks and not
|
|
36
|
+
.. warning:: This hook is only designed to work with web hooks and not chatbots.
|
|
37
37
|
|
|
38
38
|
:param chime_conn_id: :ref:`Amazon Chime Connection ID <howto/connection:chime>`
|
|
39
39
|
with Endpoint as `https://hooks.chime.aws` and the webhook token
|
|
@@ -21,8 +21,9 @@ from __future__ import annotations
|
|
|
21
21
|
import time
|
|
22
22
|
from urllib.parse import urlsplit
|
|
23
23
|
|
|
24
|
-
from airflow.exceptions import AirflowBadRequest, AirflowException
|
|
24
|
+
from airflow.exceptions import AirflowBadRequest, AirflowException
|
|
25
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
|
26
|
+
from airflow.providers.common.compat.sdk import AirflowTaskTimeout
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class DataSyncHook(AwsBaseHook):
|
|
@@ -319,5 +320,4 @@ class DataSyncHook(AwsBaseHook):
|
|
|
319
320
|
else:
|
|
320
321
|
raise AirflowException(f"Unknown status: {status}") # Should never happen
|
|
321
322
|
time.sleep(self.wait_interval_seconds)
|
|
322
|
-
|
|
323
|
-
raise AirflowTaskTimeout("Max iterations exceeded!")
|
|
323
|
+
raise AirflowTaskTimeout("Max iterations exceeded!")
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
|
4
|
+
# distributed with this work for additional information
|
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
|
7
|
+
# "License"); you may not use this file except in compliance
|
|
8
|
+
# with the License. You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
|
13
|
+
# software distributed under the License is distributed on an
|
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
15
|
+
# KIND, either express or implied. See the License for the
|
|
16
|
+
# specific language governing permissions and limitations
|
|
17
|
+
# under the License.
|
|
18
|
+
"""This module contains AWS Firehose hook."""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
|
|
24
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FirehoseHook(AwsBaseHook):
|
|
28
|
+
"""
|
|
29
|
+
Interact with Amazon Kinesis Firehose.
|
|
30
|
+
|
|
31
|
+
Provide thick wrapper around :external+boto3:py:class:`boto3.client("firehose") <Firehose.Client>`.
|
|
32
|
+
|
|
33
|
+
:param delivery_stream: Name of the delivery stream
|
|
34
|
+
|
|
35
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
|
36
|
+
are passed down to the underlying AwsBaseHook.
|
|
37
|
+
|
|
38
|
+
.. seealso::
|
|
39
|
+
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, delivery_stream: str, *args, **kwargs) -> None:
|
|
43
|
+
self.delivery_stream = delivery_stream
|
|
44
|
+
kwargs["client_type"] = "firehose"
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
|
|
47
|
+
def put_records(self, records: Iterable) -> dict:
|
|
48
|
+
"""
|
|
49
|
+
Write batch records to Kinesis Firehose.
|
|
50
|
+
|
|
51
|
+
.. seealso::
|
|
52
|
+
- :external+boto3:py:meth:`Firehose.Client.put_record_batch`
|
|
53
|
+
|
|
54
|
+
:param records: list of records
|
|
55
|
+
"""
|
|
56
|
+
return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
|
|
@@ -565,7 +565,13 @@ class GlueDataQualityHook(AwsBaseHook):
|
|
|
565
565
|
Rule_3 ColumnLength "marketplace" between 1 and 2 FAIL {'Column.marketplace.MaximumLength': 9.0, 'Column.marketplace.MinimumLength': 3.0} Value: 9.0 does not meet the constraint requirement!
|
|
566
566
|
|
|
567
567
|
"""
|
|
568
|
-
|
|
568
|
+
try:
|
|
569
|
+
import pandas as pd
|
|
570
|
+
except ImportError:
|
|
571
|
+
self.log.warning(
|
|
572
|
+
"Pandas is not installed. Please install pandas to see the detailed Data Quality results."
|
|
573
|
+
)
|
|
574
|
+
return
|
|
569
575
|
|
|
570
576
|
pd.set_option("display.max_rows", None)
|
|
571
577
|
pd.set_option("display.max_columns", None)
|
|
@@ -19,12 +19,14 @@
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
import warnings
|
|
23
23
|
|
|
24
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
24
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
|
26
|
+
from airflow.providers.amazon.aws.hooks.firehose import FirehoseHook as _FirehoseHook
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
class FirehoseHook(
|
|
29
|
+
class FirehoseHook(_FirehoseHook):
|
|
28
30
|
"""
|
|
29
31
|
Interact with Amazon Kinesis Firehose.
|
|
30
32
|
|
|
@@ -37,20 +39,36 @@ class FirehoseHook(AwsBaseHook):
|
|
|
37
39
|
|
|
38
40
|
.. seealso::
|
|
39
41
|
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
|
42
|
+
.. deprecated::
|
|
43
|
+
This hook was moved. Import from
|
|
44
|
+
:class:`airflow.providers.amazon.aws.hooks.firehose.FirehoseHook`
|
|
45
|
+
instead of kinesis.py
|
|
40
46
|
"""
|
|
41
47
|
|
|
42
|
-
def __init__(self,
|
|
43
|
-
|
|
44
|
-
|
|
48
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
49
|
+
warnings.warn(
|
|
50
|
+
"Importing FirehoseHook from kinesis.py is deprecated "
|
|
51
|
+
"and will be removed in a future release. "
|
|
52
|
+
"Please import it from firehose.py instead.",
|
|
53
|
+
AirflowProviderDeprecationWarning,
|
|
54
|
+
stacklevel=2,
|
|
55
|
+
)
|
|
45
56
|
super().__init__(*args, **kwargs)
|
|
46
57
|
|
|
47
|
-
def put_records(self, records: Iterable):
|
|
48
|
-
"""
|
|
49
|
-
Write batch records to Kinesis Firehose.
|
|
50
58
|
|
|
51
|
-
|
|
52
|
-
|
|
59
|
+
class KinesisHook(AwsBaseHook):
|
|
60
|
+
"""
|
|
61
|
+
Interact with Amazon Kinesis.
|
|
62
|
+
|
|
63
|
+
Provide thin wrapper around :external+boto3:py:class:`boto3.client("kinesis") <Kinesis.Client>`.
|
|
64
|
+
|
|
65
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
|
66
|
+
are passed down to the underlying AwsBaseHook.
|
|
53
67
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
68
|
+
.. seealso::
|
|
69
|
+
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
73
|
+
kwargs["client_type"] = "kinesis"
|
|
74
|
+
super().__init__(*args, **kwargs)
|
|
@@ -18,6 +18,9 @@
|
|
|
18
18
|
|
|
19
19
|
from __future__ import annotations
|
|
20
20
|
|
|
21
|
+
import warnings
|
|
22
|
+
from typing import Literal
|
|
23
|
+
|
|
21
24
|
import requests
|
|
22
25
|
from botocore.exceptions import ClientError
|
|
23
26
|
|
|
@@ -55,6 +58,7 @@ class MwaaHook(AwsBaseHook):
|
|
|
55
58
|
body: dict | None = None,
|
|
56
59
|
query_params: dict | None = None,
|
|
57
60
|
generate_local_token: bool = False,
|
|
61
|
+
airflow_version: Literal[2, 3] | None = None,
|
|
58
62
|
) -> dict:
|
|
59
63
|
"""
|
|
60
64
|
Invoke the REST API on the Airflow webserver with the specified inputs.
|
|
@@ -70,6 +74,8 @@ class MwaaHook(AwsBaseHook):
|
|
|
70
74
|
:param generate_local_token: If True, only the local web token method is used without trying boto's
|
|
71
75
|
`invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
|
|
72
76
|
boto's `invoke_rest_api`
|
|
77
|
+
:param airflow_version: The Airflow major version the MWAA environment runs.
|
|
78
|
+
This parameter is only used if the local web token method is used to call Airflow API.
|
|
73
79
|
"""
|
|
74
80
|
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
|
|
75
81
|
body = {k: v for k, v in body.items() if v is not None} if body else {}
|
|
@@ -83,7 +89,7 @@ class MwaaHook(AwsBaseHook):
|
|
|
83
89
|
}
|
|
84
90
|
|
|
85
91
|
if generate_local_token:
|
|
86
|
-
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
|
92
|
+
return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs)
|
|
87
93
|
|
|
88
94
|
try:
|
|
89
95
|
response = self.conn.invoke_rest_api(**api_kwargs)
|
|
@@ -100,7 +106,7 @@ class MwaaHook(AwsBaseHook):
|
|
|
100
106
|
self.log.info(
|
|
101
107
|
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
|
|
102
108
|
)
|
|
103
|
-
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
|
109
|
+
return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs)
|
|
104
110
|
to_log = e.response
|
|
105
111
|
# ResponseMetadata is removed because it contains data that is either very unlikely to be
|
|
106
112
|
# useful in XComs and logs, or redundant given the data already included in the response
|
|
@@ -110,14 +116,35 @@ class MwaaHook(AwsBaseHook):
|
|
|
110
116
|
|
|
111
117
|
def _invoke_rest_api_using_local_session_token(
|
|
112
118
|
self,
|
|
119
|
+
airflow_version: Literal[2, 3] | None = None,
|
|
113
120
|
**api_kwargs,
|
|
114
121
|
) -> dict:
|
|
122
|
+
if not airflow_version:
|
|
123
|
+
warnings.warn(
|
|
124
|
+
"The parameter ``airflow_version`` in ``MwaaHook.invoke_rest_api`` is not "
|
|
125
|
+
"specified and the local web token method is being used. "
|
|
126
|
+
"The default Airflow version being used is 2 but this value will change in the future. "
|
|
127
|
+
"To avoid any unexpected behavior, please explicitly specify the Airflow version.",
|
|
128
|
+
FutureWarning,
|
|
129
|
+
stacklevel=3,
|
|
130
|
+
)
|
|
131
|
+
airflow_version = 2
|
|
132
|
+
|
|
115
133
|
try:
|
|
116
|
-
session, hostname = self._get_session_conn(api_kwargs["Name"])
|
|
134
|
+
session, hostname, login_response = self._get_session_conn(api_kwargs["Name"], airflow_version)
|
|
135
|
+
|
|
136
|
+
headers = {}
|
|
137
|
+
if airflow_version == 3:
|
|
138
|
+
headers = {
|
|
139
|
+
"Authorization": f"Bearer {login_response.cookies['_token']}",
|
|
140
|
+
"Content-Type": "application/json",
|
|
141
|
+
}
|
|
117
142
|
|
|
143
|
+
api_version = "v1" if airflow_version == 2 else "v2"
|
|
118
144
|
response = session.request(
|
|
119
145
|
method=api_kwargs["Method"],
|
|
120
|
-
url=f"https://{hostname}/api/
|
|
146
|
+
url=f"https://{hostname}/api/{api_version}{api_kwargs['Path']}",
|
|
147
|
+
headers=headers,
|
|
121
148
|
params=api_kwargs["QueryParameters"],
|
|
122
149
|
json=api_kwargs["Body"],
|
|
123
150
|
timeout=10,
|
|
@@ -134,15 +161,19 @@ class MwaaHook(AwsBaseHook):
|
|
|
134
161
|
}
|
|
135
162
|
|
|
136
163
|
# Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
|
|
137
|
-
def _get_session_conn(self, env_name: str) -> tuple:
|
|
164
|
+
def _get_session_conn(self, env_name: str, airflow_version: Literal[2, 3]) -> tuple:
|
|
138
165
|
create_token_response = self.conn.create_web_login_token(Name=env_name)
|
|
139
166
|
web_server_hostname = create_token_response["WebServerHostname"]
|
|
140
167
|
web_token = create_token_response["WebToken"]
|
|
141
168
|
|
|
142
|
-
login_url =
|
|
169
|
+
login_url = (
|
|
170
|
+
f"https://{web_server_hostname}/aws_mwaa/login"
|
|
171
|
+
if airflow_version == 2
|
|
172
|
+
else f"https://{web_server_hostname}/pluginsv2/aws_mwaa/login"
|
|
173
|
+
)
|
|
143
174
|
login_payload = {"token": web_token}
|
|
144
175
|
session = requests.Session()
|
|
145
176
|
login_response = session.post(login_url, data=login_payload, timeout=10)
|
|
146
177
|
login_response.raise_for_status()
|
|
147
178
|
|
|
148
|
-
return session, web_server_hostname
|
|
179
|
+
return session, web_server_hostname, login_response
|
|
@@ -51,7 +51,7 @@ class RedshiftSQLHook(DbApiHook):
|
|
|
51
51
|
:ref:`Amazon Redshift connection id<howto/connection:redshift>`
|
|
52
52
|
|
|
53
53
|
.. note::
|
|
54
|
-
get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift
|
|
54
|
+
get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift.
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
57
|
conn_name_attr = "redshift_conn_id"
|
|
@@ -155,10 +155,21 @@ class RedshiftSQLHook(DbApiHook):
|
|
|
155
155
|
if "user" in conn_params:
|
|
156
156
|
conn_params["username"] = conn_params.pop("user")
|
|
157
157
|
|
|
158
|
-
#
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
158
|
+
# Use URL.create for SQLAlchemy 2 compatibility
|
|
159
|
+
username = conn_params.get("username")
|
|
160
|
+
password = conn_params.get("password")
|
|
161
|
+
host = conn_params.get("host")
|
|
162
|
+
port = conn_params.get("port")
|
|
163
|
+
database = conn_params.get("database")
|
|
164
|
+
|
|
165
|
+
return URL.create(
|
|
166
|
+
drivername="postgresql",
|
|
167
|
+
username=str(username) if username is not None else None,
|
|
168
|
+
password=str(password) if password is not None else None,
|
|
169
|
+
host=str(host) if host is not None else None,
|
|
170
|
+
port=int(port) if port is not None else None,
|
|
171
|
+
database=str(database) if database is not None else None,
|
|
172
|
+
).render_as_string(hide_password=False)
|
|
162
173
|
|
|
163
174
|
def get_sqlalchemy_engine(self, engine_kwargs=None):
|
|
164
175
|
"""Overridden to pass Redshift-specific arguments."""
|
|
@@ -237,7 +248,10 @@ class RedshiftSQLHook(DbApiHook):
|
|
|
237
248
|
region_name = AwsBaseHook(aws_conn_id=self.aws_conn_id).region_name
|
|
238
249
|
identifier = f"{cluster_identifier}.{region_name}"
|
|
239
250
|
if not cluster_identifier:
|
|
240
|
-
|
|
251
|
+
if connection.host:
|
|
252
|
+
identifier = self._get_identifier_from_hostname(connection.host)
|
|
253
|
+
else:
|
|
254
|
+
raise AirflowException("Host is required when cluster_identifier is not provided.")
|
|
241
255
|
return f"{identifier}:{port}"
|
|
242
256
|
|
|
243
257
|
def _get_identifier_from_hostname(self, hostname: str) -> str:
|