apache-airflow-providers-amazon 9.14.0__py3-none-any.whl → 9.18.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. airflow/providers/amazon/__init__.py +3 -3
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +106 -5
  3. airflow/providers/amazon/aws/auth_manager/routes/login.py +7 -1
  4. airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py +5 -1
  5. airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +1 -1
  6. airflow/providers/amazon/aws/hooks/athena.py +6 -2
  7. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
  8. airflow/providers/amazon/aws/hooks/base_aws.py +2 -2
  9. airflow/providers/amazon/aws/hooks/batch_client.py +4 -6
  10. airflow/providers/amazon/aws/hooks/batch_waiters.py +0 -1
  11. airflow/providers/amazon/aws/hooks/chime.py +1 -1
  12. airflow/providers/amazon/aws/hooks/datasync.py +3 -3
  13. airflow/providers/amazon/aws/hooks/firehose.py +56 -0
  14. airflow/providers/amazon/aws/hooks/glue.py +7 -1
  15. airflow/providers/amazon/aws/hooks/kinesis.py +31 -13
  16. airflow/providers/amazon/aws/hooks/mwaa.py +38 -7
  17. airflow/providers/amazon/aws/hooks/redshift_sql.py +20 -6
  18. airflow/providers/amazon/aws/hooks/s3.py +41 -11
  19. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +1 -1
  20. airflow/providers/amazon/aws/hooks/ses.py +76 -10
  21. airflow/providers/amazon/aws/hooks/sns.py +74 -18
  22. airflow/providers/amazon/aws/hooks/sqs.py +64 -11
  23. airflow/providers/amazon/aws/hooks/ssm.py +34 -6
  24. airflow/providers/amazon/aws/hooks/step_function.py +1 -1
  25. airflow/providers/amazon/aws/links/base_aws.py +1 -1
  26. airflow/providers/amazon/aws/notifications/ses.py +139 -0
  27. airflow/providers/amazon/aws/notifications/sns.py +16 -1
  28. airflow/providers/amazon/aws/notifications/sqs.py +17 -1
  29. airflow/providers/amazon/aws/operators/base_aws.py +2 -2
  30. airflow/providers/amazon/aws/operators/bedrock.py +2 -0
  31. airflow/providers/amazon/aws/operators/cloud_formation.py +2 -2
  32. airflow/providers/amazon/aws/operators/datasync.py +2 -1
  33. airflow/providers/amazon/aws/operators/emr.py +44 -33
  34. airflow/providers/amazon/aws/operators/mwaa.py +12 -3
  35. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +1 -1
  36. airflow/providers/amazon/aws/operators/ssm.py +122 -17
  37. airflow/providers/amazon/aws/secrets/secrets_manager.py +3 -4
  38. airflow/providers/amazon/aws/sensors/base_aws.py +2 -2
  39. airflow/providers/amazon/aws/sensors/mwaa.py +14 -1
  40. airflow/providers/amazon/aws/sensors/s3.py +27 -13
  41. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +1 -1
  42. airflow/providers/amazon/aws/sensors/ssm.py +33 -17
  43. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +3 -3
  44. airflow/providers/amazon/aws/transfers/base.py +5 -5
  45. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +4 -4
  46. airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -1
  47. airflow/providers/amazon/aws/transfers/ftp_to_s3.py +1 -1
  48. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +48 -5
  49. airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
  50. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +2 -5
  51. airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -1
  52. airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -1
  53. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -1
  54. airflow/providers/amazon/aws/transfers/local_to_s3.py +1 -1
  55. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +1 -1
  56. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +6 -6
  57. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +1 -1
  58. airflow/providers/amazon/aws/transfers/s3_to_ftp.py +1 -1
  59. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +6 -6
  60. airflow/providers/amazon/aws/transfers/s3_to_sftp.py +1 -1
  61. airflow/providers/amazon/aws/transfers/s3_to_sql.py +1 -1
  62. airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +1 -1
  63. airflow/providers/amazon/aws/transfers/sftp_to_s3.py +1 -1
  64. airflow/providers/amazon/aws/transfers/sql_to_s3.py +4 -5
  65. airflow/providers/amazon/aws/triggers/bedrock.py +1 -1
  66. airflow/providers/amazon/aws/triggers/s3.py +29 -2
  67. airflow/providers/amazon/aws/triggers/ssm.py +17 -1
  68. airflow/providers/amazon/aws/utils/connection_wrapper.py +2 -5
  69. airflow/providers/amazon/aws/utils/mixins.py +1 -1
  70. airflow/providers/amazon/aws/utils/waiter.py +2 -2
  71. airflow/providers/amazon/aws/waiters/emr.json +6 -6
  72. airflow/providers/amazon/get_provider_info.py +19 -1
  73. airflow/providers/amazon/version_compat.py +19 -16
  74. {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/METADATA +25 -19
  75. {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/RECORD +79 -76
  76. apache_airflow_providers_amazon-9.18.0rc2.dist-info/licenses/NOTICE +5 -0
  77. {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/WHEEL +0 -0
  78. {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/entry_points.txt +0 -0
  79. {airflow/providers/amazon → apache_airflow_providers_amazon-9.18.0rc2.dist-info/licenses}/LICENSE +0 -0
@@ -29,11 +29,11 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "9.14.0"
32
+ __version__ = "9.18.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
- "2.10.0"
35
+ "2.11.0"
36
36
  ):
37
37
  raise RuntimeError(
38
- f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.10.0+"
38
+ f"The package `apache-airflow-providers-amazon:{__version__}` needs Apache Airflow 2.11.0+"
39
39
  )
@@ -338,6 +338,37 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
338
338
  ]
339
339
  return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
340
340
 
341
+ def filter_authorized_connections(
342
+ self,
343
+ *,
344
+ conn_ids: set[str],
345
+ user: AwsAuthManagerUser,
346
+ method: ResourceMethod = "GET",
347
+ team_name: str | None = None,
348
+ ) -> set[str]:
349
+ requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
350
+ requests_list: list[IsAuthorizedRequest] = []
351
+ for conn_id in conn_ids:
352
+ request: IsAuthorizedRequest = {
353
+ "method": method,
354
+ "entity_type": AvpEntities.CONNECTION,
355
+ "entity_id": conn_id,
356
+ }
357
+ requests[conn_id][method] = request
358
+ requests_list.append(request)
359
+
360
+ batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
361
+ requests=requests_list, user=user
362
+ )
363
+
364
+ return {
365
+ conn_id
366
+ for conn_id in conn_ids
367
+ if self._is_authorized_from_batch_response(
368
+ batch_is_authorized_results, requests[conn_id][method], user
369
+ )
370
+ }
371
+
341
372
  def filter_authorized_dag_ids(
342
373
  self,
343
374
  *,
@@ -361,13 +392,75 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
361
392
  requests=requests_list, user=user
362
393
  )
363
394
 
364
- def _has_access_to_dag(request: IsAuthorizedRequest):
365
- result = self.avp_facade.get_batch_is_authorized_single_result(
366
- batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
395
+ return {
396
+ dag_id
397
+ for dag_id in dag_ids
398
+ if self._is_authorized_from_batch_response(
399
+ batch_is_authorized_results, requests[dag_id][method], user
367
400
  )
368
- return result["decision"] == "ALLOW"
401
+ }
369
402
 
370
- return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
403
+ def filter_authorized_pools(
404
+ self,
405
+ *,
406
+ pool_names: set[str],
407
+ user: AwsAuthManagerUser,
408
+ method: ResourceMethod = "GET",
409
+ team_name: str | None = None,
410
+ ) -> set[str]:
411
+ requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
412
+ requests_list: list[IsAuthorizedRequest] = []
413
+ for pool_name in pool_names:
414
+ request: IsAuthorizedRequest = {
415
+ "method": method,
416
+ "entity_type": AvpEntities.POOL,
417
+ "entity_id": pool_name,
418
+ }
419
+ requests[pool_name][method] = request
420
+ requests_list.append(request)
421
+
422
+ batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
423
+ requests=requests_list, user=user
424
+ )
425
+
426
+ return {
427
+ pool_name
428
+ for pool_name in pool_names
429
+ if self._is_authorized_from_batch_response(
430
+ batch_is_authorized_results, requests[pool_name][method], user
431
+ )
432
+ }
433
+
434
+ def filter_authorized_variables(
435
+ self,
436
+ *,
437
+ variable_keys: set[str],
438
+ user: AwsAuthManagerUser,
439
+ method: ResourceMethod = "GET",
440
+ team_name: str | None = None,
441
+ ) -> set[str]:
442
+ requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
443
+ requests_list: list[IsAuthorizedRequest] = []
444
+ for variable_key in variable_keys:
445
+ request: IsAuthorizedRequest = {
446
+ "method": method,
447
+ "entity_type": AvpEntities.VARIABLE,
448
+ "entity_id": variable_key,
449
+ }
450
+ requests[variable_key][method] = request
451
+ requests_list.append(request)
452
+
453
+ batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
454
+ requests=requests_list, user=user
455
+ )
456
+
457
+ return {
458
+ variable_key
459
+ for variable_key in variable_keys
460
+ if self._is_authorized_from_batch_response(
461
+ batch_is_authorized_results, requests[variable_key][method], user
462
+ )
463
+ }
371
464
 
