apache-airflow-providers-amazon 9.2.0rc1__py3-none-any.whl → 9.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. airflow/providers/amazon/LICENSE +0 -52
  2. airflow/providers/amazon/__init__.py +1 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -4
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +90 -106
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +124 -0
  6. airflow/providers/amazon/aws/executors/batch/batch_executor.py +2 -2
  7. airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
  8. airflow/providers/amazon/aws/executors/ecs/utils.py +2 -1
  9. airflow/providers/amazon/aws/hooks/base_aws.py +6 -1
  10. airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
  11. airflow/providers/amazon/aws/hooks/ecr.py +7 -1
  12. airflow/providers/amazon/aws/hooks/ecs.py +1 -2
  13. airflow/providers/amazon/aws/hooks/eks.py +10 -3
  14. airflow/providers/amazon/aws/hooks/emr.py +20 -0
  15. airflow/providers/amazon/aws/hooks/mwaa.py +85 -0
  16. airflow/providers/amazon/aws/hooks/sqs.py +4 -0
  17. airflow/providers/amazon/aws/hooks/ssm.py +10 -1
  18. airflow/providers/amazon/aws/links/comprehend.py +41 -0
  19. airflow/providers/amazon/aws/links/datasync.py +37 -0
  20. airflow/providers/amazon/aws/links/ec2.py +46 -0
  21. airflow/providers/amazon/aws/links/sagemaker.py +27 -0
  22. airflow/providers/amazon/aws/operators/athena.py +7 -5
  23. airflow/providers/amazon/aws/operators/batch.py +16 -8
  24. airflow/providers/amazon/aws/operators/bedrock.py +20 -18
  25. airflow/providers/amazon/aws/operators/comprehend.py +52 -11
  26. airflow/providers/amazon/aws/operators/datasync.py +40 -2
  27. airflow/providers/amazon/aws/operators/dms.py +0 -4
  28. airflow/providers/amazon/aws/operators/ec2.py +50 -0
  29. airflow/providers/amazon/aws/operators/ecs.py +11 -7
  30. airflow/providers/amazon/aws/operators/eks.py +17 -17
  31. airflow/providers/amazon/aws/operators/emr.py +27 -27
  32. airflow/providers/amazon/aws/operators/glue.py +16 -14
  33. airflow/providers/amazon/aws/operators/glue_crawler.py +3 -3
  34. airflow/providers/amazon/aws/operators/glue_databrew.py +5 -5
  35. airflow/providers/amazon/aws/operators/kinesis_analytics.py +9 -9
  36. airflow/providers/amazon/aws/operators/lambda_function.py +4 -4
  37. airflow/providers/amazon/aws/operators/mwaa.py +109 -0
  38. airflow/providers/amazon/aws/operators/rds.py +16 -16
  39. airflow/providers/amazon/aws/operators/redshift_cluster.py +15 -15
  40. airflow/providers/amazon/aws/operators/redshift_data.py +4 -4
  41. airflow/providers/amazon/aws/operators/sagemaker.py +52 -29
  42. airflow/providers/amazon/aws/operators/sqs.py +6 -0
  43. airflow/providers/amazon/aws/operators/step_function.py +4 -4
  44. airflow/providers/amazon/aws/sensors/ec2.py +3 -3
  45. airflow/providers/amazon/aws/sensors/emr.py +9 -9
  46. airflow/providers/amazon/aws/sensors/glue.py +7 -7
  47. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +3 -3
  48. airflow/providers/amazon/aws/sensors/redshift_cluster.py +3 -3
  49. airflow/providers/amazon/aws/sensors/sqs.py +6 -5
  50. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +8 -3
  51. airflow/providers/amazon/aws/triggers/README.md +1 -1
  52. airflow/providers/amazon/aws/triggers/opensearch_serverless.py +2 -1
  53. airflow/providers/amazon/aws/triggers/sqs.py +2 -1
  54. airflow/providers/amazon/aws/utils/sqs.py +6 -4
  55. airflow/providers/amazon/aws/waiters/dms.json +12 -0
  56. airflow/providers/amazon/get_provider_info.py +106 -87
  57. {apache_airflow_providers_amazon-9.2.0rc1.dist-info → apache_airflow_providers_amazon-9.3.0.dist-info}/METADATA +18 -36
  58. {apache_airflow_providers_amazon-9.2.0rc1.dist-info → apache_airflow_providers_amazon-9.3.0.dist-info}/RECORD +61 -55
  59. airflow/providers/amazon/aws/auth_manager/views/auth.py +0 -151
  60. /airflow/providers/amazon/aws/auth_manager/{views → router}/__init__.py +0 -0
  61. {apache_airflow_providers_amazon-9.2.0rc1.dist-info → apache_airflow_providers_amazon-9.3.0.dist-info}/WHEEL +0 -0
  62. {apache_airflow_providers_amazon-9.2.0rc1.dist-info → apache_airflow_providers_amazon-9.3.0.dist-info}/entry_points.txt +0 -0
