apache-airflow-providers-amazon 8.6.0__py3-none-any.whl → 8.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +2 -5
- airflow/providers/amazon/aws/hooks/athena.py +4 -3
- airflow/providers/amazon/aws/hooks/base_aws.py +28 -41
- airflow/providers/amazon/aws/hooks/batch_client.py +8 -6
- airflow/providers/amazon/aws/hooks/batch_waiters.py +4 -2
- airflow/providers/amazon/aws/hooks/chime.py +13 -8
- airflow/providers/amazon/aws/hooks/cloud_formation.py +5 -1
- airflow/providers/amazon/aws/hooks/datasync.py +9 -16
- airflow/providers/amazon/aws/hooks/ecr.py +4 -1
- airflow/providers/amazon/aws/hooks/ecs.py +4 -1
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -12
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -1
- airflow/providers/amazon/aws/hooks/s3.py +4 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +7 -8
- airflow/providers/amazon/aws/hooks/sns.py +0 -1
- airflow/providers/amazon/aws/links/emr.py +4 -3
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +4 -1
- airflow/providers/amazon/aws/log/s3_task_handler.py +1 -1
- airflow/providers/amazon/aws/notifications/chime.py +4 -1
- airflow/providers/amazon/aws/notifications/sns.py +94 -0
- airflow/providers/amazon/aws/notifications/sqs.py +100 -0
- airflow/providers/amazon/aws/operators/ecs.py +5 -5
- airflow/providers/amazon/aws/operators/glue.py +1 -1
- airflow/providers/amazon/aws/operators/rds.py +2 -2
- airflow/providers/amazon/aws/sensors/batch.py +7 -2
- airflow/providers/amazon/aws/sensors/dynamodb.py +1 -1
- airflow/providers/amazon/aws/sensors/ecs.py +2 -2
- airflow/providers/amazon/aws/sensors/s3.py +2 -2
- airflow/providers/amazon/aws/sensors/sqs.py +7 -6
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/ftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +3 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +4 -2
- airflow/providers/amazon/aws/triggers/athena.py +5 -1
- airflow/providers/amazon/aws/triggers/base.py +4 -2
- airflow/providers/amazon/aws/triggers/batch.py +10 -11
- airflow/providers/amazon/aws/triggers/ecs.py +9 -6
- airflow/providers/amazon/aws/triggers/eks.py +4 -2
- airflow/providers/amazon/aws/triggers/emr.py +6 -4
- airflow/providers/amazon/aws/triggers/glue_crawler.py +4 -1
- airflow/providers/amazon/aws/triggers/lambda_function.py +5 -1
- airflow/providers/amazon/aws/triggers/rds.py +4 -2
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +4 -1
- airflow/providers/amazon/aws/triggers/s3.py +4 -2
- airflow/providers/amazon/aws/triggers/sqs.py +6 -2
- airflow/providers/amazon/aws/triggers/step_function.py +5 -1
- airflow/providers/amazon/aws/utils/__init__.py +4 -2
- airflow/providers/amazon/aws/utils/redshift.py +3 -1
- airflow/providers/amazon/aws/utils/sqs.py +7 -12
- airflow/providers/amazon/aws/utils/suppress.py +74 -0
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +4 -2
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -2
- airflow/providers/amazon/aws/waiters/base_waiter.py +5 -1
- airflow/providers/amazon/get_provider_info.py +20 -5
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/METADATA +6 -14
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/RECORD +63 -60
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_amazon-8.6.0.dist-info → apache_airflow_providers_amazon-8.7.0.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,6 @@ from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
|
24
24
|
|
25
25
|
if TYPE_CHECKING:
|
26
26
|
from mypy_boto3_appflow.client import AppflowClient
|
27
|
-
from mypy_boto3_appflow.type_defs import TaskOutputTypeDef, TaskTypeDef
|
28
27
|
|
29
28
|
|
30
29
|
class AppflowHook(AwsBaseHook):
|
@@ -93,9 +92,7 @@ class AppflowHook(AwsBaseHook):
|
|
93
92
|
exec_details = last_execs[execution_id]
|
94
93
|
self.log.info("Run complete, execution details: %s", exec_details)
|
95
94
|
|
96
|
-
def update_flow_filter(
|
97
|
-
self, flow_name: str, filter_tasks: list[TaskTypeDef], set_trigger_ondemand: bool = False
|
98
|
-
) -> None:
|
95
|
+
def update_flow_filter(self, flow_name: str, filter_tasks, set_trigger_ondemand: bool = False) -> None:
|
99
96
|
"""
|
100
97
|
Update the flow task filter; all filters will be removed if an empty array is passed to filter_tasks.
|
101
98
|
|
@@ -106,7 +103,7 @@ class AppflowHook(AwsBaseHook):
|
|
106
103
|
"""
|
107
104
|
response = self.conn.describe_flow(flowName=flow_name)
|
108
105
|
connector_type = response["sourceFlowConfig"]["connectorType"]
|
109
|
-
tasks
|
106
|
+
tasks = []
|
110
107
|
|
111
108
|
# cleanup old filter tasks
|
112
109
|
for task in response["tasks"]:
|
@@ -25,14 +25,15 @@ This module contains AWS Athena hook.
|
|
25
25
|
from __future__ import annotations
|
26
26
|
|
27
27
|
import warnings
|
28
|
-
from typing import Any
|
29
|
-
|
30
|
-
from botocore.paginate import PageIterator
|
28
|
+
from typing import TYPE_CHECKING, Any
|
31
29
|
|
32
30
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
33
31
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
34
32
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
35
33
|
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from botocore.paginate import PageIterator
|
36
|
+
|
36
37
|
|
37
38
|
class AthenaHook(AwsBaseHook):
|
38
39
|
"""Interact with Amazon Athena.
|
@@ -29,11 +29,9 @@ import inspect
|
|
29
29
|
import json
|
30
30
|
import logging
|
31
31
|
import os
|
32
|
-
import uuid
|
33
32
|
import warnings
|
34
33
|
from copy import deepcopy
|
35
34
|
from functools import cached_property, wraps
|
36
|
-
from os import PathLike
|
37
35
|
from pathlib import Path
|
38
36
|
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
|
39
37
|
|
@@ -43,9 +41,7 @@ import botocore.session
|
|
43
41
|
import jinja2
|
44
42
|
import requests
|
45
43
|
import tenacity
|
46
|
-
from botocore.client import ClientMeta
|
47
44
|
from botocore.config import Config
|
48
|
-
from botocore.credentials import ReadOnlyCredentials
|
49
45
|
from botocore.waiter import Waiter, WaiterModel
|
50
46
|
from dateutil.tz import tzlocal
|
51
47
|
from slugify import slugify
|
@@ -58,6 +54,8 @@ from airflow.exceptions import (
|
|
58
54
|
)
|
59
55
|
from airflow.hooks.base import BaseHook
|
60
56
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
57
|
+
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
|
58
|
+
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
61
59
|
from airflow.providers_manager import ProvidersManager
|
62
60
|
from airflow.utils.helpers import exactly_one
|
63
61
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
@@ -66,6 +64,9 @@ from airflow.utils.log.secrets_masker import mask_secret
|
|
66
64
|
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
|
67
65
|
|
68
66
|
if TYPE_CHECKING:
|
67
|
+
from botocore.client import ClientMeta
|
68
|
+
from botocore.credentials import ReadOnlyCredentials
|
69
|
+
|
69
70
|
from airflow.models.connection import Connection # Avoid circular imports.
|
70
71
|
|
71
72
|
|
@@ -470,21 +471,17 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
470
471
|
self._verify = verify
|
471
472
|
|
472
473
|
@classmethod
|
474
|
+
@return_on_error("Unknown")
|
473
475
|
def _get_provider_version(cls) -> str:
|
474
476
|
"""Check the Providers Manager for the package version."""
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
provider = manager.providers[hook.package_name]
|
484
|
-
return provider.version
|
485
|
-
except Exception:
|
486
|
-
# Under no condition should an error here ever cause an issue for the user.
|
487
|
-
return "Unknown"
|
477
|
+
manager = ProvidersManager()
|
478
|
+
hook = manager.hooks[cls.conn_type]
|
479
|
+
if not hook:
|
480
|
+
# This gets caught immediately, but without it MyPy complains
|
481
|
+
# Item "None" of "Optional[HookInfo]" has no attribute "package_name"
|
482
|
+
# on the following line and static checks fail.
|
483
|
+
raise ValueError(f"Hook info for {cls.conn_type} not found in the Provider Manager.")
|
484
|
+
return manager.providers[hook.package_name].version
|
488
485
|
|
489
486
|
@staticmethod
|
490
487
|
def _find_class_name(target_function_name: str) -> str:
|
@@ -504,19 +501,17 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
504
501
|
# Return the name of the class object.
|
505
502
|
return frame_class_object.__name__
|
506
503
|
|
504
|
+
@return_on_error("Unknown")
|
507
505
|
def _get_caller(self, target_function_name: str = "execute") -> str:
|
508
506
|
"""Given a function name, walk the stack and return the name of the class which called it last."""
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
return caller
|
515
|
-
except Exception:
|
516
|
-
# Under no condition should an error here ever cause an issue for the user.
|
517
|
-
return "Unknown"
|
507
|
+
caller = self._find_class_name(target_function_name)
|
508
|
+
if caller == "BaseSensorOperator":
|
509
|
+
# If the result is a BaseSensorOperator, then look for whatever last called "poke".
|
510
|
+
return self._get_caller("poke")
|
511
|
+
return caller
|
518
512
|
|
519
513
|
@staticmethod
|
514
|
+
@return_on_error("00000000-0000-0000-0000-000000000000")
|
520
515
|
def _generate_dag_key() -> str:
|
521
516
|
"""Generate a DAG key.
|
522
517
|
|
@@ -525,25 +520,17 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
525
520
|
can not (reasonably) be reversed. No personal data can be inferred or
|
526
521
|
extracted from the resulting UUID.
|
527
522
|
"""
|
528
|
-
|
529
|
-
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
|
530
|
-
return str(uuid.uuid5(uuid.NAMESPACE_OID, dag_id))
|
531
|
-
except Exception:
|
532
|
-
# Under no condition should an error here ever cause an issue for the user.
|
533
|
-
return "00000000-0000-0000-0000-000000000000"
|
523
|
+
return generate_uuid(os.environ.get("AIRFLOW_CTX_DAG_ID"))
|
534
524
|
|
535
525
|
@staticmethod
|
526
|
+
@return_on_error("Unknown")
|
536
527
|
def _get_airflow_version() -> str:
|
537
528
|
"""Fetch and return the current Airflow version."""
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
from airflow import __version__ as airflow_version
|
529
|
+
# This can be a circular import under specific configurations.
|
530
|
+
# Importing locally to either avoid or catch it if it does happen.
|
531
|
+
from airflow import __version__ as airflow_version
|
542
532
|
|
543
|
-
|
544
|
-
except Exception:
|
545
|
-
# Under no condition should an error here ever cause an issue for the user.
|
546
|
-
return "Unknown"
|
533
|
+
return airflow_version
|
547
534
|
|
548
535
|
def _generate_user_agent_extra_field(self, existing_user_agent_extra: str) -> str:
|
549
536
|
user_agent_extra_values = [
|
@@ -831,7 +818,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
831
818
|
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")
|
832
819
|
|
833
820
|
@cached_property
|
834
|
-
def waiter_path(self) -> PathLike[str] | None:
|
821
|
+
def waiter_path(self) -> os.PathLike[str] | None:
|
835
822
|
filename = self.client_type if self.client_type else self.resource_type
|
836
823
|
path = Path(__file__).parents[1].joinpath(f"waiters/{filename}.json").resolve()
|
837
824
|
return path if path.exists() else None
|
@@ -27,9 +27,9 @@ A client for AWS Batch services.
|
|
27
27
|
from __future__ import annotations
|
28
28
|
|
29
29
|
import itertools
|
30
|
-
|
30
|
+
import random
|
31
31
|
from time import sleep
|
32
|
-
from typing import Callable
|
32
|
+
from typing import TYPE_CHECKING, Callable
|
33
33
|
|
34
34
|
import botocore.client
|
35
35
|
import botocore.exceptions
|
@@ -37,9 +37,11 @@ import botocore.waiter
|
|
37
37
|
|
38
38
|
from airflow.exceptions import AirflowException
|
39
39
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
40
|
-
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
|
41
40
|
from airflow.typing_compat import Protocol, runtime_checkable
|
42
41
|
|
42
|
+
if TYPE_CHECKING:
|
43
|
+
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
|
44
|
+
|
43
45
|
|
44
46
|
@runtime_checkable
|
45
47
|
class BatchProtocol(Protocol):
|
@@ -527,7 +529,7 @@ class BatchClientHook(AwsBaseHook):
|
|
527
529
|
minima = abs(minima)
|
528
530
|
lower = max(minima, delay - width)
|
529
531
|
upper = delay + width
|
530
|
-
return uniform(lower, upper)
|
532
|
+
return random.uniform(lower, upper)
|
531
533
|
|
532
534
|
@staticmethod
|
533
535
|
def delay(delay: int | float | None = None) -> None:
|
@@ -544,7 +546,7 @@ class BatchClientHook(AwsBaseHook):
|
|
544
546
|
when many concurrent tasks request job-descriptions.
|
545
547
|
"""
|
546
548
|
if delay is None:
|
547
|
-
delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX)
|
549
|
+
delay = random.uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX)
|
548
550
|
else:
|
549
551
|
delay = BatchClientHook.add_jitter(delay)
|
550
552
|
sleep(delay)
|
@@ -592,4 +594,4 @@ class BatchClientHook(AwsBaseHook):
|
|
592
594
|
max_interval = 600.0 # results in 3 to 10 minute delay
|
593
595
|
delay = 1 + pow(tries * 0.6, 2)
|
594
596
|
delay = min(max_interval, delay)
|
595
|
-
return uniform(delay / 3, delay)
|
597
|
+
return random.uniform(delay / 3, delay)
|
@@ -29,7 +29,7 @@ import json
|
|
29
29
|
import sys
|
30
30
|
from copy import deepcopy
|
31
31
|
from pathlib import Path
|
32
|
-
from typing import Callable
|
32
|
+
from typing import TYPE_CHECKING, Callable
|
33
33
|
|
34
34
|
import botocore.client
|
35
35
|
import botocore.exceptions
|
@@ -37,7 +37,9 @@ import botocore.waiter
|
|
37
37
|
|
38
38
|
from airflow.exceptions import AirflowException
|
39
39
|
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
|
40
|
-
|
40
|
+
|
41
|
+
if TYPE_CHECKING:
|
42
|
+
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
|
41
43
|
|
42
44
|
|
43
45
|
class BatchWaitersHook(BatchClientHook):
|
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import json
|
23
23
|
import re
|
24
|
+
from functools import cached_property
|
24
25
|
from typing import Any
|
25
26
|
|
26
27
|
from airflow.exceptions import AirflowException
|
@@ -28,19 +29,19 @@ from airflow.providers.http.hooks.http import HttpHook
|
|
28
29
|
|
29
30
|
|
30
31
|
class ChimeWebhookHook(HttpHook):
|
31
|
-
"""Interact with Chime
|
32
|
+
"""Interact with Amazon Chime Webhooks to create notifications.
|
32
33
|
|
33
34
|
.. warning:: This hook is only designed to work with web hooks and not chat bots.
|
34
35
|
|
35
|
-
:param chime_conn_id: Chime
|
36
|
-
|
37
|
-
|
36
|
+
:param chime_conn_id: :ref:`Amazon Chime Connection ID <howto/connection:chime>`
|
37
|
+
with Endpoint as `https://hooks.chime.aws` and the webhook token
|
38
|
+
in the form of ``{webhook.id}?token{webhook.token}``
|
38
39
|
"""
|
39
40
|
|
40
41
|
conn_name_attr = "chime_conn_id"
|
41
42
|
default_conn_name = "chime_default"
|
42
43
|
conn_type = "chime"
|
43
|
-
hook_name = "Chime
|
44
|
+
hook_name = "Amazon Chime Webhook"
|
44
45
|
|
45
46
|
def __init__(
|
46
47
|
self,
|
@@ -49,7 +50,11 @@ class ChimeWebhookHook(HttpHook):
|
|
49
50
|
**kwargs: Any,
|
50
51
|
) -> None:
|
51
52
|
super().__init__(*args, **kwargs)
|
52
|
-
self.
|
53
|
+
self._chime_conn_id = chime_conn_id
|
54
|
+
|
55
|
+
@cached_property
|
56
|
+
def webhook_endpoint(self):
|
57
|
+
return self._get_webhook_endpoint(self._chime_conn_id)
|
53
58
|
|
54
59
|
def _get_webhook_endpoint(self, conn_id: str) -> str:
|
55
60
|
"""
|
@@ -65,7 +70,7 @@ class ChimeWebhookHook(HttpHook):
|
|
65
70
|
url = conn.schema + "://" + conn.host
|
66
71
|
endpoint = url + token
|
67
72
|
# Check to make sure the endpoint matches what Chime expects
|
68
|
-
if not re.
|
73
|
+
if not re.fullmatch(r"[a-zA-Z0-9_-]+\?token=[a-zA-Z0-9_-]+", token):
|
69
74
|
raise AirflowException(
|
70
75
|
"Expected Chime webhook token in the form of '{webhook.id}?token={webhook.token}'."
|
71
76
|
)
|
@@ -104,7 +109,7 @@ class ChimeWebhookHook(HttpHook):
|
|
104
109
|
"hidden_fields": ["login", "port", "extra"],
|
105
110
|
"relabeling": {
|
106
111
|
"host": "Chime Webhook Endpoint",
|
107
|
-
"password": "Webhook
|
112
|
+
"password": "Chime Webhook token",
|
108
113
|
},
|
109
114
|
"placeholders": {
|
110
115
|
"schema": "https",
|
@@ -18,11 +18,15 @@
|
|
18
18
|
"""This module contains AWS CloudFormation Hook."""
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
|
-
from
|
21
|
+
from typing import TYPE_CHECKING
|
22
|
+
|
22
23
|
from botocore.exceptions import ClientError
|
23
24
|
|
24
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
25
26
|
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from boto3 import client, resource
|
29
|
+
|
26
30
|
|
27
31
|
class CloudFormationHook(AwsBaseHook):
|
28
32
|
"""
|
@@ -301,25 +301,18 @@ class DataSyncHook(AwsBaseHook):
|
|
301
301
|
if not task_execution_arn:
|
302
302
|
raise AirflowBadRequest("task_execution_arn not specified")
|
303
303
|
|
304
|
-
|
305
|
-
iterations = max_iterations
|
306
|
-
while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
|
304
|
+
for _ in range(max_iterations):
|
307
305
|
task_execution = self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
|
308
306
|
status = task_execution["Status"]
|
309
307
|
self.log.info("status=%s", status)
|
310
|
-
iterations -= 1
|
311
|
-
if status in self.TASK_EXECUTION_FAILURE_STATES:
|
312
|
-
break
|
313
308
|
if status in self.TASK_EXECUTION_SUCCESS_STATES:
|
314
|
-
|
315
|
-
|
316
|
-
|
309
|
+
return True
|
310
|
+
elif status in self.TASK_EXECUTION_FAILURE_STATES:
|
311
|
+
return False
|
312
|
+
elif status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
|
313
|
+
time.sleep(self.wait_interval_seconds)
|
314
|
+
else:
|
315
|
+
raise AirflowException(f"Unknown status: {status}") # Should never happen
|
317
316
|
time.sleep(self.wait_interval_seconds)
|
318
|
-
|
319
|
-
if status in self.TASK_EXECUTION_SUCCESS_STATES:
|
320
|
-
return True
|
321
|
-
if status in self.TASK_EXECUTION_FAILURE_STATES:
|
322
|
-
return False
|
323
|
-
if iterations <= 0:
|
317
|
+
else:
|
324
318
|
raise AirflowTaskTimeout("Max iterations exceeded!")
|
325
|
-
raise AirflowException(f"Unknown status: {status}") # Should never happen
|
@@ -20,11 +20,14 @@ from __future__ import annotations
|
|
20
20
|
import base64
|
21
21
|
import logging
|
22
22
|
from dataclasses import dataclass
|
23
|
-
from
|
23
|
+
from typing import TYPE_CHECKING
|
24
24
|
|
25
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
26
26
|
from airflow.utils.log.secrets_masker import mask_secret
|
27
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from datetime import datetime
|
30
|
+
|
28
31
|
logger = logging.getLogger(__name__)
|
29
32
|
|
30
33
|
|
@@ -17,13 +17,16 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from
|
20
|
+
from typing import TYPE_CHECKING
|
21
21
|
|
22
22
|
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
|
23
23
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
24
24
|
from airflow.providers.amazon.aws.utils import _StringCompareEnum
|
25
25
|
from airflow.typing_compat import Protocol, runtime_checkable
|
26
26
|
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from botocore.waiter import Waiter
|
29
|
+
|
27
30
|
|
28
31
|
def should_retry(exception: Exception):
|
29
32
|
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
|
@@ -21,7 +21,6 @@ import warnings
|
|
21
21
|
from typing import Any, Sequence
|
22
22
|
|
23
23
|
import botocore.exceptions
|
24
|
-
from botocore.exceptions import ClientError
|
25
24
|
|
26
25
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
27
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
|
@@ -70,17 +69,14 @@ class RedshiftHook(AwsBaseHook):
|
|
70
69
|
for the cluster that is being created.
|
71
70
|
:param params: Remaining AWS Create cluster API params.
|
72
71
|
"""
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
return response
|
82
|
-
except ClientError as e:
|
83
|
-
raise e
|
72
|
+
response = self.get_conn().create_cluster(
|
73
|
+
ClusterIdentifier=cluster_identifier,
|
74
|
+
NodeType=node_type,
|
75
|
+
MasterUsername=master_username,
|
76
|
+
MasterUserPassword=master_user_password,
|
77
|
+
**params,
|
78
|
+
)
|
79
|
+
return response
|
84
80
|
|
85
81
|
# TODO: Wrap create_cluster_snapshot
|
86
82
|
def cluster_status(self, cluster_identifier: str) -> str:
|
@@ -188,7 +188,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
188
188
|
pk_columns = []
|
189
189
|
token = ""
|
190
190
|
while True:
|
191
|
-
kwargs =
|
191
|
+
kwargs = {"Id": stmt_id}
|
192
192
|
if token:
|
193
193
|
kwargs["NextToken"] = token
|
194
194
|
response = self.conn.get_statement_result(**kwargs)
|
@@ -41,10 +41,8 @@ from urllib.parse import urlsplit
|
|
41
41
|
from uuid import uuid4
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
|
-
|
44
|
+
with suppress(ImportError):
|
45
45
|
from aiobotocore.client import AioBaseClient
|
46
|
-
except ImportError:
|
47
|
-
pass
|
48
46
|
|
49
47
|
from asgiref.sync import sync_to_async
|
50
48
|
from boto3.s3.transfer import TransferConfig
|
@@ -470,7 +468,7 @@ class S3Hook(AwsBaseHook):
|
|
470
468
|
:param bucket_name: the name of the bucket
|
471
469
|
:param key: the path to the key
|
472
470
|
"""
|
473
|
-
prefix = re.split(r"[\[
|
471
|
+
prefix = re.split(r"[\[*?]", key, 1)[0]
|
474
472
|
delimiter = ""
|
475
473
|
paginator = client.get_paginator("list_objects_v2")
|
476
474
|
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
|
@@ -572,7 +570,7 @@ class S3Hook(AwsBaseHook):
|
|
572
570
|
for key in bucket_keys:
|
573
571
|
prefix = key
|
574
572
|
if wildcard_match:
|
575
|
-
prefix = re.split(r"[\[
|
573
|
+
prefix = re.split(r"[\[*?]", key, 1)[0]
|
576
574
|
|
577
575
|
paginator = client.get_paginator("list_objects_v2")
|
578
576
|
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
|
@@ -1017,7 +1015,7 @@ class S3Hook(AwsBaseHook):
|
|
1017
1015
|
:param delimiter: the delimiter marks key hierarchy
|
1018
1016
|
:return: the key object from the bucket or None if none has been found.
|
1019
1017
|
"""
|
1020
|
-
prefix = re.split(r"[\[
|
1018
|
+
prefix = re.split(r"[\[*?]", wildcard_key, 1)[0]
|
1021
1019
|
key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter)
|
1022
1020
|
key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)]
|
1023
1021
|
if key_matches:
|
@@ -252,12 +252,12 @@ class SageMakerHook(AwsBaseHook):
|
|
252
252
|
]
|
253
253
|
events: list[Any | None] = []
|
254
254
|
for event_stream in event_iters:
|
255
|
-
if
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
255
|
+
if event_stream:
|
256
|
+
try:
|
257
|
+
events.append(next(event_stream))
|
258
|
+
except StopIteration:
|
259
|
+
events.append(None)
|
260
|
+
else:
|
261
261
|
events.append(None)
|
262
262
|
|
263
263
|
while any(events):
|
@@ -979,8 +979,7 @@ class SageMakerHook(AwsBaseHook):
|
|
979
979
|
found_name: str,
|
980
980
|
job_name_suffix: str | None = None,
|
981
981
|
) -> bool:
|
982
|
-
|
983
|
-
return pattern.fullmatch(found_name) is not None
|
982
|
+
return re.fullmatch(f"{processing_job_name}({job_name_suffix})?", found_name) is not None
|
984
983
|
|
985
984
|
def count_processing_jobs_by_name(
|
986
985
|
self,
|
@@ -68,7 +68,6 @@ class SnsHook(AwsBaseHook):
|
|
68
68
|
|
69
69
|
:param target_arn: either a TopicArn or an EndpointArn
|
70
70
|
:param message: the default message you want to send
|
71
|
-
:param message: str
|
72
71
|
:param subject: subject of message
|
73
72
|
:param message_attributes: additional attributes to publish for message filtering. This should be
|
74
73
|
a flat dict; the DataType to be sent depends on the type of the value:
|
@@ -16,15 +16,16 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from typing import Any
|
20
|
-
|
21
|
-
import boto3
|
19
|
+
from typing import TYPE_CHECKING, Any
|
22
20
|
|
23
21
|
from airflow.exceptions import AirflowException
|
24
22
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
25
23
|
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
26
24
|
from airflow.utils.helpers import exactly_one
|
27
25
|
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
import boto3
|
28
|
+
|
28
29
|
|
29
30
|
class EmrClusterLink(BaseAwsLink):
|
30
31
|
"""Helper class for constructing AWS EMR Cluster Link."""
|
@@ -19,16 +19,19 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from datetime import datetime, timedelta
|
21
21
|
from functools import cached_property
|
22
|
+
from typing import TYPE_CHECKING
|
22
23
|
|
23
24
|
import watchtower
|
24
25
|
|
25
26
|
from airflow.configuration import conf
|
26
|
-
from airflow.models import TaskInstance
|
27
27
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
28
28
|
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
|
29
29
|
from airflow.utils.log.file_task_handler import FileTaskHandler
|
30
30
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
31
31
|
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from airflow.models import TaskInstance
|
34
|
+
|
32
35
|
|
33
36
|
class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
34
37
|
"""
|
@@ -201,7 +201,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
201
201
|
try:
|
202
202
|
if append and self.s3_log_exists(remote_log_location):
|
203
203
|
old_log = self.s3_read(remote_log_location)
|
204
|
-
log = "\n"
|
204
|
+
log = f"{old_log}\n{log}" if old_log else log
|
205
205
|
except Exception:
|
206
206
|
self.log.exception("Could not verify previous log to append")
|
207
207
|
return False
|
@@ -18,10 +18,13 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from functools import cached_property
|
21
|
+
from typing import TYPE_CHECKING
|
21
22
|
|
22
23
|
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
23
24
|
from airflow.providers.amazon.aws.hooks.chime import ChimeWebhookHook
|
24
|
-
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from airflow.utils.context import Context
|
25
28
|
|
26
29
|
try:
|
27
30
|
from airflow.notifications.basenotifier import BaseNotifier
|