372
465
  def get_url_login(self, **kwargs) -> str:
373
466
  return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
@@ -406,6 +499,14 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
406
499
  "entity_id": menu_item_text,
407
500
  }
408
501
 
502
+ def _is_authorized_from_batch_response(
503
+ self, batch_is_authorized_results: list[dict], request: IsAuthorizedRequest, user: AwsAuthManagerUser
504
+ ):
505
+ result = self.avp_facade.get_batch_is_authorized_single_result(
506
+ batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
507
+ )
508
+ return result["decision"] == "ALLOW"
509
+
409
510
  def _check_avp_schema_version(self):
410
511
  if not self.avp_facade.is_policy_store_schema_up_to_date():
411
512
  self.log.warning(
@@ -35,6 +35,7 @@ from airflow.configuration import conf
35
35
  from airflow.providers.amazon.aws.auth_manager.constants import CONF_SAML_METADATA_URL_KEY, CONF_SECTION_NAME
36
36
  from airflow.providers.amazon.aws.auth_manager.datamodels.login import LoginResponse
37
37
  from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
38
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_1_PLUS
38
39
 
39
40
  try:
40
41
  from onelogin.saml2.auth import OneLogin_Saml2_Auth
@@ -101,7 +102,12 @@ def login_callback(request: Request):
101
102
  if relay_state == "login-redirect":
102
103
  response = RedirectResponse(url=url, status_code=303)
103
104
  secure = bool(conf.get("api", "ssl_cert", fallback=""))
104
- response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
105
+ # In Airflow 3.1.1 authentication changes, front-end no longer handle the token
106
+ # See https://github.com/apache/airflow/pull/55506
107
+ if AIRFLOW_V_3_1_1_PLUS:
108
+ response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure, httponly=True)
109
+ else:
110
+ response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure)
105
111
  return response