@@ -58,7 +58,6 @@ from airflow.providers.amazon.aws.utils.suppress import return_on_error
58
58
  from airflow.providers_manager import ProvidersManager
59
59
  from airflow.utils.helpers import exactly_one
60
60
  from airflow.utils.log.logging_mixin import LoggingMixin
61
- from airflow.utils.log.secrets_masker import mask_secret
62
61
 
63
62
  BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
64
63
 
@@ -68,6 +67,12 @@ if TYPE_CHECKING:
68
67
  from botocore.credentials import ReadOnlyCredentials
69
68
 
70
69
  from airflow.models.connection import Connection
70
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
71
+ else:
72
+ try:
73
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
74
+ except ImportError:
75
+ from airflow.utils.log.secrets_masker import mask_secret
71
76
 
72
77
  _loader = botocore.loaders.Loader()
73
78
  """
@@ -30,7 +30,7 @@ from __future__ import annotations
30
30
  import itertools
31
31
  import random
32
32
  import time
33
- from typing import TYPE_CHECKING, Callable
33
+ from typing import TYPE_CHECKING, Callable, Protocol, runtime_checkable
34
34
 
35
35
  import botocore.client
36
36
  import botocore.exceptions
@@ -38,7 +38,6 @@ import botocore.waiter
38
38
 
39
39
  from airflow.exceptions import AirflowException
40
40
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
41
- from airflow.typing_compat import Protocol, runtime_checkable
42
41
 
43
42
  if TYPE_CHECKING:
44
43
  from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
@@ -23,11 +23,17 @@ from dataclasses import dataclass
23
23
  from typing import TYPE_CHECKING
24
24
 
25
25
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26
- from airflow.utils.log.secrets_masker import mask_secret
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  from datetime import datetime
30
29
 
30
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
31
+ else:
32
+ try:
33
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
34
+ except ImportError:
35
+ from airflow.utils.log.secrets_masker import mask_secret
36
+
31
37
  logger = logging.getLogger(__name__)
32
38
 
33
39
 
@@ -17,12 +17,11 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from typing import TYPE_CHECKING
20
+ from typing import TYPE_CHECKING, Protocol, runtime_checkable
21
21
 
22
22
  from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
23
23
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
24
24
  from airflow.providers.amazon.aws.utils import _StringCompareEnum
25
- from airflow.typing_compat import Protocol, runtime_checkable
26
25
 
27
26
  if TYPE_CHECKING:
28
27
  from botocore.waiter import Waiter
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
 
21
21
  import base64
22
22
  import json
23
+ import os
23
24
  import sys
24
25
  import tempfile
25
26
  from collections.abc import Generator
@@ -32,6 +33,7 @@ from botocore.exceptions import ClientError
32
33
  from botocore.signers import RequestSigner
33
34
 
34
35
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
36
+ from airflow.providers.amazon.aws.hooks.sts import StsHook
35
37
  from airflow.utils import yaml
36
38
  from airflow.utils.json import AirflowJsonEncoder
37
39
 
@@ -612,9 +614,14 @@ class EksHook(AwsBaseHook):
612
614
  def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:
613
615
  session = self.get_session()
614
616
  service_id = self.conn.meta.service_model.service_id
615
- sts_url = (
616
- f"https://sts.{session.region_name}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15"
617
- )
617
+ # This env variable is required so that we get a regionalized endpoint for STS in regions that
618
+ # otherwise default to global endpoints. The mechanism below to generate the token is very picky that
619
+ # the endpoint is regional.
620
+ os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
621
+ try:
622
+ sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
623
+ finally:
624
+ del os.environ["AWS_STS_REGIONAL_ENDPOINTS"]
618
625
 
619
626
  signer = RequestSigner(
620
627
  service_id=service_id,
@@ -22,7 +22,9 @@ import time
22
22
  import warnings
23
23
  from typing import Any
24
24
 
25
+ import tenacity
25
26
  from botocore.exceptions import ClientError
27
+ from tenacity import retry_if_exception, stop_after_attempt, wait_fixed
26
28
 
27
29
  from airflow.exceptions import AirflowException, AirflowNotFoundException
28
30
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -311,6 +313,15 @@ class EmrServerlessHook(AwsBaseHook):
311
313
  return count
312
314
 
313
315
 
316
+ def is_connection_being_updated_exception(exception: BaseException) -> bool:
317
+ return (
318
+ isinstance(exception, ClientError)
319
+ and exception.response["Error"]["Code"] == "ValidationException"
320
+ and "is not reachable as its connection is currently being updated"
321
+ in exception.response["Error"]["Message"]
322
+ )
323
+
324
+
314
325
  class EmrContainerHook(AwsBaseHook):
315
326
  """
