apache-airflow-providers-amazon 8.19.0rc1__py3-none-any.whl → 8.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +4 -2
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +22 -7
- airflow/providers/amazon/aws/auth_manager/{cli → avp}/schema.json +34 -2
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +91 -170
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +7 -32
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +1 -0
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/executors/batch/__init__.py +16 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +420 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor_config.py +87 -0
- airflow/providers/amazon/aws/executors/batch/boto_schema.py +67 -0
- airflow/providers/amazon/aws/executors/batch/utils.py +160 -0
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +61 -18
- airflow/providers/amazon/aws/executors/ecs/utils.py +8 -13
- airflow/providers/amazon/aws/executors/utils/base_config_keys.py +25 -0
- airflow/providers/amazon/aws/hooks/athena.py +1 -0
- airflow/providers/amazon/aws/hooks/base_aws.py +1 -0
- airflow/providers/amazon/aws/hooks/batch_client.py +4 -3
- airflow/providers/amazon/aws/hooks/batch_waiters.py +1 -0
- airflow/providers/amazon/aws/hooks/bedrock.py +59 -0
- airflow/providers/amazon/aws/hooks/chime.py +1 -0
- airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -0
- airflow/providers/amazon/aws/hooks/datasync.py +1 -0
- airflow/providers/amazon/aws/hooks/dynamodb.py +1 -0
- airflow/providers/amazon/aws/hooks/eks.py +1 -0
- airflow/providers/amazon/aws/hooks/glue.py +13 -5
- airflow/providers/amazon/aws/hooks/glue_catalog.py +1 -0
- airflow/providers/amazon/aws/hooks/kinesis.py +1 -0
- airflow/providers/amazon/aws/hooks/lambda_function.py +1 -0
- airflow/providers/amazon/aws/hooks/rds.py +1 -0
- airflow/providers/amazon/aws/hooks/s3.py +24 -30
- airflow/providers/amazon/aws/hooks/ses.py +1 -0
- airflow/providers/amazon/aws/hooks/sns.py +1 -0
- airflow/providers/amazon/aws/hooks/sqs.py +1 -0
- airflow/providers/amazon/aws/operators/athena.py +2 -2
- airflow/providers/amazon/aws/operators/base_aws.py +4 -1
- airflow/providers/amazon/aws/operators/batch.py +4 -2
- airflow/providers/amazon/aws/operators/bedrock.py +252 -0
- airflow/providers/amazon/aws/operators/cloud_formation.py +1 -0
- airflow/providers/amazon/aws/operators/datasync.py +1 -0
- airflow/providers/amazon/aws/operators/ecs.py +9 -10
- airflow/providers/amazon/aws/operators/eks.py +1 -0
- airflow/providers/amazon/aws/operators/emr.py +57 -7
- airflow/providers/amazon/aws/operators/s3.py +1 -0
- airflow/providers/amazon/aws/operators/sns.py +1 -0
- airflow/providers/amazon/aws/operators/sqs.py +1 -0
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -0
- airflow/providers/amazon/aws/secrets/systems_manager.py +1 -0
- airflow/providers/amazon/aws/sensors/base_aws.py +4 -1
- airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
- airflow/providers/amazon/aws/sensors/cloud_formation.py +1 -0
- airflow/providers/amazon/aws/sensors/eks.py +3 -4
- airflow/providers/amazon/aws/sensors/sqs.py +2 -1
- airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -2
- airflow/providers/amazon/aws/transfers/base.py +1 -0
- airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -0
- airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +21 -19
- airflow/providers/amazon/aws/triggers/bedrock.py +61 -0
- airflow/providers/amazon/aws/triggers/eks.py +1 -1
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +1 -0
- airflow/providers/amazon/aws/triggers/s3.py +4 -2
- airflow/providers/amazon/aws/triggers/sagemaker.py +6 -4
- airflow/providers/amazon/aws/utils/emailer.py +1 -0
- airflow/providers/amazon/aws/waiters/bedrock.json +42 -0
- airflow/providers/amazon/get_provider_info.py +86 -1
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/METADATA +10 -9
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/RECORD +77 -66
- /airflow/providers/amazon/aws/executors/{ecs/Dockerfile → Dockerfile} +0 -0
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/entry_points.txt +0 -0
@@ -576,7 +576,7 @@ class EmrContainerOperator(BaseOperator):
|
|
576
576
|
stacklevel=2,
|
577
577
|
)
|
578
578
|
if max_polling_attempts and max_polling_attempts != max_tries:
|
579
|
-
raise
|
579
|
+
raise ValueError("max_polling_attempts must be the same value as max_tries")
|
580
580
|
else:
|
581
581
|
self.max_polling_attempts = max_tries
|
582
582
|
|
@@ -1253,27 +1253,77 @@ class EmrServerlessStartJobOperator(BaseOperator):
|
|
1253
1253
|
op_extra_links = []
|
1254
1254
|
|
1255
1255
|
if isinstance(self, MappedOperator):
|
1256
|
+
operator_class = self.operator_class
|
1256
1257
|
enable_application_ui_links = self.partial_kwargs.get(
|
1257
1258
|
"enable_application_ui_links"
|
1258
1259
|
) or self.expand_input.value.get("enable_application_ui_links")
|
1259
|
-
job_driver = self.partial_kwargs.get("job_driver") or self.expand_input.value.get(
|
1260
|
+
job_driver = self.partial_kwargs.get("job_driver", {}) or self.expand_input.value.get(
|
1261
|
+
"job_driver", {}
|
1262
|
+
)
|
1260
1263
|
configuration_overrides = self.partial_kwargs.get(
|
1261
1264
|
"configuration_overrides"
|
1262
1265
|
) or self.expand_input.value.get("configuration_overrides")
|
1263
1266
|
|
1267
|
+
# Configuration overrides can either be a list or a dictionary, depending on whether it's passed in as partial or expand.
|
1268
|
+
if isinstance(configuration_overrides, list):
|
1269
|
+
if any(
|
1270
|
+
[
|
1271
|
+
operator_class.is_monitoring_in_job_override(
|
1272
|
+
self=operator_class,
|
1273
|
+
config_key="s3MonitoringConfiguration",
|
1274
|
+
job_override=job_override,
|
1275
|
+
)
|
1276
|
+
for job_override in configuration_overrides
|
1277
|
+
]
|
1278
|
+
):
|
1279
|
+
op_extra_links.extend([EmrServerlessS3LogsLink()])
|
1280
|
+
if any(
|
1281
|
+
[
|
1282
|
+
operator_class.is_monitoring_in_job_override(
|
1283
|
+
self=operator_class,
|
1284
|
+
config_key="cloudWatchLoggingConfiguration",
|
1285
|
+
job_override=job_override,
|
1286
|
+
)
|
1287
|
+
for job_override in configuration_overrides
|
1288
|
+
]
|
1289
|
+
):
|
1290
|
+
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
|
1291
|
+
else:
|
1292
|
+
if operator_class.is_monitoring_in_job_override(
|
1293
|
+
self=operator_class,
|
1294
|
+
config_key="s3MonitoringConfiguration",
|
1295
|
+
job_override=configuration_overrides,
|
1296
|
+
):
|
1297
|
+
op_extra_links.extend([EmrServerlessS3LogsLink()])
|
1298
|
+
if operator_class.is_monitoring_in_job_override(
|
1299
|
+
self=operator_class,
|
1300
|
+
config_key="cloudWatchLoggingConfiguration",
|
1301
|
+
job_override=configuration_overrides,
|
1302
|
+
):
|
1303
|
+
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
|
1304
|
+
|
1264
1305
|
else:
|
1306
|
+
operator_class = self
|
1265
1307
|
enable_application_ui_links = self.enable_application_ui_links
|
1266
1308
|
configuration_overrides = self.configuration_overrides
|
1267
1309
|
job_driver = self.job_driver
|
1268
1310
|
|
1311
|
+
if operator_class.is_monitoring_in_job_override(
|
1312
|
+
"s3MonitoringConfiguration", configuration_overrides
|
1313
|
+
):
|
1314
|
+
op_extra_links.extend([EmrServerlessS3LogsLink()])
|
1315
|
+
if operator_class.is_monitoring_in_job_override(
|
1316
|
+
"cloudWatchLoggingConfiguration", configuration_overrides
|
1317
|
+
):
|
1318
|
+
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
|
1319
|
+
|
1269
1320
|
if enable_application_ui_links:
|
1270
1321
|
op_extra_links.extend([EmrServerlessDashboardLink()])
|
1271
|
-
if
|
1322
|
+
if isinstance(job_driver, list):
|
1323
|
+
if any("sparkSubmit" in ind_job_driver for ind_job_driver in job_driver):
|
1324
|
+
op_extra_links.extend([EmrServerlessLogsLink()])
|
1325
|
+
elif "sparkSubmit" in job_driver:
|
1272
1326
|
op_extra_links.extend([EmrServerlessLogsLink()])
|
1273
|
-
if self.is_monitoring_in_job_override("s3MonitoringConfiguration", configuration_overrides):
|
1274
|
-
op_extra_links.extend([EmrServerlessS3LogsLink()])
|
1275
|
-
if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides):
|
1276
|
-
op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
|
1277
1327
|
|
1278
1328
|
return tuple(op_extra_links)
|
1279
1329
|
|
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.utils.mixins import (
|
|
26
26
|
aws_template_fields,
|
27
27
|
)
|
28
28
|
from airflow.sensors.base import BaseSensorOperator
|
29
|
+
from airflow.utils.types import NOTSET, ArgNotSet
|
29
30
|
|
30
31
|
|
31
32
|
class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
|
@@ -84,10 +85,12 @@ class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
|
|
84
85
|
region_name: str | None = None,
|
85
86
|
verify: bool | str | None = None,
|
86
87
|
botocore_config: dict | None = None,
|
88
|
+
region: str | None | ArgNotSet = NOTSET, # Required for `.partial` signature check
|
87
89
|
**kwargs,
|
88
90
|
):
|
91
|
+
additional_params = {} if region is NOTSET else {"region": region}
|
89
92
|
hook_params = AwsHookParams.from_constructor(
|
90
|
-
aws_conn_id, region_name, verify, botocore_config, additional_params=
|
93
|
+
aws_conn_id, region_name, verify, botocore_config, additional_params=additional_params
|
91
94
|
)
|
92
95
|
super().__init__(**kwargs)
|
93
96
|
self.aws_conn_id = hook_params.aws_conn_id
|
@@ -0,0 +1,110 @@
|
|
1
|
+
#
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
4
|
+
# distributed with this work for additional information
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
7
|
+
# "License"); you may not use this file except in compliance
|
8
|
+
# with the License. You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
13
|
+
# software distributed under the License is distributed on an
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
15
|
+
# KIND, either express or implied. See the License for the
|
16
|
+
# specific language governing permissions and limitations
|
17
|
+
# under the License.
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
21
|
+
|
22
|
+
from airflow.configuration import conf
|
23
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
24
|
+
from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger
|
25
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from airflow.utils.context import Context
|
29
|
+
|
30
|
+
from airflow.exceptions import AirflowException, AirflowSkipException
|
31
|
+
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
|
32
|
+
|
33
|
+
|
34
|
+
class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]):
|
35
|
+
"""
|
36
|
+
Poll the state of the model customization job until it reaches a terminal state; fails if the job fails.
|
37
|
+
|
38
|
+
.. seealso::
|
39
|
+
For more information on how to use this sensor, take a look at the guide:
|
40
|
+
:ref:`howto/sensor:BedrockCustomizeModelCompletedSensor`
|
41
|
+
|
42
|
+
|
43
|
+
:param job_name: The name of the Bedrock model customization job.
|
44
|
+
|
45
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
46
|
+
module to be installed.
|
47
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
48
|
+
:param max_retries: Number of times before returning the current state. (default: 75)
|
49
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
50
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
51
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
52
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
53
|
+
empty, then default boto3 configuration would be used (and must be
|
54
|
+
maintained on each worker node).
|
55
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
56
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
57
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
58
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
59
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
60
|
+
"""
|
61
|
+
|
62
|
+
INTERMEDIATE_STATES = ("InProgress",)
|
63
|
+
FAILURE_STATES = ("Failed", "Stopping", "Stopped")
|
64
|
+
SUCCESS_STATES = ("Completed",)
|
65
|
+
FAILURE_MESSAGE = "Bedrock model customization job sensor failed."
|
66
|
+
|
67
|
+
aws_hook_class = BedrockHook
|
68
|
+
template_fields: Sequence[str] = aws_template_fields("job_name")
|
69
|
+
ui_color = "#66c3ff"
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
*,
|
74
|
+
job_name: str,
|
75
|
+
max_retries: int = 75,
|
76
|
+
poke_interval: int = 120,
|
77
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
78
|
+
**kwargs: Any,
|
79
|
+
) -> None:
|
80
|
+
super().__init__(**kwargs)
|
81
|
+
self.job_name = job_name
|
82
|
+
self.poke_interval = poke_interval
|
83
|
+
self.max_retries = max_retries
|
84
|
+
self.deferrable = deferrable
|
85
|
+
|
86
|
+
def execute(self, context: Context) -> Any:
|
87
|
+
if self.deferrable:
|
88
|
+
self.defer(
|
89
|
+
trigger=BedrockCustomizeModelCompletedTrigger(
|
90
|
+
job_name=self.job_name,
|
91
|
+
waiter_delay=int(self.poke_interval),
|
92
|
+
waiter_max_attempts=self.max_retries,
|
93
|
+
aws_conn_id=self.aws_conn_id,
|
94
|
+
),
|
95
|
+
method_name="poke",
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
super().execute(context=context)
|
99
|
+
|
100
|
+
def poke(self, context: Context) -> bool:
|
101
|
+
state = self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]
|
102
|
+
self.log.info("Job '%s' state: %s", self.job_name, state)
|
103
|
+
|
104
|
+
if state in self.FAILURE_STATES:
|
105
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
106
|
+
if self.soft_fail:
|
107
|
+
raise AirflowSkipException(self.FAILURE_MESSAGE)
|
108
|
+
raise AirflowException(self.FAILURE_MESSAGE)
|
109
|
+
|
110
|
+
return state not in self.INTERMEDIATE_STATES
|
@@ -15,6 +15,7 @@
|
|
15
15
|
# specific language governing permissions and limitations
|
16
16
|
# under the License.
|
17
17
|
"""Tracking the state of Amazon EKS Clusters, Amazon EKS managed node groups, and AWS Fargate profiles."""
|
18
|
+
|
18
19
|
from __future__ import annotations
|
19
20
|
|
20
21
|
from abc import abstractmethod
|
@@ -114,12 +115,10 @@ class EksBaseSensor(BaseSensorOperator):
|
|
114
115
|
return state == self.target_state
|
115
116
|
|
116
117
|
@abstractmethod
|
117
|
-
def get_state(self) -> ClusterStates | NodegroupStates | FargateProfileStates:
|
118
|
-
...
|
118
|
+
def get_state(self) -> ClusterStates | NodegroupStates | FargateProfileStates: ...
|
119
119
|
|
120
120
|
@abstractmethod
|
121
|
-
def get_terminal_states(self) -> frozenset:
|
122
|
-
...
|
121
|
+
def get_terminal_states(self) -> frozenset: ...
|
123
122
|
|
124
123
|
|
125
124
|
class EksClusterStateSensor(EksBaseSensor):
|
@@ -16,6 +16,7 @@
|
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
18
|
"""Reads and then deletes the message from SQS queue."""
|
19
|
+
|
19
20
|
from __future__ import annotations
|
20
21
|
|
21
22
|
from datetime import timedelta
|
@@ -72,7 +73,7 @@ class SqsSensor(AwsBaseSensor[SqsHook]):
|
|
72
73
|
:param delete_message_on_reception: Default to `True`, the messages are deleted from the queue
|
73
74
|
as soon as being consumed. Otherwise, the messages remain in the queue after consumption and
|
74
75
|
should be deleted manually.
|
75
|
-
:param deferrable: If True, the sensor will operate in deferrable
|
76
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
76
77
|
module to be installed.
|
77
78
|
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
78
79
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
@@ -120,8 +120,10 @@ class AzureBlobStorageToS3Operator(BaseOperator):
|
|
120
120
|
)
|
121
121
|
|
122
122
|
self.log.info(
|
123
|
-
|
124
|
-
|
123
|
+
"Getting list of the files in Container: %r; Prefix: %r; Delimiter: %r.",
|
124
|
+
self.container_name,
|
125
|
+
self.prefix,
|
126
|
+
self.delimiter,
|
125
127
|
)
|
126
128
|
|
127
129
|
files = wasb_hook.get_blobs_list_recursive(
|
@@ -16,6 +16,7 @@
|
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
18
|
"""This module allows you to transfer mail attachments from a mail server into s3 bucket."""
|
19
|
+
|
19
20
|
from __future__ import annotations
|
20
21
|
|
21
22
|
from typing import TYPE_CHECKING, Sequence
|
@@ -16,6 +16,7 @@
|
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
18
|
"""Transfers data from AWS Redshift into a S3 Bucket."""
|
19
|
+
|
19
20
|
from __future__ import annotations
|
20
21
|
|
21
22
|
import re
|
@@ -109,35 +110,19 @@ class RedshiftToS3Operator(BaseOperator):
|
|
109
110
|
) -> None:
|
110
111
|
super().__init__(**kwargs)
|
111
112
|
self.s3_bucket = s3_bucket
|
112
|
-
self.s3_key =
|
113
|
+
self.s3_key = s3_key
|
113
114
|
self.schema = schema
|
114
115
|
self.table = table
|
115
116
|
self.redshift_conn_id = redshift_conn_id
|
116
117
|
self.aws_conn_id = aws_conn_id
|
117
118
|
self.verify = verify
|
118
|
-
self.unload_options
|
119
|
+
self.unload_options = unload_options or []
|
119
120
|
self.autocommit = autocommit
|
120
121
|
self.include_header = include_header
|
121
122
|
self.parameters = parameters
|
122
123
|
self.table_as_file_name = table_as_file_name
|
123
124
|
self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
|
124
|
-
|
125
|
-
if select_query:
|
126
|
-
self.select_query = select_query
|
127
|
-
elif self.schema and self.table:
|
128
|
-
self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
|
129
|
-
else:
|
130
|
-
raise ValueError(
|
131
|
-
"Please provide both `schema` and `table` params or `select_query` to fetch the data."
|
132
|
-
)
|
133
|
-
|
134
|
-
if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
|
135
|
-
self.unload_options = [*self.unload_options, "HEADER"]
|
136
|
-
|
137
|
-
if self.redshift_data_api_kwargs:
|
138
|
-
for arg in ["sql", "parameters"]:
|
139
|
-
if arg in self.redshift_data_api_kwargs:
|
140
|
-
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
|
125
|
+
self.select_query = select_query
|
141
126
|
|
142
127
|
def _build_unload_query(
|
143
128
|
self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
|
@@ -153,9 +138,26 @@ class RedshiftToS3Operator(BaseOperator):
|
|
153
138
|
"""
|
154
139
|
|
155
140
|
def execute(self, context: Context) -> None:
|
141
|
+
if self.table and self.table_as_file_name:
|
142
|
+
self.s3_key = f"{self.s3_key}/{self.table}_"
|
143
|
+
|
144
|
+
if self.schema and self.table:
|
145
|
+
self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
|
146
|
+
|
147
|
+
if self.select_query is None:
|
148
|
+
raise ValueError(
|
149
|
+
"Please provide both `schema` and `table` params or `select_query` to fetch the data."
|
150
|
+
)
|
151
|
+
|
152
|
+
if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
|
153
|
+
self.unload_options = [*self.unload_options, "HEADER"]
|
154
|
+
|
156
155
|
redshift_hook: RedshiftDataHook | RedshiftSQLHook
|
157
156
|
if self.redshift_data_api_kwargs:
|
158
157
|
redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
|
158
|
+
for arg in ["sql", "parameters"]:
|
159
|
+
if arg in self.redshift_data_api_kwargs:
|
160
|
+
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
|
159
161
|
else:
|
160
162
|
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
|
161
163
|
conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
|
22
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
26
|
+
|
27
|
+
|
28
|
+
class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger):
|
29
|
+
"""
|
30
|
+
Trigger when a Bedrock model customization job is complete.
|
31
|
+
|
32
|
+
:param job_name: The name of the Bedrock model customization job.
|
33
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
|
34
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
35
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
*,
|
41
|
+
job_name: str,
|
42
|
+
waiter_delay: int = 120,
|
43
|
+
waiter_max_attempts: int = 75,
|
44
|
+
aws_conn_id: str | None = None,
|
45
|
+
) -> None:
|
46
|
+
super().__init__(
|
47
|
+
serialized_fields={"job_name": job_name},
|
48
|
+
waiter_name="model_customization_job_complete",
|
49
|
+
waiter_args={"jobIdentifier": job_name},
|
50
|
+
failure_message="Bedrock model customization failed.",
|
51
|
+
status_message="Status of Bedrock model customization job is",
|
52
|
+
status_queries=["status"],
|
53
|
+
return_key="job_name",
|
54
|
+
return_value=job_name,
|
55
|
+
waiter_delay=waiter_delay,
|
56
|
+
waiter_max_attempts=waiter_max_attempts,
|
57
|
+
aws_conn_id=aws_conn_id,
|
58
|
+
)
|
59
|
+
|
60
|
+
def hook(self) -> AwsGenericHook:
|
61
|
+
return BedrockHook(aws_conn_id=self.aws_conn_id)
|
@@ -214,7 +214,7 @@ class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
|
|
214
214
|
)
|
215
215
|
self.log.info("All Fargate profiles deleted")
|
216
216
|
else:
|
217
|
-
self.log.info(
|
217
|
+
self.log.info("No Fargate profiles associated with cluster %s", self.cluster_name)
|
218
218
|
|
219
219
|
|
220
220
|
class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger):
|
@@ -98,8 +98,9 @@ class S3KeyTrigger(BaseTrigger):
|
|
98
98
|
)
|
99
99
|
await asyncio.sleep(self.poke_interval)
|
100
100
|
yield TriggerEvent({"status": "running", "files": s3_objects})
|
101
|
-
|
102
|
-
|
101
|
+
else:
|
102
|
+
yield TriggerEvent({"status": "success"})
|
103
|
+
return
|
103
104
|
|
104
105
|
self.log.info("Sleeping for %s seconds", self.poke_interval)
|
105
106
|
await asyncio.sleep(self.poke_interval)
|
@@ -204,6 +205,7 @@ class S3KeysUnchangedTrigger(BaseTrigger):
|
|
204
205
|
)
|
205
206
|
if result.get("status") in ("success", "error"):
|
206
207
|
yield TriggerEvent(result)
|
208
|
+
return
|
207
209
|
elif result.get("status") == "pending":
|
208
210
|
self.previous_objects = result.get("previous_objects", set())
|
209
211
|
self.last_activity_time = result.get("last_activity_time")
|
@@ -245,8 +245,8 @@ class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
|
245
245
|
job_already_completed = status not in self.hook.non_terminal_states
|
246
246
|
state = LogState.COMPLETE if job_already_completed else LogState.TAILING
|
247
247
|
last_describe_job_call = time.time()
|
248
|
-
|
249
|
-
|
248
|
+
try:
|
249
|
+
while True:
|
250
250
|
(
|
251
251
|
state,
|
252
252
|
last_description,
|
@@ -267,6 +267,7 @@ class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
|
267
267
|
reason = last_description.get("FailureReason", "(No reason provided)")
|
268
268
|
error_message = f"SageMaker job failed because {reason}"
|
269
269
|
yield TriggerEvent({"status": "error", "message": error_message})
|
270
|
+
return
|
270
271
|
else:
|
271
272
|
billable_seconds = SageMakerHook.count_billable_seconds(
|
272
273
|
training_start_time=last_description["TrainingStartTime"],
|
@@ -275,5 +276,6 @@ class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
|
275
276
|
)
|
276
277
|
self.log.info("Billable seconds: %d", billable_seconds)
|
277
278
|
yield TriggerEvent({"status": "success", "message": last_description})
|
278
|
-
|
279
|
-
|
279
|
+
return
|
280
|
+
except Exception as e:
|
281
|
+
yield TriggerEvent({"status": "error", "message": str(e)})
|
@@ -0,0 +1,42 @@
|
|
1
|
+
{
|
2
|
+
"version": 2,
|
3
|
+
"waiters": {
|
4
|
+
"model_customization_job_complete": {
|
5
|
+
"delay": 120,
|
6
|
+
"maxAttempts": 75,
|
7
|
+
"operation": "GetModelCustomizationJob",
|
8
|
+
"acceptors": [
|
9
|
+
{
|
10
|
+
"matcher": "path",
|
11
|
+
"argument": "status",
|
12
|
+
"expected": "InProgress",
|
13
|
+
"state": "retry"
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"matcher": "path",
|
17
|
+
"argument": "status",
|
18
|
+
"expected": "Completed",
|
19
|
+
"state": "success"
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"matcher": "path",
|
23
|
+
"argument": "status",
|
24
|
+
"expected": "Failed",
|
25
|
+
"state": "failure"
|
26
|
+
},
|
27
|
+
{
|
28
|
+
"matcher": "path",
|
29
|
+
"argument": "status",
|
30
|
+
"expected": "Stopping",
|
31
|
+
"state": "failure"
|
32
|
+
},
|
33
|
+
{
|
34
|
+
"matcher": "path",
|
35
|
+
"argument": "status",
|
36
|
+
"expected": "Stopped",
|
37
|
+
"state": "failure"
|
38
|
+
}
|
39
|
+
]
|
40
|
+
}
|
41
|
+
}
|
42
|
+
}
|