apache-airflow-providers-amazon 9.2.0rc2__py3-none-any.whl → 9.4.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/LICENSE +0 -52
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -4
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +90 -106
- airflow/providers/amazon/aws/auth_manager/router/login.py +124 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +2 -2
- airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/utils.py +2 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +6 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/ecr.py +7 -1
- airflow/providers/amazon/aws/hooks/ecs.py +1 -2
- airflow/providers/amazon/aws/hooks/eks.py +10 -3
- airflow/providers/amazon/aws/hooks/emr.py +20 -0
- airflow/providers/amazon/aws/hooks/mwaa.py +85 -0
- airflow/providers/amazon/aws/hooks/sqs.py +4 -0
- airflow/providers/amazon/aws/hooks/ssm.py +10 -1
- airflow/providers/amazon/aws/links/comprehend.py +41 -0
- airflow/providers/amazon/aws/links/datasync.py +37 -0
- airflow/providers/amazon/aws/links/ec2.py +46 -0
- airflow/providers/amazon/aws/links/sagemaker.py +27 -0
- airflow/providers/amazon/aws/operators/athena.py +7 -5
- airflow/providers/amazon/aws/operators/batch.py +16 -8
- airflow/providers/amazon/aws/operators/bedrock.py +20 -18
- airflow/providers/amazon/aws/operators/comprehend.py +52 -11
- airflow/providers/amazon/aws/operators/datasync.py +40 -2
- airflow/providers/amazon/aws/operators/dms.py +0 -4
- airflow/providers/amazon/aws/operators/ec2.py +50 -0
- airflow/providers/amazon/aws/operators/ecs.py +11 -7
- airflow/providers/amazon/aws/operators/eks.py +17 -17
- airflow/providers/amazon/aws/operators/emr.py +27 -27
- airflow/providers/amazon/aws/operators/glue.py +16 -14
- airflow/providers/amazon/aws/operators/glue_crawler.py +3 -3
- airflow/providers/amazon/aws/operators/glue_databrew.py +5 -5
- airflow/providers/amazon/aws/operators/kinesis_analytics.py +9 -9
- airflow/providers/amazon/aws/operators/lambda_function.py +4 -4
- airflow/providers/amazon/aws/operators/mwaa.py +109 -0
- airflow/providers/amazon/aws/operators/rds.py +16 -16
- airflow/providers/amazon/aws/operators/redshift_cluster.py +15 -15
- airflow/providers/amazon/aws/operators/redshift_data.py +4 -4
- airflow/providers/amazon/aws/operators/sagemaker.py +52 -29
- airflow/providers/amazon/aws/operators/sqs.py +6 -0
- airflow/providers/amazon/aws/operators/step_function.py +4 -4
- airflow/providers/amazon/aws/sensors/ec2.py +3 -3
- airflow/providers/amazon/aws/sensors/emr.py +9 -9
- airflow/providers/amazon/aws/sensors/glue.py +7 -7
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +3 -3
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +3 -3
- airflow/providers/amazon/aws/sensors/sqs.py +6 -5
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +8 -3
- airflow/providers/amazon/aws/triggers/README.md +1 -1
- airflow/providers/amazon/aws/triggers/opensearch_serverless.py +2 -1
- airflow/providers/amazon/aws/triggers/sqs.py +2 -1
- airflow/providers/amazon/aws/utils/sqs.py +6 -4
- airflow/providers/amazon/aws/waiters/dms.json +12 -0
- airflow/providers/amazon/get_provider_info.py +106 -87
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/METADATA +16 -34
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/RECORD +61 -55
- airflow/providers/amazon/aws/auth_manager/views/auth.py +0 -151
- /airflow/providers/amazon/aws/auth_manager/{views → router}/__init__.py +0 -0
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/entry_points.txt +0 -0
@@ -58,7 +58,6 @@ from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
|
58
58
|
from airflow.providers_manager import ProvidersManager
|
59
59
|
from airflow.utils.helpers import exactly_one
|
60
60
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
61
|
-
from airflow.utils.log.secrets_masker import mask_secret
|
62
61
|
|
63
62
|
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
|
64
63
|
|
@@ -68,6 +67,12 @@ if TYPE_CHECKING:
|
|
68
67
|
from botocore.credentials import ReadOnlyCredentials
|
69
68
|
|
70
69
|
from airflow.models.connection import Connection
|
70
|
+
from airflow.sdk.execution_time.secrets_masker import mask_secret
|
71
|
+
else:
|
72
|
+
try:
|
73
|
+
from airflow.sdk.execution_time.secrets_masker import mask_secret
|
74
|
+
except ImportError:
|
75
|
+
from airflow.utils.log.secrets_masker import mask_secret
|
71
76
|
|
72
77
|
_loader = botocore.loaders.Loader()
|
73
78
|
"""
|
@@ -30,7 +30,7 @@ from __future__ import annotations
|
|
30
30
|
import itertools
|
31
31
|
import random
|
32
32
|
import time
|
33
|
-
from typing import TYPE_CHECKING, Callable
|
33
|
+
from typing import TYPE_CHECKING, Callable, Protocol, runtime_checkable
|
34
34
|
|
35
35
|
import botocore.client
|
36
36
|
import botocore.exceptions
|
@@ -38,7 +38,6 @@ import botocore.waiter
|
|
38
38
|
|
39
39
|
from airflow.exceptions import AirflowException
|
40
40
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
41
|
-
from airflow.typing_compat import Protocol, runtime_checkable
|
42
41
|
|
43
42
|
if TYPE_CHECKING:
|
44
43
|
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
|
@@ -23,11 +23,17 @@ from dataclasses import dataclass
|
|
23
23
|
from typing import TYPE_CHECKING
|
24
24
|
|
25
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
26
|
-
from airflow.utils.log.secrets_masker import mask_secret
|
27
26
|
|
28
27
|
if TYPE_CHECKING:
|
29
28
|
from datetime import datetime
|
30
29
|
|
30
|
+
from airflow.sdk.execution_time.secrets_masker import mask_secret
|
31
|
+
else:
|
32
|
+
try:
|
33
|
+
from airflow.sdk.execution_time.secrets_masker import mask_secret
|
34
|
+
except ImportError:
|
35
|
+
from airflow.utils.log.secrets_masker import mask_secret
|
36
|
+
|
31
37
|
logger = logging.getLogger(__name__)
|
32
38
|
|
33
39
|
|
@@ -17,12 +17,11 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import TYPE_CHECKING
|
20
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
21
21
|
|
22
22
|
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
|
23
23
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
24
24
|
from airflow.providers.amazon.aws.utils import _StringCompareEnum
|
25
|
-
from airflow.typing_compat import Protocol, runtime_checkable
|
26
25
|
|
27
26
|
if TYPE_CHECKING:
|
28
27
|
from botocore.waiter import Waiter
|
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|
20
20
|
|
21
21
|
import base64
|
22
22
|
import json
|
23
|
+
import os
|
23
24
|
import sys
|
24
25
|
import tempfile
|
25
26
|
from collections.abc import Generator
|
@@ -32,6 +33,7 @@ from botocore.exceptions import ClientError
|
|
32
33
|
from botocore.signers import RequestSigner
|
33
34
|
|
34
35
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
36
|
+
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
35
37
|
from airflow.utils import yaml
|
36
38
|
from airflow.utils.json import AirflowJsonEncoder
|
37
39
|
|
@@ -612,9 +614,14 @@ class EksHook(AwsBaseHook):
|
|
612
614
|
def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:
|
613
615
|
session = self.get_session()
|
614
616
|
service_id = self.conn.meta.service_model.service_id
|
615
|
-
|
616
|
-
|
617
|
-
|
617
|
+
# This env variable is required so that we get a regionalized endpoint for STS in regions that
|
618
|
+
# otherwise default to global endpoints. The mechanism below to generate the token is very picky that
|
619
|
+
# the endpoint is regional.
|
620
|
+
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
|
621
|
+
try:
|
622
|
+
sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
|
623
|
+
finally:
|
624
|
+
del os.environ["AWS_STS_REGIONAL_ENDPOINTS"]
|
618
625
|
|
619
626
|
signer = RequestSigner(
|
620
627
|
service_id=service_id,
|
@@ -22,7 +22,9 @@ import time
|
|
22
22
|
import warnings
|
23
23
|
from typing import Any
|
24
24
|
|
25
|
+
import tenacity
|
25
26
|
from botocore.exceptions import ClientError
|
27
|
+
from tenacity import retry_if_exception, stop_after_attempt, wait_fixed
|
26
28
|
|
27
29
|
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
28
30
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -311,6 +313,15 @@ class EmrServerlessHook(AwsBaseHook):
|
|
311
313
|
return count
|
312
314
|
|
313
315
|
|
316
|
+
def is_connection_being_updated_exception(exception: BaseException) -> bool:
|
317
|
+
return (
|
318
|
+
isinstance(exception, ClientError)
|
319
|
+
and exception.response["Error"]["Code"] == "ValidationException"
|
320
|
+
and "is not reachable as its connection is currently being updated"
|
321
|
+
in exception.response["Error"]["Message"]
|
322
|
+
)
|
323
|
+
|
324
|
+
|
314
325
|
class EmrContainerHook(AwsBaseHook):
|
315
326
|
"""
|
316
327
|
Interact with Amazon EMR Containers (Amazon EMR on EKS).
|
@@ -348,6 +359,15 @@ class EmrContainerHook(AwsBaseHook):
|
|
348
359
|
super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore
|
349
360
|
self.virtual_cluster_id = virtual_cluster_id
|
350
361
|
|
362
|
+
# Retry this method when the ``create_virtual_cluster`` raises
|
363
|
+
# "Cluster XXX is not reachable as its connection is currently being updated".
|
364
|
+
# Even though the EKS cluster status is ``ACTIVE``, ``create_virtual_cluster`` can raise this error.
|
365
|
+
# Retrying is the only option. Retry up to 3 minutes
|
366
|
+
@tenacity.retry(
|
367
|
+
retry=retry_if_exception(is_connection_being_updated_exception),
|
368
|
+
stop=stop_after_attempt(12),
|
369
|
+
wait=wait_fixed(15),
|
370
|
+
)
|
351
371
|
def create_emr_on_eks_cluster(
|
352
372
|
self,
|
353
373
|
virtual_cluster_name: str,
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
"""This module contains AWS MWAA hook."""
|
18
|
+
|
19
|
+
from __future__ import annotations
|
20
|
+
|
21
|
+
from botocore.exceptions import ClientError
|
22
|
+
|
23
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
24
|
+
|
25
|
+
|
26
|
+
class MwaaHook(AwsBaseHook):
|
27
|
+
"""
|
28
|
+
Interact with AWS Manager Workflows for Apache Airflow.
|
29
|
+
|
30
|
+
Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
|
31
|
+
|
32
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
33
|
+
are passed down to the underlying AwsBaseHook.
|
34
|
+
|
35
|
+
.. seealso::
|
36
|
+
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, *args, **kwargs) -> None:
|
40
|
+
kwargs["client_type"] = "mwaa"
|
41
|
+
super().__init__(*args, **kwargs)
|
42
|
+
|
43
|
+
def invoke_rest_api(
|
44
|
+
self,
|
45
|
+
env_name: str,
|
46
|
+
path: str,
|
47
|
+
method: str,
|
48
|
+
body: dict | None = None,
|
49
|
+
query_params: dict | None = None,
|
50
|
+
) -> dict:
|
51
|
+
"""
|
52
|
+
Invoke the REST API on the Airflow webserver with the specified inputs.
|
53
|
+
|
54
|
+
.. seealso::
|
55
|
+
- :external+boto3:py:meth:`MWAA.Client.invoke_rest_api`
|
56
|
+
|
57
|
+
:param env_name: name of the MWAA environment
|
58
|
+
:param path: Apache Airflow REST API endpoint path to be called
|
59
|
+
:param method: HTTP method used for making Airflow REST API calls
|
60
|
+
:param body: Request body for the Apache Airflow REST API call
|
61
|
+
:param query_params: Query parameters to be included in the Apache Airflow REST API call
|
62
|
+
"""
|
63
|
+
body = body or {}
|
64
|
+
api_kwargs = {
|
65
|
+
"Name": env_name,
|
66
|
+
"Path": path,
|
67
|
+
"Method": method,
|
68
|
+
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
|
69
|
+
"Body": {k: v for k, v in body.items() if v is not None},
|
70
|
+
"QueryParameters": query_params if query_params else {},
|
71
|
+
}
|
72
|
+
try:
|
73
|
+
result = self.conn.invoke_rest_api(**api_kwargs)
|
74
|
+
# ResponseMetadata is removed because it contains data that is either very unlikely to be useful
|
75
|
+
# in XComs and logs, or redundant given the data already included in the response
|
76
|
+
result.pop("ResponseMetadata", None)
|
77
|
+
return result
|
78
|
+
except ClientError as e:
|
79
|
+
to_log = e.response
|
80
|
+
# ResponseMetadata and Error are removed because they contain data that is either very unlikely to
|
81
|
+
# be useful in XComs and logs, or redundant given the data already included in the response
|
82
|
+
to_log.pop("ResponseMetadata", None)
|
83
|
+
to_log.pop("Error", None)
|
84
|
+
self.log.error(to_log)
|
85
|
+
raise e
|
@@ -59,6 +59,7 @@ class SqsHook(AwsBaseHook):
|
|
59
59
|
delay_seconds: int = 0,
|
60
60
|
message_attributes: dict | None = None,
|
61
61
|
message_group_id: str | None = None,
|
62
|
+
message_deduplication_id: str | None = None,
|
62
63
|
) -> dict:
|
63
64
|
"""
|
64
65
|
Send message to the queue.
|
@@ -71,6 +72,7 @@ class SqsHook(AwsBaseHook):
|
|
71
72
|
:param delay_seconds: seconds to delay the message
|
72
73
|
:param message_attributes: additional attributes for the message (default: None)
|
73
74
|
:param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None)
|
75
|
+
:param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
|
74
76
|
:return: dict with the information about the message sent
|
75
77
|
"""
|
76
78
|
params = {
|
@@ -81,5 +83,7 @@ class SqsHook(AwsBaseHook):
|
|
81
83
|
}
|
82
84
|
if message_group_id:
|
83
85
|
params["MessageGroupId"] = message_group_id
|
86
|
+
if message_deduplication_id:
|
87
|
+
params["MessageDeduplicationId"] = message_deduplication_id
|
84
88
|
|
85
89
|
return self.get_conn().send_message(**params)
|
@@ -17,10 +17,19 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from typing import TYPE_CHECKING
|
21
|
+
|
20
22
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
21
|
-
from airflow.utils.log.secrets_masker import mask_secret
|
22
23
|
from airflow.utils.types import NOTSET, ArgNotSet
|
23
24
|
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from airflow.sdk.execution_time.secrets_masker import mask_secret
|
27
|
+
else:
|
28
|
+
try:
|
29
|
+
from airflow.sdk.execution_time.secrets_masker import mask_secret
|
30
|
+
except ImportError:
|
31
|
+
from airflow.utils.log.secrets_masker import mask_secret
|
32
|
+
|
24
33
|
|
25
34
|
class SsmHook(AwsBaseHook):
|
26
35
|
"""
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
20
|
+
|
21
|
+
|
22
|
+
class ComprehendPiiEntitiesDetectionLink(BaseAwsLink):
|
23
|
+
"""Helper class for constructing Amazon Comprehend PII Detection console link."""
|
24
|
+
|
25
|
+
name = "PII Detection Job"
|
26
|
+
key = "comprehend_pii_detection"
|
27
|
+
format_str = (
|
28
|
+
BASE_AWS_CONSOLE_LINK
|
29
|
+
+ "/comprehend/home?region={region_name}#"
|
30
|
+
+ "/analysis-job-details/pii/{job_id}"
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class ComprehendDocumentClassifierLink(BaseAwsLink):
|
35
|
+
"""Helper class for constructing Amazon Comprehend Document Classifier console link."""
|
36
|
+
|
37
|
+
name = "Document Classifier"
|
38
|
+
key = "comprehend_document_classifier"
|
39
|
+
format_str = (
|
40
|
+
BASE_AWS_CONSOLE_LINK + "/comprehend/home?region={region_name}#" + "classifier-version-details/{arn}"
|
41
|
+
)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
20
|
+
|
21
|
+
|
22
|
+
class DataSyncTaskLink(BaseAwsLink):
|
23
|
+
"""Helper class for constructing AWS DataSync Task console link."""
|
24
|
+
|
25
|
+
name = "DataSync Task"
|
26
|
+
key = "datasync_task"
|
27
|
+
format_str = BASE_AWS_CONSOLE_LINK + "/datasync/home?region={region_name}#" + "/tasks/{task_id}"
|
28
|
+
|
29
|
+
|
30
|
+
class DataSyncTaskExecutionLink(BaseAwsLink):
|
31
|
+
"""Helper class for constructing AWS DataSync TaskExecution console link."""
|
32
|
+
|
33
|
+
name = "DataSync Task Execution"
|
34
|
+
key = "datasync_task_execution"
|
35
|
+
format_str = (
|
36
|
+
BASE_AWS_CONSOLE_LINK + "/datasync/home?region={region_name}#/history/{task_id}/{task_execution_id}"
|
37
|
+
)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
20
|
+
|
21
|
+
|
22
|
+
class EC2InstanceLink(BaseAwsLink):
|
23
|
+
"""Helper class for constructing Amazon EC2 instance links."""
|
24
|
+
|
25
|
+
name = "Instance"
|
26
|
+
key = "_instance_id"
|
27
|
+
format_str = (
|
28
|
+
BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#InstanceDetails:instanceId={instance_id}"
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class EC2InstanceDashboardLink(BaseAwsLink):
|
33
|
+
"""
|
34
|
+
Helper class for constructing Amazon EC2 console links.
|
35
|
+
|
36
|
+
This is useful for displaying the list of EC2 instances, rather
|
37
|
+
than a single instance.
|
38
|
+
"""
|
39
|
+
|
40
|
+
name = "EC2 Instances"
|
41
|
+
key = "_instance_dashboard"
|
42
|
+
format_str = BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#Instances:instanceId=:{instance_ids}"
|
43
|
+
|
44
|
+
@staticmethod
|
45
|
+
def format_instance_id_filter(instance_ids: list[str]) -> str:
|
46
|
+
return ",:".join(instance_ids)
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
20
|
+
|
21
|
+
|
22
|
+
class SageMakerTransformJobLink(BaseAwsLink):
|
23
|
+
"""Helper class for constructing AWS Transform Run Details Link."""
|
24
|
+
|
25
|
+
name = "Amazon SageMaker Transform Job Details"
|
26
|
+
key = "sagemaker_transform_job_details"
|
27
|
+
format_str = BASE_AWS_CONSOLE_LINK + "/sagemaker/home?region={region_name}#/transform-jobs/{job_name}"
|
@@ -177,14 +177,16 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
|
|
177
177
|
return self.query_execution_id
|
178
178
|
|
179
179
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
180
|
-
|
180
|
+
validated_event = validate_execute_complete_event(event)
|
181
181
|
|
182
|
-
if
|
183
|
-
raise AirflowException(
|
182
|
+
if validated_event["status"] != "success":
|
183
|
+
raise AirflowException(
|
184
|
+
f"Error while waiting for operation on cluster to complete: {validated_event}"
|
185
|
+
)
|
184
186
|
|
185
187
|
# Save query_execution_id to be later used by listeners
|
186
|
-
self.query_execution_id =
|
187
|
-
return
|
188
|
+
self.query_execution_id = validated_event["value"]
|
189
|
+
return validated_event["value"]
|
188
190
|
|
189
191
|
def on_kill(self) -> None:
|
190
192
|
"""Cancel the submitted Amazon Athena query."""
|
@@ -95,6 +95,7 @@ class BatchOperator(BaseOperator):
|
|
95
95
|
If it is an array job, only the logs of the first task will be printed.
|
96
96
|
:param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec.
|
97
97
|
:param poll_interval: (Deferrable mode only) Time in seconds to wait between polling.
|
98
|
+
:param submit_job_timeout: Execution timeout in seconds for submitted batch job.
|
98
99
|
|
99
100
|
.. note::
|
100
101
|
Any custom waiters must return a waiter for these calls:
|
@@ -184,6 +185,7 @@ class BatchOperator(BaseOperator):
|
|
184
185
|
poll_interval: int = 30,
|
185
186
|
awslogs_enabled: bool = False,
|
186
187
|
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
|
188
|
+
submit_job_timeout: int | None = None,
|
187
189
|
**kwargs,
|
188
190
|
) -> None:
|
189
191
|
BaseOperator.__init__(self, **kwargs)
|
@@ -208,6 +210,7 @@ class BatchOperator(BaseOperator):
|
|
208
210
|
self.poll_interval = poll_interval
|
209
211
|
self.awslogs_enabled = awslogs_enabled
|
210
212
|
self.awslogs_fetch_interval = awslogs_fetch_interval
|
213
|
+
self.submit_job_timeout = submit_job_timeout
|
211
214
|
|
212
215
|
# params for hook
|
213
216
|
self.max_retries = max_retries
|
@@ -264,13 +267,13 @@ class BatchOperator(BaseOperator):
|
|
264
267
|
return self.job_id
|
265
268
|
|
266
269
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
267
|
-
|
270
|
+
validated_event = validate_execute_complete_event(event)
|
268
271
|
|
269
|
-
if
|
270
|
-
raise AirflowException(f"Error while running job: {
|
272
|
+
if validated_event["status"] != "success":
|
273
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
271
274
|
|
272
275
|
self.log.info("Job completed.")
|
273
|
-
return
|
276
|
+
return validated_event["job_id"]
|
274
277
|
|
275
278
|
def on_kill(self):
|
276
279
|
response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
|
@@ -315,6 +318,9 @@ class BatchOperator(BaseOperator):
|
|
315
318
|
"schedulingPriorityOverride": self.scheduling_priority_override,
|
316
319
|
}
|
317
320
|
|
321
|
+
if self.submit_job_timeout:
|
322
|
+
args["timeout"] = {"attemptDurationSeconds": self.submit_job_timeout}
|
323
|
+
|
318
324
|
try:
|
319
325
|
response = self.hook.client.submit_job(**trim_none_values(args))
|
320
326
|
except Exception as e:
|
@@ -534,8 +540,10 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
|
|
534
540
|
return arn
|
535
541
|
|
536
542
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
537
|
-
|
543
|
+
validated_event = validate_execute_complete_event(event)
|
538
544
|
|
539
|
-
if
|
540
|
-
raise AirflowException(
|
541
|
-
|
545
|
+
if validated_event["status"] != "success":
|
546
|
+
raise AirflowException(
|
547
|
+
f"Error while waiting for the compute environment to be ready: {validated_event}"
|
548
|
+
)
|
549
|
+
return validated_event["value"]
|
@@ -198,13 +198,13 @@ class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]):
|
|
198
198
|
self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"}
|
199
199
|
|
200
200
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
201
|
-
|
201
|
+
validated_event = validate_execute_complete_event(event)
|
202
202
|
|
203
|
-
if
|
204
|
-
raise AirflowException(f"Error while running job: {
|
203
|
+
if validated_event["status"] != "success":
|
204
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
205
205
|
|
206
206
|
self.log.info("Bedrock model customization job `%s` complete.", self.job_name)
|
207
|
-
return self.hook.conn.get_model_customization_job(jobIdentifier=
|
207
|
+
return self.hook.conn.get_model_customization_job(jobIdentifier=validated_event["job_name"])["jobArn"]
|
208
208
|
|
209
209
|
def execute(self, context: Context) -> dict:
|
210
210
|
response = {}
|
@@ -353,13 +353,15 @@ class BedrockCreateProvisionedModelThroughputOperator(AwsBaseOperator[BedrockHoo
|
|
353
353
|
return provisioned_model_id
|
354
354
|
|
355
355
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
356
|
-
|
356
|
+
validated_event = validate_execute_complete_event(event)
|
357
357
|
|
358
|
-
if
|
359
|
-
raise AirflowException(f"Error while running job: {
|
358
|
+
if validated_event["status"] != "success":
|
359
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
360
360
|
|
361
|
-
self.log.info(
|
362
|
-
|
361
|
+
self.log.info(
|
362
|
+
"Bedrock provisioned throughput job `%s` complete.", validated_event["provisioned_model_id"]
|
363
|
+
)
|
364
|
+
return validated_event["provisioned_model_id"]
|
363
365
|
|
364
366
|
|
365
367
|
class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]):
|
@@ -449,13 +451,13 @@ class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]):
|
|
449
451
|
self.deferrable = deferrable
|
450
452
|
|
451
453
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
452
|
-
|
454
|
+
validated_event = validate_execute_complete_event(event)
|
453
455
|
|
454
|
-
if
|
455
|
-
raise AirflowException(f"Error while running job: {
|
456
|
+
if validated_event["status"] != "success":
|
457
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
456
458
|
|
457
459
|
self.log.info("Bedrock knowledge base creation job `%s` complete.", self.name)
|
458
|
-
return
|
460
|
+
return validated_event["knowledge_base_id"]
|
459
461
|
|
460
462
|
def execute(self, context: Context) -> str:
|
461
463
|
def _create_kb():
|
@@ -634,14 +636,14 @@ class BedrockIngestDataOperator(AwsBaseOperator[BedrockAgentHook]):
|
|
634
636
|
self.deferrable = deferrable
|
635
637
|
|
636
638
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
637
|
-
|
639
|
+
validated_event = validate_execute_complete_event(event)
|
638
640
|
|
639
|
-
if
|
640
|
-
raise AirflowException(f"Error while running ingestion job: {
|
641
|
+
if validated_event["status"] != "success":
|
642
|
+
raise AirflowException(f"Error while running ingestion job: {validated_event}")
|
641
643
|
|
642
|
-
self.log.info("Bedrock ingestion job `%s` complete.",
|
644
|
+
self.log.info("Bedrock ingestion job `%s` complete.", validated_event["ingestion_job_id"])
|
643
645
|
|
644
|
-
return
|
646
|
+
return validated_event["ingestion_job_id"]
|
645
647
|
|
646
648
|
def execute(self, context: Context) -> str:
|
647
649
|
ingestion_job_id = self.hook.conn.start_ingestion_job(
|