316
327
  Interact with Amazon EMR Containers (Amazon EMR on EKS).
@@ -348,6 +359,15 @@ class EmrContainerHook(AwsBaseHook):
348
359
  super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore
349
360
  self.virtual_cluster_id = virtual_cluster_id
350
361
 
362
+ # Retry this method when the ``create_virtual_cluster`` raises
363
+ # "Cluster XXX is not reachable as its connection is currently being updated".
364
+ # Even though the EKS cluster status is ``ACTIVE``, ``create_virtual_cluster`` can raise this error.
365
+ # Retrying is the only option. Retry up to 3 minutes
366
+ @tenacity.retry(
367
+ retry=retry_if_exception(is_connection_being_updated_exception),
368
+ stop=stop_after_attempt(12),
369
+ wait=wait_fixed(15),
370
+ )
351
371
  def create_emr_on_eks_cluster(
352
372
  self,
353
373
  virtual_cluster_name: str,
@@ -0,0 +1,85 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """This module contains AWS MWAA hook."""
18
+
19
+ from __future__ import annotations
20
+
21
+ from botocore.exceptions import ClientError
22
+
23
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
24
+
25
+
26
+ class MwaaHook(AwsBaseHook):
27
+ """
28
+ Interact with AWS Manager Workflows for Apache Airflow.
29
+
30
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
31
+
32
+ Additional arguments (such as ``aws_conn_id``) may be specified and
33
+ are passed down to the underlying AwsBaseHook.
34
+
35
+ .. seealso::
36
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
37
+ """
38
+
39
+ def __init__(self, *args, **kwargs) -> None:
40
+ kwargs["client_type"] = "mwaa"
41
+ super().__init__(*args, **kwargs)
42
+
43
+ def invoke_rest_api(
44
+ self,
45
+ env_name: str,
46
+ path: str,
47
+ method: str,
48
+ body: dict | None = None,
49
+ query_params: dict | None = None,
50
+ ) -> dict:
51
+ """
52
+ Invoke the REST API on the Airflow webserver with the specified inputs.
53
+
54
+ .. seealso::
55
+ - :external+boto3:py:meth:`MWAA.Client.invoke_rest_api`
56
+
57
+ :param env_name: name of the MWAA environment
58
+ :param path: Apache Airflow REST API endpoint path to be called
59
+ :param method: HTTP method used for making Airflow REST API calls
60
+ :param body: Request body for the Apache Airflow REST API call
61
+ :param query_params: Query parameters to be included in the Apache Airflow REST API call
62
+ """
63
+ body = body or {}
64
+ api_kwargs = {
65
+ "Name": env_name,
66
+ "Path": path,
67
+ "Method": method,
68
+ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
69
+ "Body": {k: v for k, v in body.items() if v is not None},
70
+ "QueryParameters": query_params if query_params else {},
71
+ }
72
+ try:
73
+ result = self.conn.invoke_rest_api(**api_kwargs)
74
+ # ResponseMetadata is removed because it contains data that is either very unlikely to be useful
75
+ # in XComs and logs, or redundant given the data already included in the response
76
+ result.pop("ResponseMetadata", None)
77
+ return result
78
+ except ClientError as e:
79
+ to_log = e.response
80
+ # ResponseMetadata and Error are removed because they contain data that is either very unlikely to
81
+ # be useful in XComs and logs, or redundant given the data already included in the response
82
+ to_log.pop("ResponseMetadata", None)
83
+ to_log.pop("Error", None)
84
+ self.log.error(to_log)
85
+ raise e
@@ -59,6 +59,7 @@ class SqsHook(AwsBaseHook):
59
59
  delay_seconds: int = 0,
60
60
  message_attributes: dict | None = None,
61
61
  message_group_id: str | None = None,
62
+ message_deduplication_id: str | None = None,
62
63
  ) -> dict:
63
64
  """
64
65
  Send message to the queue.
@@ -71,6 +72,7 @@ class SqsHook(AwsBaseHook):
71
72
  :param delay_seconds: seconds to delay the message
72
73
  :param message_attributes: additional attributes for the message (default: None)
73
74
  :param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None)
75
+ :param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
74
76
  :return: dict with the information about the message sent
75
77
  """
76
78
  params = {
@@ -81,5 +83,7 @@ class SqsHook(AwsBaseHook):
81
83
  }
82
84
  if message_group_id:
83
85
  params["MessageGroupId"] = message_group_id
86
+ if message_deduplication_id:
87
+ params["MessageDeduplicationId"] = message_deduplication_id
84
88
 
85
89
  return self.get_conn().send_message(**params)
@@ -17,10 +17,19 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from typing import TYPE_CHECKING
21
+
20
22
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
21
- from airflow.utils.log.secrets_masker import mask_secret
22
23
  from airflow.utils.types import NOTSET, ArgNotSet
23
24
 
25
+ if TYPE_CHECKING:
26
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
27
+ else:
28
+ try:
29
+ from airflow.sdk.execution_time.secrets_masker import mask_secret
30
+ except ImportError:
31
+ from airflow.utils.log.secrets_masker import mask_secret
32
+
24
33
 
25
34
  class SsmHook(AwsBaseHook):