106
112
  if relay_state == "login-token":
107
113
  return LoginResponse(access_token=token)
@@ -66,7 +66,11 @@ def run_and_report(command, task_key):
66
66
  try:
67
67
  log.info("Starting execution for task: %s", task_key)
68
68
  result = subprocess.run(
69
- command, shell=isinstance(command, str), stdout=subprocess.PIPE, stderr=subprocess.STDOUT
69
+ command,
70
+ check=False,
71
+ shell=isinstance(command, str),
72
+ stdout=subprocess.PIPE,
73
+ stderr=subprocess.STDOUT,
70
74
  )
71
75
  return_code = result.returncode
72
76
  log.info("Execution completed for task %s with return code %s", task_key, return_code)
@@ -49,7 +49,7 @@ try:
49
49
  except ImportError:
50
50
  from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
51
51
 
52
- from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
52
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
53
53
 
54
54
  if TYPE_CHECKING:
55
55
  from sqlalchemy.orm import Session
@@ -40,11 +40,15 @@ MULTI_LINE_QUERY_LOG_PREFIX = "\n\t\t"
40
40
 
41
41
  def query_params_to_string(params: dict[str, str | Collection[str]]) -> str:
42
42
  result = ""
43
- for key, value in params.items():
43
+ for key, original_value in params.items():
44
+ value: str | Collection[str]
44
45
  if key == "QueryString":