26
35
  """
@@ -0,0 +1,41 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class ComprehendPiiEntitiesDetectionLink(BaseAwsLink):
23
+ """Helper class for constructing Amazon Comprehend PII Detection console link."""
24
+
25
+ name = "PII Detection Job"
26
+ key = "comprehend_pii_detection"
27
+ format_str = (
28
+ BASE_AWS_CONSOLE_LINK
29
+ + "/comprehend/home?region={region_name}#"
30
+ + "/analysis-job-details/pii/{job_id}"
31
+ )
32
+
33
+
34
+ class ComprehendDocumentClassifierLink(BaseAwsLink):
35
+ """Helper class for constructing Amazon Comprehend Document Classifier console link."""
36
+
37
+ name = "Document Classifier"
38
+ key = "comprehend_document_classifier"
39
+ format_str = (
40
+ BASE_AWS_CONSOLE_LINK + "/comprehend/home?region={region_name}#" + "classifier-version-details/{arn}"
41
+ )
@@ -0,0 +1,37 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class DataSyncTaskLink(BaseAwsLink):
23
+ """Helper class for constructing AWS DataSync Task console link."""
24
+
25
+ name = "DataSync Task"
26
+ key = "datasync_task"
27
+ format_str = BASE_AWS_CONSOLE_LINK + "/datasync/home?region={region_name}#" + "/tasks/{task_id}"
28
+
29
+
30
+ class DataSyncTaskExecutionLink(BaseAwsLink):
31
+ """Helper class for constructing AWS DataSync TaskExecution console link."""
32
+
33
+ name = "DataSync Task Execution"
34
+ key = "datasync_task_execution"
35
+ format_str = (
36
+ BASE_AWS_CONSOLE_LINK + "/datasync/home?region={region_name}#/history/{task_id}/{task_execution_id}"
37
+ )
@@ -0,0 +1,46 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class EC2InstanceLink(BaseAwsLink):
23
+ """Helper class for constructing Amazon EC2 instance links."""
24
+
25
+ name = "Instance"
26
+ key = "_instance_id"
27
+ format_str = (
28
+ BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#InstanceDetails:instanceId={instance_id}"
29
+ )
30
+
31
+
32
+ class EC2InstanceDashboardLink(BaseAwsLink):
33
+ """
34
+ Helper class for constructing Amazon EC2 console links.
35
+
36
+ This is useful for displaying the list of EC2 instances, rather
37
+ than a single instance.
38
+ """
39
+
40
+ name = "EC2 Instances"
41
+ key = "_instance_dashboard"
42
+ format_str = BASE_AWS_CONSOLE_LINK + "/ec2/home?region={region_name}#Instances:instanceId=:{instance_ids}"
43
+
44
+ @staticmethod
45
+ def format_instance_id_filter(instance_ids: list[str]) -> str:
46
+ return ",:".join(instance_ids)
@@ -0,0 +1,27 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
20
+
21
+
22
+ class SageMakerTransformJobLink(BaseAwsLink):
23
+ """Helper class for constructing AWS Transform Run Details Link."""
24
+
25
+ name = "Amazon SageMaker Transform Job Details"
26
+ key = "sagemaker_transform_job_details"
27
+ format_str = BASE_AWS_CONSOLE_LINK + "/sagemaker/home?region={region_name}#/transform-jobs/{job_name}"
@@ -177,14 +177,16 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
177
177
  return self.query_execution_id
178
178
 
179
179
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
180
- event = validate_execute_complete_event(event)
180
+ validated_event = validate_execute_complete_event(event)
181
181
 
182
- if event["status"] != "success":
183
- raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
182
+ if validated_event["status"] != "success":
183
+ raise AirflowException(
184
+ f"Error while waiting for operation on cluster to complete: {validated_event}"
185
+ )
184
186
 
185
187
  # Save query_execution_id to be later used by listeners
186
- self.query_execution_id = event["value"]
187
- return event["value"]
188
+ self.query_execution_id = validated_event["value"]
189
+ return validated_event["value"]
188
190
 
189
191
  def on_kill(self) -> None:
190
192
  """Cancel the submitted Amazon Athena query."""
@@ -95,6 +95,7 @@ class BatchOperator(BaseOperator):
95
95
  If it is an array job, only the logs of the first task will be printed.
96
96
  :param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec.
97
97
  :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling.
98
+ :param submit_job_timeout: Execution timeout in seconds for submitted batch job.
98
99
 
99
100
  .. note::
100
101
  Any custom waiters must return a waiter for these calls:
@@ -184,6 +185,7 @@ class BatchOperator(BaseOperator):
184
185
  poll_interval: int = 30,
185
186
  awslogs_enabled: bool = False,
186
187
  awslogs_fetch_interval: timedelta = timedelta(seconds=30),
188
+ submit_job_timeout: int | None = None,
187
189
  **kwargs,
188
190
  ) -> None:
189
191
  BaseOperator.__init__(self, **kwargs)
@@ -208,6 +210,7 @@ class BatchOperator(BaseOperator):
208
210
  self.poll_interval = poll_interval
209
211
  self.awslogs_enabled = awslogs_enabled
210
212
  self.awslogs_fetch_interval = awslogs_fetch_interval
213
+ self.submit_job_timeout = submit_job_timeout
211
214
 
212
215
  # params for hook
213
216
  self.max_retries = max_retries
@@ -264,13 +267,13 @@ class BatchOperator(BaseOperator):
264
267
  return self.job_id
265
268
 
266
269
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
267
- event = validate_execute_complete_event(event)
270
+ validated_event = validate_execute_complete_event(event)
268
271
 
269
- if event["status"] != "success":
270
- raise AirflowException(f"Error while running job: {event}")
272
+ if validated_event["status"] != "success":
273
+ raise AirflowException(f"Error while running job: {validated_event}")
271
274
 
272
275
  self.log.info("Job completed.")
273
- return event["job_id"]
276
+ return validated_event["job_id"]
274
277
 
275
278
  def on_kill(self):
276
279
  response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
@@ -315,6 +318,9 @@ class BatchOperator(BaseOperator):
315
318
  "schedulingPriorityOverride": self.scheduling_priority_override,
316
319
  }
317
320
 
321
+ if self.submit_job_timeout:
322
+ args["timeout"] = {"attemptDurationSeconds": self.submit_job_timeout}
323
+
318
324
  try:
319
325
  response = self.hook.client.submit_job(**trim_none_values(args))
320
326
  except Exception as e:
@@ -534,8 +540,10 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
534
540
  return arn
535
541
 
536
542
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
537
- event = validate_execute_complete_event(event)
543
+ validated_event = validate_execute_complete_event(event)
538
544
 
539
- if event["status"] != "success":
540
- raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}")
541
- return event["value"]
545
+ if validated_event["status"] != "success":
546
+ raise AirflowException(
547
+ f"Error while waiting for the compute environment to be ready: {validated_event}"
548
+ )
549
+ return validated_event["value"]
@@ -198,13 +198,13 @@ class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]):
198
198
  self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"}
199
199
 
200
200
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
201
- event = validate_execute_complete_event(event)
201
+ validated_event = validate_execute_complete_event(event)
202
202
 
203
- if event["status"] != "success":
204
- raise AirflowException(f"Error while running job: {event}")
203
+ if validated_event["status"] != "success":
204
+ raise AirflowException(f"Error while running job: {validated_event}")
205
205
 
206
206
  self.log.info("Bedrock model customization job `%s` complete.", self.job_name)
207
- return self.hook.conn.get_model_customization_job(jobIdentifier=event["job_name"])["jobArn"]
207
+ return self.hook.conn.get_model_customization_job(jobIdentifier=validated_event["job_name"])["jobArn"]
208
208
 
209
209
  def execute(self, context: Context) -> dict:
210
210
  response = {}
@@ -353,13 +353,15 @@ class BedrockCreateProvisionedModelThroughputOperator(AwsBaseOperator[BedrockHoo
353
353
  return provisioned_model_id
354
354
 
355
355
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
356
- event = validate_execute_complete_event(event)
356
+ validated_event = validate_execute_complete_event(event)
357
357
 
358
- if event["status"] != "success":
359
- raise AirflowException(f"Error while running job: {event}")
358
+ if validated_event["status"] != "success":
359
+ raise AirflowException(f"Error while running job: {validated_event}")
360
360
 
361
- self.log.info("Bedrock provisioned throughput job `%s` complete.", event["provisioned_model_id"])
362
- return event["provisioned_model_id"]
361
+ self.log.info(
362
+ "Bedrock provisioned throughput job `%s` complete.", validated_event["provisioned_model_id"]
363
+ )
364
+ return validated_event["provisioned_model_id"]
363
365
 
364
366
 
365
367
  class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]):
@@ -449,13 +451,13 @@ class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]):
449
451
  self.deferrable = deferrable
450
452
 
451
453
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
452
- event = validate_execute_complete_event(event)
454
+ validated_event = validate_execute_complete_event(event)
453
455
 
454
- if event["status"] != "success":
455
- raise AirflowException(f"Error while running job: {event}")
456
+ if validated_event["status"] != "success":
457
+ raise AirflowException(f"Error while running job: {validated_event}")
456
458
 
457
459
  self.log.info("Bedrock knowledge base creation job `%s` complete.", self.name)
458
- return event["knowledge_base_id"]
460
+ return validated_event["knowledge_base_id"]
459
461
 
460
462
  def execute(self, context: Context) -> str:
461
463
  def _create_kb():
@@ -634,14 +636,14 @@ class BedrockIngestDataOperator(AwsBaseOperator[BedrockAgentHook]):
634
636
  self.deferrable = deferrable
635
637
 
636
638
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
637
- event = validate_execute_complete_event(event)
639
+ validated_event = validate_execute_complete_event(event)
638
640
 
639
- if event["status"] != "success":
640
- raise AirflowException(f"Error while running ingestion job: {event}")
641
+ if validated_event["status"] != "success":
642
+ raise AirflowException(f"Error while running ingestion job: {validated_event}")
641
643
 
642
- self.log.info("Bedrock ingestion job `%s` complete.", event["ingestion_job_id"])
644
+ self.log.info("Bedrock ingestion job `%s` complete.", validated_event["ingestion_job_id"])
643
645
 
644
- return event["ingestion_job_id"]
646
+ return validated_event["ingestion_job_id"]
645
647
 
646
648
  def execute(self, context: Context) -> str:
647
649
  ingestion_job_id = self.hook.conn.start_ingestion_job(