45
46
  value = (
46
- MULTI_LINE_QUERY_LOG_PREFIX + str(value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip()
47
+ MULTI_LINE_QUERY_LOG_PREFIX
48
+ + str(original_value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip()
47
49
  )
50
+ else:
51
+ value = original_value
48
52
  result += f"\t{key}: {value}\n"
49
53
  return result.rstrip()
50
54
 
@@ -56,7 +56,7 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
56
56
  :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
57
57
 
58
58
  .. note::
59
- get_uri() depends on SQLAlchemy and PyAthena.
59
+ get_uri() depends on SQLAlchemy and PyAthena
60
60
  """
61
61
 
62
62
  conn_name_attr = "athena_conn_id"
@@ -163,7 +163,7 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
163
163
  port=443,
164
164
  database=conn_params["schema_name"],
165
165
  query={"aws_session_token": creds.token, **self.conn.extra_dejson},
166
- )
166
+ ).render_as_string(hide_password=False)
167
167
 
168
168
  def get_conn(self) -> AthenaConnection:
169
169
  """Get a ``pyathena.Connection`` object."""
@@ -60,7 +60,7 @@ from airflow.exceptions import (
60
60
  from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
61
61
  from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
62
62
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
63
- from airflow.providers.amazon.version_compat import BaseHook
63
+ from airflow.providers.common.compat.sdk import BaseHook
64
64
  from airflow.providers_manager import ProvidersManager
65
65
  from airflow.utils.helpers import exactly_one
66
66
  from airflow.utils.log.logging_mixin import LoggingMixin
@@ -790,7 +790,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
790
790
  async def get_async_conn(self):
791
791
  """Get an aiobotocore client to use for async operations."""
792
792
  # We have to wrap the call `self.get_client_type` in another call `_get_async_conn`,
793
- # because one of it's arguments `self.region_name` is a `@property` decorated function
793
+ # because one of its arguments `self.region_name` is a `@property` decorated function
794
794
  # calling the cached property `self.conn_config` at the end.
795
795
  return await sync_to_async(self._get_async_conn)()
796
796
 
@@ -386,8 +386,7 @@ class BatchClientHook(AwsBaseHook):
386
386
  )
387
387
  if job_status in match_status:
388
388
  return True
389
- else:
390
- raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")
389
+ raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")
391
390
 
392
391
  def get_job_description(self, job_id: str) -> dict:
393
392
  """
@@ -426,10 +425,9 @@ class BatchClientHook(AwsBaseHook):
426
425
  "check Amazon Provider AWS Connection documentation for more details.",
427
426
  str(err),
428
427
  )
429
- else:
430
- raise AirflowException(
431
- f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
432
- )
428
+ raise AirflowException(
429
+ f"AWS Batch job ({job_id}) description error: exceeded status_retries ({self.status_retries})"
430
+ )
433
431
 
434
432
  @staticmethod
435
433
  def parse_job_description(job_id: str, response: dict) -> dict:
@@ -33,7 +33,6 @@ from copy import deepcopy
33
33
  from pathlib import Path
34
34
  from typing import TYPE_CHECKING, Any
35
35
 
36
- import botocore.client
37
36
  import botocore.exceptions
38
37
  import botocore.waiter
39
38
 
@@ -33,7 +33,7 @@ class ChimeWebhookHook(HttpHook):
33
33
  """
34
34
  Interact with Amazon Chime Webhooks to create notifications.
35
35
 
36
- .. warning:: This hook is only designed to work with web hooks and not chat bots.
36
+ .. warning:: This hook is only designed to work with web hooks and not chatbots.
37
37
 
38
38
  :param chime_conn_id: :ref:`Amazon Chime Connection ID <howto/connection:chime>`
39
39
  with Endpoint as `https://hooks.chime.aws` and the webhook token
@@ -21,8 +21,9 @@ from __future__ import annotations
21
21
  import time
22
22
  from urllib.parse import urlsplit
23
23
 
24
- from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout
24
+ from airflow.exceptions import AirflowBadRequest, AirflowException
25
25
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26
+ from airflow.providers.common.compat.sdk import AirflowTaskTimeout
26
27
 
27
28
 
28
29
  class DataSyncHook(AwsBaseHook):
@@ -319,5 +320,4 @@ class DataSyncHook(AwsBaseHook):
319
320
  else:
320
321
  raise AirflowException(f"Unknown status: {status}") # Should never happen
321
322
  time.sleep(self.wait_interval_seconds)
322
- else:
323
- raise AirflowTaskTimeout("Max iterations exceeded!")
323
+ raise AirflowTaskTimeout("Max iterations exceeded!")
@@ -0,0 +1,56 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains AWS Firehose hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from collections.abc import Iterable
23
+
24
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
25
+
26
+
27
+ class FirehoseHook(AwsBaseHook):
28
+ """
29
+ Interact with Amazon Kinesis Firehose.
30
+
31
+ Provide thick wrapper around :external+boto3:py:class:`boto3.client("firehose") <Firehose.Client>`.
32
+
33
+ :param delivery_stream: Name of the delivery stream
34
+
35
+ Additional arguments (such as ``aws_conn_id``) may be specified and
36
+ are passed down to the underlying AwsBaseHook.
37
+
38
+ .. seealso::
39
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
40
+ """
41
+
42
+ def __init__(self, delivery_stream: str, *args, **kwargs) -> None:
43
+ self.delivery_stream = delivery_stream
44
+ kwargs["client_type"] = "firehose"
45
+ super().__init__(*args, **kwargs)
46
+
47
+ def put_records(self, records: Iterable) -> dict:
48
+ """
49
+ Write batch records to Kinesis Firehose.
50
+
51
+ .. seealso::
52
+ - :external+boto3:py:meth:`Firehose.Client.put_record_batch`
53
+
54
+ :param records: list of records
55
+ """
56
+ return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
@@ -565,7 +565,13 @@ class GlueDataQualityHook(AwsBaseHook):
565
565
  Rule_3 ColumnLength "marketplace" between 1 and 2 FAIL {'Column.marketplace.MaximumLength': 9.0, 'Column.marketplace.MinimumLength': 3.0} Value: 9.0 does not meet the constraint requirement!
566
566
 
567
567
  """
568
- import pandas as pd
568
+ try:
569
+ import pandas as pd
570
+ except ImportError:
571
+ self.log.warning(
572
+ "Pandas is not installed. Please install pandas to see the detailed Data Quality results."
573
+ )
574
+ return
569
575
 
570
576
  pd.set_option("display.max_rows", None)
571
577
  pd.set_option("display.max_columns", None)
@@ -19,12 +19,14 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- from collections.abc import Iterable
22
+ import warnings
23
23
 
24
+ from airflow.exceptions import AirflowProviderDeprecationWarning
24
25
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26
+ from airflow.providers.amazon.aws.hooks.firehose import FirehoseHook as _FirehoseHook
25
27
 
26
28
 
27
- class FirehoseHook(AwsBaseHook):
29
+ class FirehoseHook(_FirehoseHook):
28
30
  """
29
31
  Interact with Amazon Kinesis Firehose.
30
32
 
@@ -37,20 +39,36 @@ class FirehoseHook(AwsBaseHook):
37
39
 
38
40
  .. seealso::
39
41
  - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
42
+ .. deprecated::
43
+ This hook was moved. Import from
44
+ :class:`airflow.providers.amazon.aws.hooks.firehose.FirehoseHook`
45
+ instead of kinesis.py
40
46
  """
41
47
 
42
- def __init__(self, delivery_stream: str, *args, **kwargs) -> None:
43
- self.delivery_stream = delivery_stream
44
- kwargs["client_type"] = "firehose"
48
+ def __init__(self, *args, **kwargs) -> None:
49
+ warnings.warn(
50
+ "Importing FirehoseHook from kinesis.py is deprecated "
51
+ "and will be removed in a future release. "
52
+ "Please import it from firehose.py instead.",
53
+ AirflowProviderDeprecationWarning,
54
+ stacklevel=2,
55
+ )
45
56
  super().__init__(*args, **kwargs)
46
57
 
47
- def put_records(self, records: Iterable):
48
- """
49
- Write batch records to Kinesis Firehose.
50
58
 
51
- .. seealso::
52
- - :external+boto3:py:meth:`Firehose.Client.put_record_batch`
59
+ class KinesisHook(AwsBaseHook):
60
+ """
61
+ Interact with Amazon Kinesis.
62
+
63
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("kinesis") <Kinesis.Client>`.
64
+
65
+ Additional arguments (such as ``aws_conn_id``) may be specified and
66
+ are passed down to the underlying AwsBaseHook.
53
67
 
54
- :param records: list of records
55
- """
56
- return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
68
+ .. seealso::
69
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
70
+ """
71
+
72
+ def __init__(self, *args, **kwargs) -> None:
73
+ kwargs["client_type"] = "kinesis"
74
+ super().__init__(*args, **kwargs)
@@ -18,6 +18,9 @@
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import warnings
22
+ from typing import Literal
23
+
21
24
  import requests
22
25
  from botocore.exceptions import ClientError
23
26
 
@@ -55,6 +58,7 @@ class MwaaHook(AwsBaseHook):
55
58
  body: dict | None = None,
56
59
  query_params: dict | None = None,
57
60
  generate_local_token: bool = False,
61
+ airflow_version: Literal[2, 3] | None = None,
58
62
  ) -> dict:
59
63
  """
60
64
  Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -70,6 +74,8 @@ class MwaaHook(AwsBaseHook):
70
74
  :param generate_local_token: If True, only the local web token method is used without trying boto's
71
75
  `invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
72
76
  boto's `invoke_rest_api`
77
+ :param airflow_version: The Airflow major version the MWAA environment runs.
78
+ This parameter is only used if the local web token method is used to call Airflow API.
73
79
  """
74
80
  # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
75
81
  body = {k: v for k, v in body.items() if v is not None} if body else {}
@@ -83,7 +89,7 @@ class MwaaHook(AwsBaseHook):
83
89
  }
84
90
 
85
91
  if generate_local_token:
86
- return self._invoke_rest_api_using_local_session_token(**api_kwargs)
92
+ return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs)
87
93
 
88
94
  try:
89
95
  response = self.conn.invoke_rest_api(**api_kwargs)
@@ -100,7 +106,7 @@ class MwaaHook(AwsBaseHook):
100
106
  self.log.info(
101
107
  "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
102
108
  )
103
- return self._invoke_rest_api_using_local_session_token(**api_kwargs)
109
+ return self._invoke_rest_api_using_local_session_token(airflow_version, **api_kwargs)
104
110
  to_log = e.response
105
111
  # ResponseMetadata is removed because it contains data that is either very unlikely to be
106
112
  # useful in XComs and logs, or redundant given the data already included in the response
@@ -110,14 +116,35 @@ class MwaaHook(AwsBaseHook):
110
116
 
111
117
  def _invoke_rest_api_using_local_session_token(
112
118
  self,
119
+ airflow_version: Literal[2, 3] | None = None,
113
120
  **api_kwargs,
114
121
  ) -> dict:
122
+ if not airflow_version:
123
+ warnings.warn(
124
+ "The parameter ``airflow_version`` in ``MwaaHook.invoke_rest_api`` is not "
125
+ "specified and the local web token method is being used. "
126
+ "The default Airflow version being used is 2 but this value will change in the future. "
127
+ "To avoid any unexpected behavior, please explicitly specify the Airflow version.",
128
+ FutureWarning,
129
+ stacklevel=3,
130
+ )
131
+ airflow_version = 2
132
+
115
133
  try:
116
- session, hostname = self._get_session_conn(api_kwargs["Name"])
134
+ session, hostname, login_response = self._get_session_conn(api_kwargs["Name"], airflow_version)
135
+
136
+ headers = {}
137
+ if airflow_version == 3:
138
+ headers = {
139
+ "Authorization": f"Bearer {login_response.cookies['_token']}",
140
+ "Content-Type": "application/json",
141
+ }
117
142
 
143
+ api_version = "v1" if airflow_version == 2 else "v2"
118
144
  response = session.request(
119
145
  method=api_kwargs["Method"],
120
- url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
146
+ url=f"https://{hostname}/api/{api_version}{api_kwargs['Path']}",
147
+ headers=headers,
121
148
  params=api_kwargs["QueryParameters"],
122
149
  json=api_kwargs["Body"],
123
150
  timeout=10,
@@ -134,15 +161,19 @@ class MwaaHook(AwsBaseHook):
134
161
  }
135
162
 
136
163
  # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
137
- def _get_session_conn(self, env_name: str) -> tuple:
164
+ def _get_session_conn(self, env_name: str, airflow_version: Literal[2, 3]) -> tuple:
138
165
  create_token_response = self.conn.create_web_login_token(Name=env_name)
139
166
  web_server_hostname = create_token_response["WebServerHostname"]
140
167
  web_token = create_token_response["WebToken"]
141
168
 
142
- login_url = f"https://{web_server_hostname}/aws_mwaa/login"
169
+ login_url = (
170
+ f"https://{web_server_hostname}/aws_mwaa/login"
171
+ if airflow_version == 2
172
+ else f"https://{web_server_hostname}/pluginsv2/aws_mwaa/login"
173
+ )
143
174
  login_payload = {"token": web_token}
144
175
  session = requests.Session()
145
176
  login_response = session.post(login_url, data=login_payload, timeout=10)
146
177
  login_response.raise_for_status()
147
178
 
148
- return session, web_server_hostname
179
+ return session, web_server_hostname, login_response
@@ -51,7 +51,7 @@ class RedshiftSQLHook(DbApiHook):
51
51
  :ref:`Amazon Redshift connection id<howto/connection:redshift>`
52
52
 
53
53
  .. note::
54
- get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift
54
+ get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift.
55
55
  """
56
56
 
57
57
  conn_name_attr = "redshift_conn_id"
@@ -155,10 +155,21 @@ class RedshiftSQLHook(DbApiHook):
155
155
  if "user" in conn_params:
156
156
  conn_params["username"] = conn_params.pop("user")
157
157
 
158
- # Compatibility: The 'create' factory method was added in SQLAlchemy 1.4
159
- # to replace calling the default URL constructor directly.
160
- create_url = getattr(URL, "create", URL)
161
- return str(create_url(drivername="postgresql", **conn_params))
158
+ # Use URL.create for SQLAlchemy 2 compatibility
159
+ username = conn_params.get("username")
160
+ password = conn_params.get("password")
161
+ host = conn_params.get("host")
162
+ port = conn_params.get("port")
163
+ database = conn_params.get("database")
164
+
165
+ return URL.create(
166
+ drivername="postgresql",
167
+ username=str(username) if username is not None else None,
168
+ password=str(password) if password is not None else None,
169
+ host=str(host) if host is not None else None,
170
+ port=int(port) if port is not None else None,
171
+ database=str(database) if database is not None else None,
172
+ ).render_as_string(hide_password=False)
162
173
 
163
174
  def get_sqlalchemy_engine(self, engine_kwargs=None):
164
175
  """Overridden to pass Redshift-specific arguments."""
@@ -237,7 +248,10 @@ class RedshiftSQLHook(DbApiHook):
237
248
  region_name = AwsBaseHook(aws_conn_id=self.aws_conn_id).region_name
238
249
  identifier = f"{cluster_identifier}.{region_name}"
239
250
  if not cluster_identifier:
240
- identifier = self._get_identifier_from_hostname(connection.host)
251
+ if connection.host:
252
+ identifier = self._get_identifier_from_hostname(connection.host)
253
+ else:
254
+ raise AirflowException("Host is required when cluster_identifier is not provided.")
241
255
  return f"{identifier}:{port}"
242
256
 
243
257
  def _get_identifier_from_hostname(self, hostname: str) -> str: